From 1a9d44434cff98df67d6e13d0612aed4e340b06e Mon Sep 17 00:00:00 2001 From: Steve Dunham Date: Tue, 19 Nov 2024 22:40:52 -0800 Subject: [PATCH] first pass at where clauses --- src/Lib/Elab.idr | 52 +++++++++++++++++++++++++++++++++++------ src/Lib/Parser.idr | 3 ++- src/Lib/ProcessDecl.idr | 9 +------ src/Lib/Syntax.idr | 20 +++++++++++++++- src/Main.idr | 1 + 5 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/Lib/Elab.idr b/src/Lib/Elab.idr index e94bf0e..bd3aed9 100644 --- a/src/Lib/Elab.idr +++ b/src/Lib/Elab.idr @@ -15,6 +15,17 @@ import Lib.Eval import Lib.TopContext import Lib.Syntax + +||| collectDecl collects multiple Def for one function into one +export +collectDecl : List Decl -> List Decl +collectDecl [] = [] +collectDecl ((Def fc nm cl) :: rest@(Def _ nm' cl' :: xs)) = + if nm == nm' then collectDecl (Def fc nm (cl ++ cl') :: xs) + else (Def fc nm cl :: collectDecl rest) +collectDecl (x :: xs) = x :: collectDecl xs + + -- renaming -- dom gamma ren data Pden = PR Nat Nat (List Nat) @@ -233,9 +244,6 @@ unifySpine env mode True (xs :< x) (ys :< y) = [| unify env mode x y <+> (unifyS unifySpine env mode True _ _ = error emptyFC "meta spine length mismatch" - - - unify env mode t u = do debug "Unify lvl \{show $ length env}" debug " \{show t}" @@ -254,6 +262,9 @@ unify env mode t u = do Pattern => unifyPattern t' u' Normal => unify' t' u' + -- The case tree is still big here. It's hard for idris to sort + -- What we really want is what I wrote - handle meta, handle lam, handle var, etc + where unify' : Val -> Val -> M UnifyResult -- flex/flex @@ -268,9 +279,6 @@ unify env mode t u = do unify' (VPi fc _ _ a b) (VPi fc' _ _ a' b') = do let fresh = VVar fc (length env) [<] [| unify env mode a a' <+> unify (fresh :: env) mode !(b $$ fresh) !(b' $$ fresh) |] - unify' t'@(VVar fc k sp) u'@(VVar fc' k' sp') = - if k == k' then unifySpine env mode (k == k') sp sp' - else error fc "Failed to unify \{show t'} and \{show u'}" -- we don't eta expand on LHS unify' (VLam fc _ t) (VLam _ _ t') = do @@ -285,7 +293,10 @@ unify env mode t u = do let fresh = VVar fc (length env) [<] unify (fresh :: env) mode !(t $$ fresh) !(t' `vapp` fresh) - -- We only want to do this for LHS pattern vars, otherwise, try expanding + unify' t'@(VVar fc k sp) u'@(VVar fc' k' sp') = + if k == k' then unifySpine env mode (k == k') sp sp' + else error fc "Failed to unify \{show t'} and \{show u'}" + unify' t'@(VVar fc k [<]) u = case !(tryEval env u) of Just v => unify env mode t' v Nothing => error fc "Failed to unify \{show t'} and \{show u}" @@ -690,6 +701,32 @@ makeClause top (lhs, rhs) = do +-- we'll want both check and infer, we're augmenting a context +-- so probably a continuation: +-- Context -> List Decl -> (Context -> M a) -> M a +checkWhere : Context -> List Decl -> Raw -> Val -> M Tm +checkWhere ctx decls body ty = do + -- we're going to be very proscriptive here + let (TypeSig sigFC [name] rawtype :: decls) = decls + | x :: _ => error (getFC x) "expected type signature" + | _ => check ctx body ty + funTy <- check ctx rawtype (VU sigFC) + debug "where clause \{name} : \{pprint (names ctx) funTy}" + let (Def defFC name' clauses :: decls') = decls + | x :: _ => error (getFC x) "expected function definition" + | _ => error sigFC "expected function definition after this signature" + unless (name == name') $ error defFC "Expected def for \{name}" + -- REVIEW is this right, cribbed from my top level code + top <- get + clauses' <- traverse (makeClause top) clauses + vty <- eval ctx.env CBN funTy + debug "\{name} vty is \{show vty}" + tm <- buildTree ctx (MkProb clauses' vty) + vtm <- eval ctx.env CBN tm + let ctx' = define ctx name vtm vty + pure $ Let sigFC name tm !(checkWhere ctx' decls' body ty) + + checkDone : Context -> List (String, Pattern) -> Raw -> Val -> M Tm checkDone ctx [] body ty = do debug "DONE-> check body \{show body} at \{show ty}" @@ -864,6 +901,7 @@ undo ((DoLet fc nm tm) :: xs) = RLet fc nm (RImplicit fc) tm <$> undo xs undo ((DoArrow fc nm tm) :: xs) = pure $ RApp fc (RApp fc (RVar fc "_>>=_") tm Explicit) (RLam fc nm Explicit !(undo xs)) Explicit check ctx tm ty = case (tm, !(forceType ctx.env ty)) of + (RWhere fc decls body, ty) => checkWhere ctx (collectDecl decls) body ty (RIf fc a b c, ty) => let tm' = RCase fc a [ MkAlt (RVar (getFC b) "True") b, MkAlt (RVar (getFC c) "False") c ] in check ctx tm' ty diff --git a/src/Lib/Parser.idr b/src/Lib/Parser.idr index 92fe5fd..4bbf2d8 100644 --- a/src/Lib/Parser.idr +++ b/src/Lib/Parser.idr @@ -366,10 +366,11 @@ parseDef = do pats <- many patAtom keyword "=" body <- typeExpr + wfc <- getPos w <- optional $ do keyword "where" startBlock $ manySame $ (parseSig <|> parseDef) - + let body = maybe body (\ decls => RWhere wfc decls body) w -- these get collected later pure $ Def fc nm [(t, body)] -- [MkClause fc [] t body] diff --git a/src/Lib/ProcessDecl.idr b/src/Lib/ProcessDecl.idr index f09d2bb..b8892fd 100644 --- a/src/Lib/ProcessDecl.idr +++ b/src/Lib/ProcessDecl.idr @@ -81,14 +81,7 @@ getArity _ = Z -- Can metas live in context for now? -- We'll have to be able to add them, which might put gamma in a ref -||| collectDecl collects multiple Def for one function into one -export -collectDecl : List Decl -> List Decl -collectDecl [] = [] -collectDecl ((Def fc nm cl) :: rest@(Def _ nm' cl' :: xs)) = - if nm == nm' then collectDecl (Def fc nm (cl ++ cl') :: xs) - else (Def fc nm cl :: collectDecl rest) -collectDecl (x :: xs) = x :: collectDecl xs + -- Makes the arg for `solve` when we solve an auto makeSpine : Nat -> Vect k BD -> SnocList Val diff --git a/src/Lib/Syntax.idr b/src/Lib/Syntax.idr index 3e0c473..5234189 100644 --- a/src/Lib/Syntax.idr +++ b/src/Lib/Syntax.idr @@ -65,6 +65,8 @@ data DoStmt : Type where DoLet : (fc : FC) -> String -> Raw -> DoStmt DoArrow : (fc: FC) -> String -> Raw -> DoStmt + +data Decl : Type data Raw : Type where RVar : (fc : FC) -> (nm : Name) -> Raw RLam : (fc : FC) -> (nm : String) -> (icit : Icit) -> (ty : Raw) -> Raw @@ -79,10 +81,11 @@ data Raw : Type where RHole : (fc : FC) -> Raw RDo : (fc : FC) -> List DoStmt -> Raw RIf : (fc : FC) -> Raw -> Raw -> Raw -> Raw - + RWhere : (fc : FC) -> (List Decl) -> Raw -> Raw %name Raw tm + export HasFC Raw where getFC (RVar fc nm) = fc @@ -98,6 +101,8 @@ HasFC Raw where getFC (RHole fc) = fc getFC (RDo fc stmts) = fc getFC (RIf fc _ _ _) = fc + getFC (RWhere fc _ _) = fc + -- derive some stuff - I'd like json, eq, show, ... @@ -123,6 +128,17 @@ data Decl | Class FC Name Telescope (List Decl) | Instance FC Raw (List Decl) +public export +HasFC Decl where + getFC (TypeSig x strs tm) = x + getFC (Def x str xs) = x + getFC (DCheck x tm tm1) = x + getFC (Data x str tm xs) = x + getFC (PType x str mtm) = x + getFC (PFunc x str tm str1) = x + getFC (PMixFix x strs k y) = x + getFC (Class x str xs ys) = x + getFC (Instance x tm xs) = x public export record Module where @@ -202,6 +218,7 @@ Show Raw where show (RDo _ stmts) = foo [ "DO", "FIXME"] show (RU _) = "U" show (RIf _ x y z) = foo [ "If", show x, show y, show z] + show (RWhere _ _ _) = foo [ "Where", "FIXME"] export Pretty Literal where @@ -254,6 +271,7 @@ Pretty Raw where asDoc p (RHole _) = text "?" asDoc p (RDo _ stmts) = text "TODO - RDo" asDoc p (RIf _ x y z) = par p 0 $ text "if" <+> asDoc 0 x <+/> "then" <+> asDoc 0 y <+/> "else" <+> asDoc 0 z + asDoc p (RWhere _ dd b) = text "TODO pretty where" export Pretty Decl where diff --git a/src/Main.idr b/src/Main.idr index 1ae8fba..28ad1a4 100644 --- a/src/Main.idr +++ b/src/Main.idr @@ -11,6 +11,7 @@ import Data.IORef -- import Lib.Elab import Lib.Compile import Lib.Parser +import Lib.Elab import Lib.Parser.Impl import Lib.Prettier import Lib.ProcessDecl