From affae1fecf238e1bfa1c1c500b6ee0713ba823d3 Mon Sep 17 00:00:00 2001 From: Steve Dunham Date: Wed, 20 Nov 2024 19:51:25 -0800 Subject: [PATCH] Allow where defs to refer to themselves --- port/Prelude.newt | 36 +++++++++++++++++++++++++++++------- src/Lib/Compile.idr | 7 +++++++ src/Lib/CompileExp.idr | 2 ++ src/Lib/Elab.idr | 23 +++++++++++++++++------ src/Lib/Eval.idr | 8 +++++++- src/Lib/Types.idr | 8 ++++++++ 6 files changed, 70 insertions(+), 14 deletions(-) diff --git a/port/Prelude.newt b/port/Prelude.newt index 5edf091..c7859d5 100644 --- a/port/Prelude.newt +++ b/port/Prelude.newt @@ -15,10 +15,19 @@ _||_ : Bool → Bool → Bool True || _ = True False || b = b +infixl 6 _==_ +class Eq a where + _==_ : a → a → Bool + data Nat : U where Z : Nat S : Nat -> Nat +instance Eq Nat where + Z == Z = True + S n == S m = n == m + x == y = False + data Maybe : U -> U where Just : {a : U} -> a -> Maybe a Nothing : {a : U} -> Maybe a @@ -146,12 +155,15 @@ instance Mul Nat where Z * _ = Z S n * m = m + n * m --- TODO Sub + infixl 7 _-_ -_-_ : Nat -> Nat -> Nat -Z - m = Z -n - Z = n -S n - S m = n - m +class Sub a where + _-_ : a → a → a + +instance Sub Nat where + Z - m = Z + n - Z = n + S n - S m = n - m infixr 7 _++_ class Concat a where @@ -168,8 +180,6 @@ pfunc length : String → Nat := "(s) => { return rval }" - - pfunc sconcat : String → String → String := "(x,y) => x + y" instance Concat String where _++_ = sconcat @@ -188,10 +198,22 @@ pfunc listToArray : {a : U} -> List a -> Array a := " return rval } " + pfunc alen : {a : U} -> Array a -> Int := "(a,arr) => arr.length" pfunc aget : {a : U} -> Array a -> Int -> a := "(a, arr, ix) => arr[ix]" pfunc aempty : {a : U} -> Unit -> Array a := "() => []" +pfunc arrayToList : {a} → Array a → List a := "(a,arr) => { + let rval = Nil(a) + for (let i = arr.length - 1;i >= 0; i--) { + rval = Cons(a, arr[i], rval) + } + return rval +}" + +-- for now I'll run this in JS +pfunc lines : String → List String := "(s) => arrayToList(s.split('\n'))" + -- TODO represent Nat as number at runtime pfunc natToInt : Nat -> Int := "(n) => { let rval = 0 diff --git a/src/Lib/Compile.idr b/src/Lib/Compile.idr index d078d7f..048966b 100644 --- a/src/Lib/Compile.idr +++ b/src/Lib/Compile.idr @@ -116,6 +116,13 @@ termToJS env (CLet nm t u) f = in case termToJS env t (JAssign nm') of (JAssign _ exp) => JSnoc (JConst nm' exp) (termToJS env' u f) t' => JSnoc (JLet nm' t') (termToJS env' u f) +termToJS env (CLetRec nm t u) f = + let nm' = fresh nm env + env' = push env (Var nm') + -- If it's a simple term, use const + in case termToJS env' t (JAssign nm') of + (JAssign _ exp) => JSnoc (JConst nm' exp) (termToJS env' u f) + t' => JSnoc (JLet nm' t') (termToJS env' u f) termToJS env (CApp t args) f = termToJS env t (\ t' => argsToJS args [<] (\ args' => f (Apply t' args'))) where diff --git a/src/Lib/CompileExp.idr b/src/Lib/CompileExp.idr index 0e03921..138eeac 100644 --- a/src/Lib/CompileExp.idr +++ b/src/Lib/CompileExp.idr @@ -35,6 +35,7 @@ data CExp : Type where CMeta : Nat -> CExp CLit : Literal -> CExp CLet : Name -> CExp -> CExp -> CExp + CLetRec : Name -> CExp -> CExp -> CExp ||| I'm counting Lam in the term for arity. This matches what I need in ||| code gen. @@ -117,6 +118,7 @@ compileTerm (Case _ t alts) = do pure $ CCase t' alts' compileTerm (Lit _ lit) = pure $ CLit lit compileTerm (Let _ nm t u) = pure $ CLet nm !(compileTerm t) !(compileTerm u) +compileTerm (LetRec _ nm t u) = pure $ CLetRec nm !(compileTerm t) !(compileTerm u) export compileFun : Tm -> M CExp diff --git a/src/Lib/Elab.idr b/src/Lib/Elab.idr index bd3aed9..ee70d2f 100644 --- a/src/Lib/Elab.idr +++ b/src/Lib/Elab.idr @@ -180,6 +180,8 @@ rename meta ren lvl v = go ren lvl v go ren lvl (VLit fc lit) = pure (Lit fc lit) go ren lvl (VLet fc name val body) = pure $ Let fc name !(go ren lvl val) !(go (lvl :: ren) (S lvl) body) + go ren lvl (VLetRec fc name val body) = + pure $ Let fc name !(go (lvl :: ren) (S lvl) val) !(go (lvl :: ren) (S lvl) body) lams : Nat -> List String -> Tm -> Tm lams 0 _ tm = tm @@ -357,7 +359,7 @@ unifyCatch fc ctx ty' ty = do debug "fail \{show ty'} \{show ty}" a <- quote ctx.lvl ty' b <- quote ctx.lvl ty - let msg = "unification failure: \{errorMsg err}\n failed to unify \{pprint names a}\n with \{pprint names b}\n " + let msg = "unification failure: \{errorMsg err}\n failed to unify \{pprint names a}\n with \{pprint names b}\n " throwError (E fc msg) case res of MkResult [] => pure () @@ -368,7 +370,7 @@ unifyCatch fc ctx ty' ty = do a <- quote ctx.lvl ty' b <- quote ctx.lvl ty let names = toList $ map fst ctx.types - let msg = "unification failure\n failed to unify \{pprint names a}\n with \{pprint names b}" + let msg = "unification failure\n failed to unify \{pprint names a}\n with \{pprint names b}" let msg = msg ++ "\nconstraints \{show cs.constraints}" throwError (E fc msg) -- error fc "Unification yields constraints \{show cs.constraints}" @@ -721,10 +723,19 @@ checkWhere ctx decls body ty = do clauses' <- traverse (makeClause top) clauses vty <- eval ctx.env CBN funTy debug "\{name} vty is \{show vty}" - tm <- buildTree ctx (MkProb clauses' vty) - vtm <- eval ctx.env CBN tm - let ctx' = define ctx name vtm vty - pure $ Let sigFC name tm !(checkWhere ctx' decls' body ty) + let ctx' = extend ctx name vty + + -- if I lift, I need to namespace it, add args, and add args when + -- calling locally + -- context could hold a Name -> Val (not Tm because levels) to help with that + -- e.g. "go" -> (VApp ... (VApp (VRef "ns.go") ...) + -- But I'll attempt letrec first + tm <- buildTree ctx' (MkProb clauses' vty) + vtm <- eval ctx'.env CBN tm + -- Should we run the rest with the definition in place? + -- I'm wondering if switching from bind to define will mess with metas + -- let ctx' = define ctx name vtm vty + pure $ LetRec sigFC name tm !(checkWhere ctx' decls' body ty) checkDone : Context -> List (String, Pattern) -> Raw -> Val -> M Tm diff --git a/src/Lib/Eval.idr b/src/Lib/Eval.idr index 30ae1f8..28b3be1 100644 --- a/src/Lib/Eval.idr +++ b/src/Lib/Eval.idr @@ -152,6 +152,7 @@ eval env mode (Meta fc i) = eval env mode (Lam fc x t) = pure $ VLam fc x (MkClosure env t) eval env mode (Pi fc x icit a b) = pure $ VPi fc x icit !(eval env mode a) (MkClosure env b) eval env mode (Let fc nm t u) = pure $ VLet fc nm !(eval env mode t) !(eval (VVar fc (length env) [<] :: env) mode u) +eval env mode (LetRec fc nm t u) = pure $ VLetRec fc nm !(eval (VVar fc (length env) [<] :: env) mode t) !(eval (VVar fc (length env) [<] :: env) mode u) -- Here, we assume env has everything. We push levels onto it during type checking. -- I think we could pass in an l and assume everything outside env is free and -- translate to a level @@ -187,6 +188,7 @@ quote l (VMeta fc i sp) = quote l (VLam fc x t) = pure $ Lam fc x !(quote (S l) !(t $$ VVar emptyFC l [<])) quote l (VPi fc x icit a b) = pure $ Pi fc x icit !(quote l a) !(quote (S l) !(b $$ VVar emptyFC l [<])) quote l (VLet fc nm t u) = pure $ Let fc nm !(quote l t) !(quote (S l) u) +quote l (VLetRec fc nm t u) = pure $ LetRec fc nm !(quote (S l) t) !(quote (S l) u) quote l (VU fc) = pure (U fc) quote l (VRef fc n def sp) = quoteSp l (Ref fc n def) sp quote l (VCase fc sc alts) = pure $ Case fc !(quote l sc) alts @@ -260,5 +262,9 @@ zonk top l env t = case t of (App fc t u) => zonkApp top l env t [!(zonk top l env u)] (Pi fc nm icit a b) => Pi fc nm icit <$> zonk top l env a <*> zonkBind top l env b (Let fc nm t u) => Let fc nm <$> zonk top l env t <*> zonkBind top l env u + (LetRec fc nm t u) => LetRec fc nm <$> zonkBind top l env t <*> zonkBind top l env u (Case fc sc alts) => Case fc <$> zonk top l env sc <*> traverse (zonkAlt top l env) alts - _ => pure t + U fc => pure $ U fc + Lit fc lit => pure $ Lit fc lit + Bnd fc ix => pure $ Bnd fc ix + Ref fc ix def => pure $ Ref fc ix def diff --git a/src/Lib/Types.idr b/src/Lib/Types.idr index 4bbf4dc..4cbeed7 100644 --- a/src/Lib/Types.idr +++ b/src/Lib/Types.idr @@ -101,6 +101,8 @@ data Tm : Type where Case : FC -> Tm -> List CaseAlt -> Tm -- need type? Let : FC -> Name -> Tm -> Tm -> Tm + -- for desugaring where + LetRec : FC -> Name -> Tm -> Tm -> Tm Lit : FC -> Literal -> Tm %name Tm t, u, v @@ -117,6 +119,7 @@ HasFC Tm where getFC (Case fc t xs) = fc getFC (Lit fc _) = fc getFC (Let fc _ _ _) = fc + getFC (LetRec fc _ _ _) = fc covering Show Tm @@ -141,6 +144,7 @@ Show Tm where show (Pi _ str Auto t u) = "(Pi {{\{str} : \{show t}}} => \{show u})" show (Case _ sc alts) = "(Case \{show sc} \{show alts})" show (Let _ nm t u) = "(Let \{nm} \{show t} \{show u})" + show (LetRec _ nm t u) = "(LetRec \{nm} \{show t} \{show u})" public export showTm : Tm -> String @@ -214,6 +218,7 @@ pprint names tm = go 0 names tm go p names (Case _ sc alts) = parens 0 p $ text "case" <+> go 0 names sc <+> text "of" ++ (nest 2 (line ++ stack (map (goAlt 0 names) alts))) go p names (Lit _ lit) = text (show lit) go p names (Let _ nm t u) = parens 0 p $ text "let" <+> text nm <+> ":=" <+> go 0 names t <+> "in" (nest 2 $ go 0 (nm :: names) u) + go p names (LetRec _ nm t u) = parens 0 p $ text "letrec" <+> text nm <+> ":=" <+> go 0 names t <+> "in" (nest 2 $ go 0 (nm :: names) u) data Val : Type @@ -246,6 +251,7 @@ data Val : Type where VLam : FC -> Name -> Closure -> Val VPi : FC -> Name -> Icit -> (a : Lazy Val) -> (b : Closure) -> Val VLet : FC -> Name -> Val -> Val -> Val + VLetRec : FC -> Name -> Val -> Val -> Val VU : FC -> Val VLit : FC -> Literal -> Val @@ -260,6 +266,7 @@ getValFC (VPi fc _ _ a b) = fc getValFC (VU fc) = fc getValFC (VLit fc _) = fc getValFC (VLet fc _ _ _) = fc +getValFC (VLetRec fc _ _ _) = fc public export @@ -281,6 +288,7 @@ Show Val where show (VU _) = "U" show (VLit _ lit) = show lit show (VLet _ nm a b) = "(%let \{show nm} = \{show a} in \{show b}" + show (VLetRec _ nm a b) = "(%letrec \{show nm} = \{show a} in \{show b}" public export Env : Type