111 lines
3.7 KiB
Agda
111 lines
3.7 KiB
Agda
module Lib.TCO
|
||
|
||
import Prelude
|
||
import Data.Graph
|
||
import Lib.Ref2
|
||
import Lib.Common
|
||
import Lib.Types
|
||
import Lib.CompileExp
|
||
import Data.SortedMap
|
||
|
||
/-
|
||
This is modeled after Idris' tail call optimization written by Stefan Hoeck.
|
||
|
||
We collect strongly connected components of the tail call graph,
|
||
defunctionalize it (make a data type modelling function calls and "return"),
|
||
and wrap it in a trampoline.
|
||
|
||
-/
|
||
|
||
-- Find names of applications in tail position
|
||
tailNames : CExp → List QName
|
||
tailNames (CApp (CRef name) args 0) = name :: Nil
|
||
tailNames (CCase _ alts) = join $ map altTailNames alts
|
||
where
|
||
altTailNames : CAlt → List QName
|
||
altTailNames (CConAlt _ _ exp) = tailNames exp
|
||
altTailNames (CDefAlt exp) = tailNames exp
|
||
altTailNames (CLitAlt _ exp) = tailNames exp
|
||
tailNames (CLet _ _ t) = tailNames t
|
||
tailNames (CLetRec _ _ t) = tailNames t
|
||
tailNames (CConstr _ args) = Nil
|
||
tailNames (CBnd _) = Nil
|
||
tailNames (CFun _ tm) = tailNames tm
|
||
tailNames (CLam _ _) = Nil
|
||
tailNames (CApp (CRef nm) args n) = nm :: Nil
|
||
tailNames (CApp t args n) = Nil
|
||
tailNames (CRef _) = Nil
|
||
tailNames CErased = Nil
|
||
tailNames (CLit _) = Nil
|
||
tailNames (CMeta _) = Nil
|
||
tailNames (CRaw _ _) = Nil
|
||
tailNames (CPrimOp _ _ _) = Nil
|
||
|
||
-- rewrite tail calls to return an object
|
||
rewriteTailCalls : List QName → CExp → CExp
|
||
rewriteTailCalls nms tm = case tm of
|
||
CApp (CRef nm) args 0 =>
|
||
if elem nm nms
|
||
then CConstr (show nm) args
|
||
else CConstr "return" (tm :: Nil)
|
||
CLetRec nm t u => CLetRec nm t $ rewriteTailCalls nms u
|
||
CLet nm t u => CLet nm t $ rewriteTailCalls nms u
|
||
CCase sc alts => CCase sc $ map rewriteAlt alts
|
||
tm => CConstr "return" (tm :: Nil)
|
||
where
|
||
rewriteAlt : CAlt -> CAlt
|
||
rewriteAlt (CConAlt nm args t) = CConAlt nm args $ rewriteTailCalls nms t
|
||
rewriteAlt (CDefAlt t) = CDefAlt $ rewriteTailCalls nms t
|
||
rewriteAlt (CLitAlt lit t) = CLitAlt lit $ rewriteTailCalls nms t
|
||
|
||
-- the name of our trampoline
|
||
bouncer : QName
|
||
bouncer = QN Nil "bouncer"
|
||
|
||
doOptimize : List (QName × CExp) → M (List (QName × CExp))
|
||
doOptimize fns = do
|
||
splitFuns <- traverse splitFun fns
|
||
let nms = map fst fns
|
||
let alts = map (mkAlt nms) splitFuns
|
||
recName <- mkRecName nms
|
||
let recfun = CFun ("arg" :: Nil) $ CCase (CBnd 0) alts
|
||
wrapped <- traverse (mkWrap recName) fns
|
||
pure $ (recName, recfun) :: wrapped
|
||
where
|
||
mkWrap : QName → QName × CExp → M (QName × CExp)
|
||
mkWrap recName (qn, CFun args _) = do
|
||
let arglen = length' args
|
||
let arg = CConstr (show qn) $ map (\k => CBnd (arglen - k - 1)) (range 0 arglen)
|
||
let body = CApp (CRef bouncer) (CRef recName :: arg :: Nil) 0
|
||
pure $ (qn, CFun args body)
|
||
mkWrap _ (qn, _) = error emptyFC "error in mkWrap: \{show qn} not a CFun"
|
||
|
||
mkRecName : List QName → M QName
|
||
mkRecName Nil = error emptyFC "INTERNAL ERROR: Empty List in doOptimize"
|
||
mkRecName (QN ns nm :: _) = pure $ QN ns "REC_\{nm}"
|
||
|
||
mkAlt : List QName → (QName × List Name × CExp) -> CAlt
|
||
mkAlt nms (qn, args, tm) = CConAlt (show qn) args (rewriteTailCalls nms tm)
|
||
|
||
splitFun : (QName × CExp) → M (QName × List Name × CExp)
|
||
splitFun (qn, CFun args body) = pure (qn, args, body)
|
||
splitFun (qn, _) = error emptyFC "TCO error: \{show qn} not a function"
|
||
|
||
ExpMap : U
|
||
ExpMap = SortedMap QName CExp
|
||
|
||
tailCallOpt : ExpMap → M ExpMap
|
||
tailCallOpt expMap = do
|
||
let graph = map (bimap id tailNames) (toList expMap)
|
||
let groups = tarjan graph
|
||
foldlM processGroup expMap groups
|
||
where
|
||
doUpdate : ExpMap → QName × CExp → ExpMap
|
||
doUpdate acc (k,v) = updateMap k v acc
|
||
|
||
processGroup : ExpMap → List QName → M ExpMap
|
||
processGroup expMap names = do
|
||
let pairs = mapMaybe (flip lookupMap expMap) names
|
||
updates <- doOptimize pairs
|
||
pure $ foldl doUpdate expMap updates
|