Looping TCO for singleton components

This commit is contained in:
2026-02-26 22:14:25 -08:00
parent c4ff0c7c8c
commit 7a763bf40a
4 changed files with 93 additions and 15 deletions

View File

@@ -53,11 +53,19 @@ data JSStmt : StKind -> U where
JReturn : JSExp -> JSStmt Return
JLet : (nm : String) -> JSStmt (Assign nm) -> JSStmt Plain -- need somebody to assign
JAssign : (nm : String) -> JSExp -> JSStmt (Assign nm)
-- TODO - switch to Int tags
JCase : a. JSExp -> List JAlt -> JSStmt a
JIfThen : a. JSExp -> JSStmt a -> JSStmt a -> JSStmt a
-- throw can't be used
JError : a. String -> JSStmt a
-- FIXME We're routing around the index here
-- Might be able to keep the index if
-- we add `Loop : List String -> StKind`
-- JLoopAssign peels one off
-- JContinue is a Loop Nil
-- And LoopReturn
JWhile : a. JSStmt a JSStmt a
JLoopAssign : (nm : String) JSExp JSStmt Plain
JContinue : a. JSStmt a
Cont : StKind U
Cont e = JSExp -> JSStmt e
@@ -109,6 +117,8 @@ freshName' nm env =
env' = push env (Var nm')
in (nm', env')
-- get list of arg names and an environment with either references or undefined
-- depending on quantity
freshNames : List (Quant × String) -> JSEnv -> (List String × JSEnv)
freshNames nms env = go nms env Lin
where
@@ -132,6 +142,11 @@ simpleJSExp (LitString _) = True
simpleJSExp (LitBool _) = True
simpleJSExp _ = False
getEnv : Int → List JSExp → JSExp
getEnv ix env = case getAt' ix env of
Just e => e
Nothing => fatalError "Bad bounds \{show ix}"
-- This is inspired by A-normalization, look into the continuation monad
-- There is an index on JSStmt, adopted from Stefan Hoeck's code.
--
@@ -139,9 +154,7 @@ simpleJSExp _ = False
-- is a continuation, which turns the final JSExpr into a JSStmt, and the function returns
-- a JSStmt, wrapping recursive calls in JSnoc if necessary.
termToJS : ∀ e. JSEnv -> CExp -> Cont e -> JSStmt e
termToJS env (CBnd k) f = case getAt (cast k) env.jsenv of
(Just e) => f e
Nothing => fatalError "Bad bounds"
termToJS env (CBnd k) f = f $ getEnv k env.jsenv
termToJS env CErased f = f JUndefined
termToJS env (CRaw str _) f = f (Raw str)
termToJS env (CLam nm t) f =
@@ -155,9 +168,7 @@ termToJS env (CPrimOp op t u) f = termToJS env t $ \ t => termToJS env u $ \ u =
termToJS env (CMeta k) f = f $ LitString "META \{show k}"
termToJS env (CLit lit) f = f (litToJS lit)
-- if it's a var, just use the original
termToJS env (CLet nm (CBnd k) u) f = case getAt (cast k) env.jsenv of
Just e => termToJS (push env e) u f
Nothing => fatalError "bad bounds"
termToJS env (CLet nm (CBnd k) u) f = termToJS (push env $ getEnv k env.jsenv) u f
-- For a let, we run with a continuation to JAssign to a pre-declared variable
-- if JAssign comes back out, we either push the JSExpr into the environment or JConst it,
-- depending on complexity. Otherwise, stick the declaration in front.
@@ -169,6 +180,22 @@ termToJS env (CLet nm t u) f =
then termToJS (push env exp) u f
else JSnoc (JConst nm' exp) (termToJS env' u f)
t' => JSnoc (JLet nm' t') (termToJS env' u f)
termToJS env (CLetLoop args body) f =
let off = length' args in
-- Add lets for the args, we put this in a while and
-- mutate the args, then continue for the self-call
let (lets, env') = go (length' args - 1) args env Lin in
JWhile $ foldr (\a b => JSnoc a b) (termToJS env' body f) lets
where
go : Int → List (Quant × String) -> JSEnv -> SnocList (JSStmt Plain) -> (List (JSStmt Plain) × JSEnv)
go off Nil env acc = (acc <>> Nil, env)
go off ((Many, n) :: ns) env acc =
let (n', env') = freshName' n env
in go off ns env' (acc :< JConst n' (getEnv off env.jsenv))
go off ((Zero, n) :: ns) env acc =
let env' = push env JUndefined
in go off ns env' acc
termToJS env (CLetRec nm CErased u) f = termToJS (push env JUndefined) u f
termToJS env (CLetRec nm t u) f =
-- this shouldn't happen if where is lifted
@@ -184,6 +211,19 @@ termToJS env (CConstr ix _ args qs) f = go args qs 0 (\ args => f $ LitObject ((
go (t :: ts) (Many :: qs) ix k = termToJS env t $ \ t' => go ts qs (ix + 1) $ \ args => k $ ("h\{show ix}", t') :: args
go (t :: ts) (q :: qs) ix k = go ts qs (ix + 1) $ \ args => k args
go _ _ ix k = k Nil
termToJS {e} env (CLoop args quants) f = runArgs (reverse env.jsenv) args quants
where
-- Here we drop the continuation. It _should_ be a JReturn wrapper, because of how we insert JLoop.
-- But we're not statically checking that.
runArgs : List JSExp → List CExp → List Quant → JSStmt e
runArgs _ Nil Nil = JContinue
runArgs _ Nil _ = fatalError "too few CExp"
runArgs (Var x :: rest) (arg :: args) (Many :: qs) =
termToJS env arg $ \ arg' => JSnoc (JLoopAssign x arg') $ runArgs rest args qs
-- TODO check arg erased
runArgs (JUndefined :: rest) (arg :: args) (q :: qs) = runArgs rest args qs
runArgs (wat :: rest) (arg :: args) (q :: qs) = fatalError "bad env for quant \{show q}"
runArgs a b c = fatalError "FALLBACK \{show $ length' a} \{show $ length' b} \{show $ length' c}"
termToJS env (CAppRef nm args quants) f = termToJS env (CRef nm) (\ t' => (argsToJS env t' args quants Lin f))
where
etaExpand : JSEnv -> List Quant -> SnocList JSExp -> JSExp -> JSExp
@@ -329,7 +369,11 @@ stmtToDoc (JPlain x) = expToDoc x ++ text ";"
-- I might not need these split yet.
stmtToDoc (JLet nm body) = text "let" <+> jsIdent nm ++ text ";" </> stmtToDoc body
stmtToDoc (JAssign nm expr) = jsIdent nm <+> text "=" <+> expToDoc expr ++ text ";"
stmtToDoc (JConst nm x) = text "const" <+> jsIdent nm <+> nest 2 (text "=" <+/> expToDoc x ++ text ";")
stmtToDoc (JLoopAssign nm expr) = jsIdent nm <+> text "=" <+> expToDoc expr ++ text ";"
stmtToDoc (JContinue) = text "continue" ++ text ";"
stmtToDoc (JWhile stmt) = text "while (1)" <+> bracket "{" (stmtToDoc stmt) "}"
-- In the loop case, this may be reassigned
stmtToDoc (JConst nm x) = text "let" <+> jsIdent nm <+> nest 2 (text "=" <+/> expToDoc x ++ text ";")
stmtToDoc (JReturn x) = text "return" <+> expToDoc x ++ text ";"
stmtToDoc (JError str) = text "throw new Error(" ++ text (quoteString str) ++ text ");"
stmtToDoc (JIfThen sc t e) =
@@ -431,9 +475,11 @@ sortedNames defs names =
getNames : (deep : Bool) List (Bool × QName) CExp List (Bool × QName)
-- liftIO calls a lambda statically
getNames deep acc (CLam _ t) = getNames deep acc t
getNames deep acc (CLetLoop _ t) = getNames deep acc t
-- top level 0-ary function, doesn't happen
getNames deep acc (CFun _ t) = if deep then getNames deep acc t else acc
-- REVIEW - True or deep?
getNames deep acc (CLoop args qs) = foldl (getNames True) acc args
getNames deep acc (CAppRef nm args qs) =
if length' args == length' qs
then case args of

View File

@@ -36,6 +36,11 @@ data CExp : U where
CLit : Literal -> CExp
CLet : Name -> CExp -> CExp -> CExp
CLetRec : Name -> CExp -> CExp -> CExp
-- Might be able to use a bunch of flagged lets or something
CLetLoop : List (Quant × Name) CExp CExp
-- This is like a CAppRef, self-call
-- If we know it's a tail call fn, we could handle all of this in codegen...
CLoop : List CExp List Quant CExp
CErased : CExp
-- Data / type constructor
CConstr : Nat Name List CExp List Quant CExp

View File

@@ -20,6 +20,9 @@ import Data.SortedMap
-- Find names of applications in tail position
tailNames : CExp List QName
tailNames (CAppRef nm args n) = nm :: Nil
-- these two shouldn't exist yet
tailNames (CLoop _ _) = Nil
tailNames (CLetLoop _ _) = Nil
tailNames (CCase _ alts) = join $ map altTailNames alts
where
altTailNames : CAlt List QName
@@ -40,7 +43,8 @@ tailNames (CMeta _) = Nil
tailNames (CRaw _ _) = Nil
tailNames (CPrimOp _ _ _) = Nil
-- rewrite tail calls to return an object
-- rewrite tail calls to return an object to a trampoline
-- takes a list of the names in the group and the function body
rewriteTailCalls : List QName CExp CExp
rewriteTailCalls nms tm = case tm of
CAppRef nm args qs =>
@@ -63,11 +67,34 @@ rewriteTailCalls nms tm = case tm of
rewriteAlt (CDefAlt t) = CDefAlt $ rewriteTailCalls nms t
rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteTailCalls nms t
-- A looping version of TCO, specialized for single function calls
-- takes a list of the name of the function and the function body
rewriteLoop : QName CExp CExp
rewriteLoop qn tm = case tm of
(CAppRef nm args qs) =>
if length' args == length' qs && nm == qn
then CLoop args qs
else tm
(CLetRec nm t u) => CLetRec nm t $ rewriteLoop qn u
(CLet nm t u) => CLet nm t $ rewriteLoop qn u
(CCase sc alts) => CCase sc $ map rewriteAlt alts
tm => tm
where
rewriteAlt : CAlt CAlt
rewriteAlt (CConAlt ix nm info args t) = CConAlt ix nm info args $ rewriteLoop qn t
rewriteAlt (CDefAlt t) = CDefAlt $ rewriteLoop qn t
rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteLoop qn t
-- the name of our trampoline
bouncer : QName
bouncer = QN "" "bouncer"
doOptimize : List (QName × CExp) M (List (QName × CExp))
doOptimize ((qn, exp) :: Nil) = do
let (CFun args body) = exp | _ => error emptyFC "doOptimize \{show qn} not a CFun"
let body = rewriteLoop qn body
pure $ (qn, CFun args (CLetLoop args body)) :: Nil
doOptimize fns = do
splitFuns <- traverse splitFun fns
let nms = map fst fns
@@ -112,6 +139,8 @@ tailCallOpt expMap = do
processGroup : ExpMap List QName M ExpMap
processGroup expMap names = do
-- Looks like only two are > 1
debug $ \ _ => "compile.tco: group \{show $ length' names} \{show names}"
let pairs = mapMaybe (flip lookupMap expMap) names
updates <- doOptimize pairs
pure $ foldl doUpdate expMap updates