Allow where defs to refer to themselves

This commit is contained in:
2024-11-20 19:51:25 -08:00
parent 7c8c0c9df0
commit affae1fecf
6 changed files with 70 additions and 14 deletions

View File

@@ -15,10 +15,19 @@ _||_ : Bool → Bool → Bool
True || _ = True True || _ = True
False || b = b False || b = b
infixl 6 _==_
class Eq a where
_==_ : a a Bool
data Nat : U where data Nat : U where
Z : Nat Z : Nat
S : Nat -> Nat S : Nat -> Nat
instance Eq Nat where
Z == Z = True
S n == S m = n == m
x == y = False
data Maybe : U -> U where data Maybe : U -> U where
Just : {a : U} -> a -> Maybe a Just : {a : U} -> a -> Maybe a
Nothing : {a : U} -> Maybe a Nothing : {a : U} -> Maybe a
@@ -146,12 +155,15 @@ instance Mul Nat where
Z * _ = Z Z * _ = Z
S n * m = m + n * m S n * m = m + n * m
-- TODO Sub
infixl 7 _-_ infixl 7 _-_
_-_ : Nat -> Nat -> Nat class Sub a where
Z - m = Z _-_ : a a a
n - Z = n
S n - S m = n - m instance Sub Nat where
Z - m = Z
n - Z = n
S n - S m = n - m
infixr 7 _++_ infixr 7 _++_
class Concat a where class Concat a where
@@ -168,8 +180,6 @@ pfunc length : String → Nat := "(s) => {
return rval return rval
}" }"
pfunc sconcat : String String String := "(x,y) => x + y" pfunc sconcat : String String String := "(x,y) => x + y"
instance Concat String where instance Concat String where
_++_ = sconcat _++_ = sconcat
@@ -188,10 +198,22 @@ pfunc listToArray : {a : U} -> List a -> Array a := "
return rval return rval
} }
" "
pfunc alen : {a : U} -> Array a -> Int := "(a,arr) => arr.length" pfunc alen : {a : U} -> Array a -> Int := "(a,arr) => arr.length"
pfunc aget : {a : U} -> Array a -> Int -> a := "(a, arr, ix) => arr[ix]" pfunc aget : {a : U} -> Array a -> Int -> a := "(a, arr, ix) => arr[ix]"
pfunc aempty : {a : U} -> Unit -> Array a := "() => []" pfunc aempty : {a : U} -> Unit -> Array a := "() => []"
pfunc arrayToList : {a} Array a List a := "(a,arr) => {
let rval = Nil(a)
for (let i = arr.length - 1;i >= 0; i--) {
rval = Cons(a, arr[i], rval)
}
return rval
}"
-- for now I'll run this in JS
pfunc lines : String List String := "(s) => arrayToList(s.split('\n'))"
-- TODO represent Nat as number at runtime -- TODO represent Nat as number at runtime
pfunc natToInt : Nat -> Int := "(n) => { pfunc natToInt : Nat -> Int := "(n) => {
let rval = 0 let rval = 0

View File

@@ -116,6 +116,13 @@ termToJS env (CLet nm t u) f =
in case termToJS env t (JAssign nm') of in case termToJS env t (JAssign nm') of
(JAssign _ exp) => JSnoc (JConst nm' exp) (termToJS env' u f) (JAssign _ exp) => JSnoc (JConst nm' exp) (termToJS env' u f)
t' => JSnoc (JLet nm' t') (termToJS env' u f) t' => JSnoc (JLet nm' t') (termToJS env' u f)
termToJS env (CLetRec nm t u) f =
let nm' = fresh nm env
env' = push env (Var nm')
-- If it's a simple term, use const
in case termToJS env' t (JAssign nm') of
(JAssign _ exp) => JSnoc (JConst nm' exp) (termToJS env' u f)
t' => JSnoc (JLet nm' t') (termToJS env' u f)
termToJS env (CApp t args) f = termToJS env t (\ t' => argsToJS args [<] (\ args' => f (Apply t' args'))) termToJS env (CApp t args) f = termToJS env t (\ t' => argsToJS args [<] (\ args' => f (Apply t' args')))
where where

View File

@@ -35,6 +35,7 @@ data CExp : Type where
CMeta : Nat -> CExp CMeta : Nat -> CExp
CLit : Literal -> CExp CLit : Literal -> CExp
CLet : Name -> CExp -> CExp -> CExp CLet : Name -> CExp -> CExp -> CExp
CLetRec : Name -> CExp -> CExp -> CExp
||| I'm counting Lam in the term for arity. This matches what I need in ||| I'm counting Lam in the term for arity. This matches what I need in
||| code gen. ||| code gen.
@@ -117,6 +118,7 @@ compileTerm (Case _ t alts) = do
pure $ CCase t' alts' pure $ CCase t' alts'
compileTerm (Lit _ lit) = pure $ CLit lit compileTerm (Lit _ lit) = pure $ CLit lit
compileTerm (Let _ nm t u) = pure $ CLet nm !(compileTerm t) !(compileTerm u) compileTerm (Let _ nm t u) = pure $ CLet nm !(compileTerm t) !(compileTerm u)
compileTerm (LetRec _ nm t u) = pure $ CLetRec nm !(compileTerm t) !(compileTerm u)
export export
compileFun : Tm -> M CExp compileFun : Tm -> M CExp

View File

@@ -180,6 +180,8 @@ rename meta ren lvl v = go ren lvl v
go ren lvl (VLit fc lit) = pure (Lit fc lit) go ren lvl (VLit fc lit) = pure (Lit fc lit)
go ren lvl (VLet fc name val body) = go ren lvl (VLet fc name val body) =
pure $ Let fc name !(go ren lvl val) !(go (lvl :: ren) (S lvl) body) pure $ Let fc name !(go ren lvl val) !(go (lvl :: ren) (S lvl) body)
go ren lvl (VLetRec fc name val body) =
pure $ Let fc name !(go (lvl :: ren) (S lvl) val) !(go (lvl :: ren) (S lvl) body)
lams : Nat -> List String -> Tm -> Tm lams : Nat -> List String -> Tm -> Tm
lams 0 _ tm = tm lams 0 _ tm = tm
@@ -357,7 +359,7 @@ unifyCatch fc ctx ty' ty = do
debug "fail \{show ty'} \{show ty}" debug "fail \{show ty'} \{show ty}"
a <- quote ctx.lvl ty' a <- quote ctx.lvl ty'
b <- quote ctx.lvl ty b <- quote ctx.lvl ty
let msg = "unification failure: \{errorMsg err}\n failed to unify \{pprint names a}\n with \{pprint names b}\n " let msg = "unification failure: \{errorMsg err}\n failed to unify \{pprint names a}\n with \{pprint names b}\n "
throwError (E fc msg) throwError (E fc msg)
case res of case res of
MkResult [] => pure () MkResult [] => pure ()
@@ -368,7 +370,7 @@ unifyCatch fc ctx ty' ty = do
a <- quote ctx.lvl ty' a <- quote ctx.lvl ty'
b <- quote ctx.lvl ty b <- quote ctx.lvl ty
let names = toList $ map fst ctx.types let names = toList $ map fst ctx.types
let msg = "unification failure\n failed to unify \{pprint names a}\n with \{pprint names b}" let msg = "unification failure\n failed to unify \{pprint names a}\n with \{pprint names b}"
let msg = msg ++ "\nconstraints \{show cs.constraints}" let msg = msg ++ "\nconstraints \{show cs.constraints}"
throwError (E fc msg) throwError (E fc msg)
-- error fc "Unification yields constraints \{show cs.constraints}" -- error fc "Unification yields constraints \{show cs.constraints}"
@@ -721,10 +723,19 @@ checkWhere ctx decls body ty = do
clauses' <- traverse (makeClause top) clauses clauses' <- traverse (makeClause top) clauses
vty <- eval ctx.env CBN funTy vty <- eval ctx.env CBN funTy
debug "\{name} vty is \{show vty}" debug "\{name} vty is \{show vty}"
tm <- buildTree ctx (MkProb clauses' vty) let ctx' = extend ctx name vty
vtm <- eval ctx.env CBN tm
let ctx' = define ctx name vtm vty -- if I lift, I need to namespace it, add args, and add args when
pure $ Let sigFC name tm !(checkWhere ctx' decls' body ty) -- calling locally
-- context could hold a Name -> Val (not Tm because levels) to help with that
-- e.g. "go" -> (VApp ... (VApp (VRef "ns.go") ...)
-- But I'll attempt letrec first
tm <- buildTree ctx' (MkProb clauses' vty)
vtm <- eval ctx'.env CBN tm
-- Should we run the rest with the definition in place?
-- I'm wondering if switching from bind to define will mess with metas
-- let ctx' = define ctx name vtm vty
pure $ LetRec sigFC name tm !(checkWhere ctx' decls' body ty)
checkDone : Context -> List (String, Pattern) -> Raw -> Val -> M Tm checkDone : Context -> List (String, Pattern) -> Raw -> Val -> M Tm

View File

@@ -152,6 +152,7 @@ eval env mode (Meta fc i) =
eval env mode (Lam fc x t) = pure $ VLam fc x (MkClosure env t) eval env mode (Lam fc x t) = pure $ VLam fc x (MkClosure env t)
eval env mode (Pi fc x icit a b) = pure $ VPi fc x icit !(eval env mode a) (MkClosure env b) eval env mode (Pi fc x icit a b) = pure $ VPi fc x icit !(eval env mode a) (MkClosure env b)
eval env mode (Let fc nm t u) = pure $ VLet fc nm !(eval env mode t) !(eval (VVar fc (length env) [<] :: env) mode u) eval env mode (Let fc nm t u) = pure $ VLet fc nm !(eval env mode t) !(eval (VVar fc (length env) [<] :: env) mode u)
eval env mode (LetRec fc nm t u) = pure $ VLetRec fc nm !(eval (VVar fc (length env) [<] :: env) mode t) !(eval (VVar fc (length env) [<] :: env) mode u)
-- Here, we assume env has everything. We push levels onto it during type checking. -- Here, we assume env has everything. We push levels onto it during type checking.
-- I think we could pass in an l and assume everything outside env is free and -- I think we could pass in an l and assume everything outside env is free and
-- translate to a level -- translate to a level
@@ -187,6 +188,7 @@ quote l (VMeta fc i sp) =
quote l (VLam fc x t) = pure $ Lam fc x !(quote (S l) !(t $$ VVar emptyFC l [<])) quote l (VLam fc x t) = pure $ Lam fc x !(quote (S l) !(t $$ VVar emptyFC l [<]))
quote l (VPi fc x icit a b) = pure $ Pi fc x icit !(quote l a) !(quote (S l) !(b $$ VVar emptyFC l [<])) quote l (VPi fc x icit a b) = pure $ Pi fc x icit !(quote l a) !(quote (S l) !(b $$ VVar emptyFC l [<]))
quote l (VLet fc nm t u) = pure $ Let fc nm !(quote l t) !(quote (S l) u) quote l (VLet fc nm t u) = pure $ Let fc nm !(quote l t) !(quote (S l) u)
quote l (VLetRec fc nm t u) = pure $ LetRec fc nm !(quote (S l) t) !(quote (S l) u)
quote l (VU fc) = pure (U fc) quote l (VU fc) = pure (U fc)
quote l (VRef fc n def sp) = quoteSp l (Ref fc n def) sp quote l (VRef fc n def sp) = quoteSp l (Ref fc n def) sp
quote l (VCase fc sc alts) = pure $ Case fc !(quote l sc) alts quote l (VCase fc sc alts) = pure $ Case fc !(quote l sc) alts
@@ -260,5 +262,9 @@ zonk top l env t = case t of
(App fc t u) => zonkApp top l env t [!(zonk top l env u)] (App fc t u) => zonkApp top l env t [!(zonk top l env u)]
(Pi fc nm icit a b) => Pi fc nm icit <$> zonk top l env a <*> zonkBind top l env b (Pi fc nm icit a b) => Pi fc nm icit <$> zonk top l env a <*> zonkBind top l env b
(Let fc nm t u) => Let fc nm <$> zonk top l env t <*> zonkBind top l env u (Let fc nm t u) => Let fc nm <$> zonk top l env t <*> zonkBind top l env u
(LetRec fc nm t u) => LetRec fc nm <$> zonkBind top l env t <*> zonkBind top l env u
(Case fc sc alts) => Case fc <$> zonk top l env sc <*> traverse (zonkAlt top l env) alts (Case fc sc alts) => Case fc <$> zonk top l env sc <*> traverse (zonkAlt top l env) alts
_ => pure t U fc => pure $ U fc
Lit fc lit => pure $ Lit fc lit
Bnd fc ix => pure $ Bnd fc ix
Ref fc ix def => pure $ Ref fc ix def

View File

@@ -101,6 +101,8 @@ data Tm : Type where
Case : FC -> Tm -> List CaseAlt -> Tm Case : FC -> Tm -> List CaseAlt -> Tm
-- need type? -- need type?
Let : FC -> Name -> Tm -> Tm -> Tm Let : FC -> Name -> Tm -> Tm -> Tm
-- for desugaring where
LetRec : FC -> Name -> Tm -> Tm -> Tm
Lit : FC -> Literal -> Tm Lit : FC -> Literal -> Tm
%name Tm t, u, v %name Tm t, u, v
@@ -117,6 +119,7 @@ HasFC Tm where
getFC (Case fc t xs) = fc getFC (Case fc t xs) = fc
getFC (Lit fc _) = fc getFC (Lit fc _) = fc
getFC (Let fc _ _ _) = fc getFC (Let fc _ _ _) = fc
getFC (LetRec fc _ _ _) = fc
covering covering
Show Tm Show Tm
@@ -141,6 +144,7 @@ Show Tm where
show (Pi _ str Auto t u) = "(Pi {{\{str} : \{show t}}} => \{show u})" show (Pi _ str Auto t u) = "(Pi {{\{str} : \{show t}}} => \{show u})"
show (Case _ sc alts) = "(Case \{show sc} \{show alts})" show (Case _ sc alts) = "(Case \{show sc} \{show alts})"
show (Let _ nm t u) = "(Let \{nm} \{show t} \{show u})" show (Let _ nm t u) = "(Let \{nm} \{show t} \{show u})"
show (LetRec _ nm t u) = "(LetRec \{nm} \{show t} \{show u})"
public export public export
showTm : Tm -> String showTm : Tm -> String
@@ -214,6 +218,7 @@ pprint names tm = go 0 names tm
go p names (Case _ sc alts) = parens 0 p $ text "case" <+> go 0 names sc <+> text "of" ++ (nest 2 (line ++ stack (map (goAlt 0 names) alts))) go p names (Case _ sc alts) = parens 0 p $ text "case" <+> go 0 names sc <+> text "of" ++ (nest 2 (line ++ stack (map (goAlt 0 names) alts)))
go p names (Lit _ lit) = text (show lit) go p names (Lit _ lit) = text (show lit)
go p names (Let _ nm t u) = parens 0 p $ text "let" <+> text nm <+> ":=" <+> go 0 names t <+> "in" </> (nest 2 $ go 0 (nm :: names) u) go p names (Let _ nm t u) = parens 0 p $ text "let" <+> text nm <+> ":=" <+> go 0 names t <+> "in" </> (nest 2 $ go 0 (nm :: names) u)
go p names (LetRec _ nm t u) = parens 0 p $ text "letrec" <+> text nm <+> ":=" <+> go 0 names t <+> "in" </> (nest 2 $ go 0 (nm :: names) u)
data Val : Type data Val : Type
@@ -246,6 +251,7 @@ data Val : Type where
VLam : FC -> Name -> Closure -> Val VLam : FC -> Name -> Closure -> Val
VPi : FC -> Name -> Icit -> (a : Lazy Val) -> (b : Closure) -> Val VPi : FC -> Name -> Icit -> (a : Lazy Val) -> (b : Closure) -> Val
VLet : FC -> Name -> Val -> Val -> Val VLet : FC -> Name -> Val -> Val -> Val
VLetRec : FC -> Name -> Val -> Val -> Val
VU : FC -> Val VU : FC -> Val
VLit : FC -> Literal -> Val VLit : FC -> Literal -> Val
@@ -260,6 +266,7 @@ getValFC (VPi fc _ _ a b) = fc
getValFC (VU fc) = fc getValFC (VU fc) = fc
getValFC (VLit fc _) = fc getValFC (VLit fc _) = fc
getValFC (VLet fc _ _ _) = fc getValFC (VLet fc _ _ _) = fc
getValFC (VLetRec fc _ _ _) = fc
public export public export
@@ -281,6 +288,7 @@ Show Val where
show (VU _) = "U" show (VU _) = "U"
show (VLit _ lit) = show lit show (VLit _ lit) = show lit
show (VLet _ nm a b) = "(%let \{show nm} = \{show a} in \{show b}" show (VLet _ nm a b) = "(%let \{show nm} = \{show a} in \{show b}"
show (VLetRec _ nm a b) = "(%letrec \{show nm} = \{show a} in \{show b}"
public export public export
Env : Type Env : Type