Files
newt/src/Lib/TCO.newt
2025-04-07 14:29:55 -07:00

111 lines
3.7 KiB
Agda
Raw Blame History

module Lib.TCO
import Prelude
import Data.Graph
import Lib.Ref2
import Lib.Common
import Lib.Types
import Lib.CompileExp
import Data.SortedMap
/-
This is modeled after Idris' tail call optimization written by Stefan Hoeck.
We collect strongly connected components of the tail call graph,
defunctionalize it (make a data type modelling function calls and "return"),
and wrap it in a trampoline.
-/
-- Find names of applications in tail position
tailNames : CExp List QName
tailNames (CApp (CRef name) args 0) = name :: Nil
tailNames (CCase _ alts) = join $ map altTailNames alts
where
altTailNames : CAlt List QName
altTailNames (CConAlt _ _ exp) = tailNames exp
altTailNames (CDefAlt exp) = tailNames exp
altTailNames (CLitAlt _ exp) = tailNames exp
tailNames (CLet _ _ t) = tailNames t
tailNames (CLetRec _ _ t) = tailNames t
tailNames (CConstr _ args) = Nil
tailNames (CBnd _) = Nil
tailNames (CFun _ tm) = tailNames tm
tailNames (CLam _ _) = Nil
tailNames (CApp (CRef nm) args n) = nm :: Nil
tailNames (CApp t args n) = Nil
tailNames (CRef _) = Nil
tailNames CErased = Nil
tailNames (CLit _) = Nil
tailNames (CMeta _) = Nil
tailNames (CRaw _ _) = Nil
tailNames (CPrimOp _ _ _) = Nil
-- rewrite tail calls to return an object
rewriteTailCalls : List QName CExp CExp
rewriteTailCalls nms tm = case tm of
CApp (CRef nm) args 0 =>
if elem nm nms
then CConstr (show nm) args
else CConstr "return" (tm :: Nil)
CLetRec nm t u => CLetRec nm t $ rewriteTailCalls nms u
CLet nm t u => CLet nm t $ rewriteTailCalls nms u
CCase sc alts => CCase sc $ map rewriteAlt alts
tm => CConstr "return" (tm :: Nil)
where
rewriteAlt : CAlt -> CAlt
rewriteAlt (CConAlt nm args t) = CConAlt nm args $ rewriteTailCalls nms t
rewriteAlt (CDefAlt t) = CDefAlt $ rewriteTailCalls nms t
rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteTailCalls nms t
-- the name of our trampoline
bouncer : QName
bouncer = QN Nil "bouncer"
doOptimize : List (QName × CExp) M (List (QName × CExp))
doOptimize fns = do
splitFuns <- traverse splitFun fns
let nms = map fst fns
let alts = map (mkAlt nms) splitFuns
recName <- mkRecName nms
let recfun = CFun ("arg" :: Nil) $ CCase (CBnd 0) alts
wrapped <- traverse (mkWrap recName) fns
pure $ (recName, recfun) :: wrapped
where
mkWrap : QName QName × CExp M (QName × CExp)
mkWrap recName (qn, CFun args _) = do
let arglen = length' args
let arg = CConstr (show qn) $ map (\k => CBnd (arglen - k - 1)) (range 0 arglen)
let body = CApp (CRef bouncer) (CRef recName :: arg :: Nil) 0
pure $ (qn, CFun args body)
mkWrap _ (qn, _) = error emptyFC "error in mkWrap: \{show qn} not a CFun"
mkRecName : List QName M QName
mkRecName Nil = error emptyFC "INTERNAL ERROR: Empty List in doOptimize"
mkRecName (QN ns nm :: _) = pure $ QN ns "REC_\{nm}"
mkAlt : List QName (QName × List Name × CExp) -> CAlt
mkAlt nms (qn, args, tm) = CConAlt (show qn) args (rewriteTailCalls nms tm)
splitFun : (QName × CExp) M (QName × List Name × CExp)
splitFun (qn, CFun args body) = pure (qn, args, body)
splitFun (qn, _) = error emptyFC "TCO error: \{show qn} not a function"
ExpMap : U
ExpMap = SortedMap QName CExp
tailCallOpt : ExpMap M ExpMap
tailCallOpt expMap = do
let graph = map (bimap id tailNames) (toList expMap)
let groups = tarjan graph
foldlM processGroup expMap groups
where
doUpdate : ExpMap QName × CExp ExpMap
doUpdate acc (k,v) = updateMap k v acc
processGroup : ExpMap List QName M ExpMap
processGroup expMap names = do
let pairs = mapMaybe (flip lookupMap expMap) names
updates <- doOptimize pairs
pure $ foldl doUpdate expMap updates