module Lib.TCO import Prelude import Data.Graph import Lib.Ref2 import Lib.Common import Lib.Types import Lib.CompileExp import Data.SortedMap /- This is modeled after Idris' tail call optimization written by Stefan Hoeck. We collect strongly connected components of the tail call graph, defunctionalize it (make a data type modelling function calls and "return"), and wrap it in a trampoline. -/ -- Find names of applications in tail position tailNames : CExp → List QName tailNames (CApp (CRef name) args 0) = name :: Nil tailNames (CCase _ alts) = join $ map altTailNames alts where altTailNames : CAlt → List QName altTailNames (CConAlt _ _ exp) = tailNames exp altTailNames (CDefAlt exp) = tailNames exp altTailNames (CLitAlt _ exp) = tailNames exp tailNames (CLet _ _ t) = tailNames t tailNames (CLetRec _ _ t) = tailNames t tailNames (CConstr _ args) = Nil tailNames (CBnd _) = Nil tailNames (CFun _ tm) = tailNames tm tailNames (CLam _ _) = Nil tailNames (CApp (CRef nm) args n) = nm :: Nil tailNames (CApp t args n) = Nil tailNames (CRef _) = Nil tailNames CErased = Nil tailNames (CLit _) = Nil tailNames (CMeta _) = Nil tailNames (CRaw _ _) = Nil tailNames (CPrimOp _ _ _) = Nil -- rewrite tail calls to return an object rewriteTailCalls : List QName → CExp → CExp rewriteTailCalls nms tm = case tm of CApp (CRef nm) args 0 => if elem nm nms then CConstr (show nm) args else CConstr "return" (tm :: Nil) CLetRec nm t u => CLetRec nm t $ rewriteTailCalls nms u CLet nm t u => CLet nm t $ rewriteTailCalls nms u CCase sc alts => CCase sc $ map rewriteAlt alts tm => CConstr "return" (tm :: Nil) where rewriteAlt : CAlt -> CAlt rewriteAlt (CConAlt nm args t) = CConAlt nm args $ rewriteTailCalls nms t rewriteAlt (CDefAlt t) = CDefAlt $ rewriteTailCalls nms t rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteTailCalls nms t -- the name of our trampoline bouncer : QName bouncer = QN Nil "bouncer" doOptimize : List (QName × CExp) → M (List (QName × CExp)) doOptimize fns = do splitFuns <- traverse splitFun fns let nms = map fst fns let alts = map (mkAlt nms) splitFuns recName <- mkRecName nms let recfun = CFun ("arg" :: Nil) $ CCase (CBnd 0) alts wrapped <- traverse (mkWrap recName) fns pure $ (recName, recfun) :: wrapped where mkWrap : QName → QName × CExp → M (QName × CExp) mkWrap recName (qn, CFun args _) = do let arglen = length' args let arg = CConstr (show qn) $ map (\k => CBnd (arglen - k - 1)) (range 0 arglen) let body = CApp (CRef bouncer) (CRef recName :: arg :: Nil) 0 pure $ (qn, CFun args body) mkWrap _ (qn, _) = error emptyFC "error in mkWrap: \{show qn} not a CFun" mkRecName : List QName → M QName mkRecName Nil = error emptyFC "INTERNAL ERROR: Empty List in doOptimize" mkRecName (QN ns nm :: _) = pure $ QN ns "REC_\{nm}" mkAlt : List QName → (QName × List Name × CExp) -> CAlt mkAlt nms (qn, args, tm) = CConAlt (show qn) args (rewriteTailCalls nms tm) splitFun : (QName × CExp) → M (QName × List Name × CExp) splitFun (qn, CFun args body) = pure (qn, args, body) splitFun (qn, _) = error emptyFC "TCO error: \{show qn} not a function" ExpMap : U ExpMap = SortedMap QName CExp tailCallOpt : ExpMap → M ExpMap tailCallOpt expMap = do let graph = map (bimap id tailNames) (toList expMap) let groups = tarjan graph foldlM processGroup expMap groups where doUpdate : ExpMap → QName × CExp → ExpMap doUpdate acc (k,v) = updateMap k v acc processGroup : ExpMap → List QName → M ExpMap processGroup expMap names = do let pairs = mapMaybe (flip lookupMap expMap) names updates <- doOptimize pairs pure $ foldl doUpdate expMap updates