Looping TCO for singleton components
This commit is contained in:
@@ -53,11 +53,19 @@ data JSStmt : StKind -> U where
|
||||
JReturn : JSExp -> JSStmt Return
|
||||
JLet : (nm : String) -> JSStmt (Assign nm) -> JSStmt Plain -- need somebody to assign
|
||||
JAssign : (nm : String) -> JSExp -> JSStmt (Assign nm)
|
||||
-- TODO - switch to Int tags
|
||||
JCase : ∀ a. JSExp -> List JAlt -> JSStmt a
|
||||
JIfThen : ∀ a. JSExp -> JSStmt a -> JSStmt a -> JSStmt a
|
||||
-- throw can't be used
|
||||
JError : ∀ a. String -> JSStmt a
|
||||
-- FIXME We're routing around the index here
|
||||
-- Might be able to keep the index if
|
||||
-- we add `Loop : List String -> StKind`
|
||||
-- JLoopAssign peels one off
|
||||
-- JContinue is a Loop Nil
|
||||
-- And LoopReturn
|
||||
JWhile : ∀ a. JSStmt a → JSStmt a
|
||||
JLoopAssign : (nm : String) → JSExp → JSStmt Plain
|
||||
JContinue : ∀ a. JSStmt a
|
||||
|
||||
Cont : StKind → U
|
||||
Cont e = JSExp -> JSStmt e
|
||||
@@ -109,6 +117,8 @@ freshName' nm env =
|
||||
env' = push env (Var nm')
|
||||
in (nm', env')
|
||||
|
||||
-- get list of arg names and an environment with either references or undefined
|
||||
-- depending on quantity
|
||||
freshNames : List (Quant × String) -> JSEnv -> (List String × JSEnv)
|
||||
freshNames nms env = go nms env Lin
|
||||
where
|
||||
@@ -132,6 +142,11 @@ simpleJSExp (LitString _) = True
|
||||
simpleJSExp (LitBool _) = True
|
||||
simpleJSExp _ = False
|
||||
|
||||
getEnv : Int → List JSExp → JSExp
|
||||
getEnv ix env = case getAt' ix env of
|
||||
Just e => e
|
||||
Nothing => fatalError "Bad bounds \{show ix}"
|
||||
|
||||
-- This is inspired by A-normalization, look into the continuation monad
|
||||
-- There is an index on JSStmt, adopted from Stefan Hoeck's code.
|
||||
--
|
||||
@@ -139,9 +154,7 @@ simpleJSExp _ = False
|
||||
-- is a continuation, which turns the final JSExpr into a JSStmt, and the function returns
|
||||
-- a JSStmt, wrapping recursive calls in JSnoc if necessary.
|
||||
termToJS : ∀ e. JSEnv -> CExp -> Cont e -> JSStmt e
|
||||
termToJS env (CBnd k) f = case getAt (cast k) env.jsenv of
|
||||
(Just e) => f e
|
||||
Nothing => fatalError "Bad bounds"
|
||||
termToJS env (CBnd k) f = f $ getEnv k env.jsenv
|
||||
termToJS env CErased f = f JUndefined
|
||||
termToJS env (CRaw str _) f = f (Raw str)
|
||||
termToJS env (CLam nm t) f =
|
||||
@@ -155,9 +168,7 @@ termToJS env (CPrimOp op t u) f = termToJS env t $ \ t => termToJS env u $ \ u =
|
||||
termToJS env (CMeta k) f = f $ LitString "META \{show k}"
|
||||
termToJS env (CLit lit) f = f (litToJS lit)
|
||||
-- if it's a var, just use the original
|
||||
termToJS env (CLet nm (CBnd k) u) f = case getAt (cast k) env.jsenv of
|
||||
Just e => termToJS (push env e) u f
|
||||
Nothing => fatalError "bad bounds"
|
||||
termToJS env (CLet nm (CBnd k) u) f = termToJS (push env $ getEnv k env.jsenv) u f
|
||||
-- For a let, we run with a continuation to JAssign to a pre-declared variable
|
||||
-- if JAssign comes back out, we either push the JSExpr into the environment or JConst it,
|
||||
-- depending on complexity. Otherwise, stick the declaration in front.
|
||||
@@ -169,6 +180,22 @@ termToJS env (CLet nm t u) f =
|
||||
then termToJS (push env exp) u f
|
||||
else JSnoc (JConst nm' exp) (termToJS env' u f)
|
||||
t' => JSnoc (JLet nm' t') (termToJS env' u f)
|
||||
termToJS env (CLetLoop args body) f =
|
||||
let off = length' args in
|
||||
-- Add lets for the args, we put this in a while and
|
||||
-- mutate the args, then continue for the self-call
|
||||
let (lets, env') = go (length' args - 1) args env Lin in
|
||||
JWhile $ foldr (\a b => JSnoc a b) (termToJS env' body f) lets
|
||||
where
|
||||
go : Int → List (Quant × String) -> JSEnv -> SnocList (JSStmt Plain) -> (List (JSStmt Plain) × JSEnv)
|
||||
go off Nil env acc = (acc <>> Nil, env)
|
||||
go off ((Many, n) :: ns) env acc =
|
||||
let (n', env') = freshName' n env
|
||||
in go off ns env' (acc :< JConst n' (getEnv off env.jsenv))
|
||||
go off ((Zero, n) :: ns) env acc =
|
||||
let env' = push env JUndefined
|
||||
in go off ns env' acc
|
||||
|
||||
termToJS env (CLetRec nm CErased u) f = termToJS (push env JUndefined) u f
|
||||
termToJS env (CLetRec nm t u) f =
|
||||
-- this shouldn't happen if where is lifted
|
||||
@@ -184,6 +211,19 @@ termToJS env (CConstr ix _ args qs) f = go args qs 0 (\ args => f $ LitObject ((
|
||||
go (t :: ts) (Many :: qs) ix k = termToJS env t $ \ t' => go ts qs (ix + 1) $ \ args => k $ ("h\{show ix}", t') :: args
|
||||
go (t :: ts) (q :: qs) ix k = go ts qs (ix + 1) $ \ args => k args
|
||||
go _ _ ix k = k Nil
|
||||
termToJS {e} env (CLoop args quants) f = runArgs (reverse env.jsenv) args quants
|
||||
where
|
||||
-- Here we drop the continuation. It _should_ be a JReturn wrapper, because of how we insert JLoop.
|
||||
-- But we're not statically checking that.
|
||||
runArgs : List JSExp → List CExp → List Quant → JSStmt e
|
||||
runArgs _ Nil Nil = JContinue
|
||||
runArgs _ Nil _ = fatalError "too few CExp"
|
||||
runArgs (Var x :: rest) (arg :: args) (Many :: qs) =
|
||||
termToJS env arg $ \ arg' => JSnoc (JLoopAssign x arg') $ runArgs rest args qs
|
||||
-- TODO check arg erased
|
||||
runArgs (JUndefined :: rest) (arg :: args) (q :: qs) = runArgs rest args qs
|
||||
runArgs (wat :: rest) (arg :: args) (q :: qs) = fatalError "bad env for quant \{show q}"
|
||||
runArgs a b c = fatalError "FALLBACK \{show $ length' a} \{show $ length' b} \{show $ length' c}"
|
||||
termToJS env (CAppRef nm args quants) f = termToJS env (CRef nm) (\ t' => (argsToJS env t' args quants Lin f))
|
||||
where
|
||||
etaExpand : JSEnv -> List Quant -> SnocList JSExp -> JSExp -> JSExp
|
||||
@@ -329,7 +369,11 @@ stmtToDoc (JPlain x) = expToDoc x ++ text ";"
|
||||
-- I might not need these split yet.
|
||||
stmtToDoc (JLet nm body) = text "let" <+> jsIdent nm ++ text ";" </> stmtToDoc body
|
||||
stmtToDoc (JAssign nm expr) = jsIdent nm <+> text "=" <+> expToDoc expr ++ text ";"
|
||||
stmtToDoc (JConst nm x) = text "const" <+> jsIdent nm <+> nest 2 (text "=" <+/> expToDoc x ++ text ";")
|
||||
stmtToDoc (JLoopAssign nm expr) = jsIdent nm <+> text "=" <+> expToDoc expr ++ text ";"
|
||||
stmtToDoc (JContinue) = text "continue" ++ text ";"
|
||||
stmtToDoc (JWhile stmt) = text "while (1)" <+> bracket "{" (stmtToDoc stmt) "}"
|
||||
-- In the loop case, this may be reassigned
|
||||
stmtToDoc (JConst nm x) = text "let" <+> jsIdent nm <+> nest 2 (text "=" <+/> expToDoc x ++ text ";")
|
||||
stmtToDoc (JReturn x) = text "return" <+> expToDoc x ++ text ";"
|
||||
stmtToDoc (JError str) = text "throw new Error(" ++ text (quoteString str) ++ text ");"
|
||||
stmtToDoc (JIfThen sc t e) =
|
||||
@@ -431,9 +475,11 @@ sortedNames defs names =
|
||||
getNames : (deep : Bool) → List (Bool × QName) → CExp → List (Bool × QName)
|
||||
-- liftIO calls a lambda statically
|
||||
getNames deep acc (CLam _ t) = getNames deep acc t
|
||||
getNames deep acc (CLetLoop _ t) = getNames deep acc t
|
||||
-- top level 0-ary function, doesn't happen
|
||||
getNames deep acc (CFun _ t) = if deep then getNames deep acc t else acc
|
||||
|
||||
-- REVIEW - True or deep?
|
||||
getNames deep acc (CLoop args qs) = foldl (getNames True) acc args
|
||||
getNames deep acc (CAppRef nm args qs) =
|
||||
if length' args == length' qs
|
||||
then case args of
|
||||
|
||||
@@ -36,6 +36,11 @@ data CExp : U where
|
||||
CLit : Literal -> CExp
|
||||
CLet : Name -> CExp -> CExp -> CExp
|
||||
CLetRec : Name -> CExp -> CExp -> CExp
|
||||
-- Might be able to use a bunch of flagged lets or something
|
||||
CLetLoop : List (Quant × Name) → CExp → CExp
|
||||
-- This is like a CAppRef, self-call
|
||||
-- If we know it's a tail call fn, we could handle all of this in codegen...
|
||||
CLoop : List CExp → List Quant → CExp
|
||||
CErased : CExp
|
||||
-- Data / type constructor
|
||||
CConstr : Nat → Name → List CExp → List Quant → CExp
|
||||
|
||||
@@ -20,6 +20,9 @@ import Data.SortedMap
|
||||
-- Find names of applications in tail position
|
||||
tailNames : CExp → List QName
|
||||
tailNames (CAppRef nm args n) = nm :: Nil
|
||||
-- these two shouldn't exist yet
|
||||
tailNames (CLoop _ _) = Nil
|
||||
tailNames (CLetLoop _ _) = Nil
|
||||
tailNames (CCase _ alts) = join $ map altTailNames alts
|
||||
where
|
||||
altTailNames : CAlt → List QName
|
||||
@@ -40,7 +43,8 @@ tailNames (CMeta _) = Nil
|
||||
tailNames (CRaw _ _) = Nil
|
||||
tailNames (CPrimOp _ _ _) = Nil
|
||||
|
||||
-- rewrite tail calls to return an object
|
||||
-- rewrite tail calls to return an object to a trampoline
|
||||
-- takes a list of the names in the group and the function body
|
||||
rewriteTailCalls : List QName → CExp → CExp
|
||||
rewriteTailCalls nms tm = case tm of
|
||||
CAppRef nm args qs =>
|
||||
@@ -63,11 +67,34 @@ rewriteTailCalls nms tm = case tm of
|
||||
rewriteAlt (CDefAlt t) = CDefAlt $ rewriteTailCalls nms t
|
||||
rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteTailCalls nms t
|
||||
|
||||
-- A looping version of TCO, specialized for single function calls
|
||||
-- takes a list of the name of the function and the function body
|
||||
rewriteLoop : QName → CExp → CExp
|
||||
rewriteLoop qn tm = case tm of
|
||||
(CAppRef nm args qs) =>
|
||||
if length' args == length' qs && nm == qn
|
||||
then CLoop args qs
|
||||
else tm
|
||||
(CLetRec nm t u) => CLetRec nm t $ rewriteLoop qn u
|
||||
(CLet nm t u) => CLet nm t $ rewriteLoop qn u
|
||||
(CCase sc alts) => CCase sc $ map rewriteAlt alts
|
||||
tm => tm
|
||||
where
|
||||
rewriteAlt : CAlt → CAlt
|
||||
rewriteAlt (CConAlt ix nm info args t) = CConAlt ix nm info args $ rewriteLoop qn t
|
||||
rewriteAlt (CDefAlt t) = CDefAlt $ rewriteLoop qn t
|
||||
rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteLoop qn t
|
||||
|
||||
-- the name of our trampoline
|
||||
bouncer : QName
|
||||
bouncer = QN "" "bouncer"
|
||||
|
||||
doOptimize : List (QName × CExp) → M (List (QName × CExp))
|
||||
doOptimize ((qn, exp) :: Nil) = do
|
||||
let (CFun args body) = exp | _ => error emptyFC "doOptimize \{show qn} not a CFun"
|
||||
let body = rewriteLoop qn body
|
||||
pure $ (qn, CFun args (CLetLoop args body)) :: Nil
|
||||
|
||||
doOptimize fns = do
|
||||
splitFuns <- traverse splitFun fns
|
||||
let nms = map fst fns
|
||||
@@ -112,6 +139,8 @@ tailCallOpt expMap = do
|
||||
|
||||
processGroup : ExpMap → List QName → M ExpMap
|
||||
processGroup expMap names = do
|
||||
-- Looks like only two are > 1
|
||||
debug $ \ _ => "compile.tco: group \{show $ length' names} \{show names}"
|
||||
let pairs = mapMaybe (flip lookupMap expMap) names
|
||||
updates <- doOptimize pairs
|
||||
pure $ foldl doUpdate expMap updates
|
||||
|
||||
Reference in New Issue
Block a user