first pass at where clauses

This commit is contained in:
2024-11-19 22:40:52 -08:00
parent c665310653
commit 1a9d44434c
5 changed files with 68 additions and 17 deletions

View File

@@ -15,6 +15,17 @@ import Lib.Eval
import Lib.TopContext
import Lib.Syntax
||| collectDecl collects multiple Def for one function into one
export
collectDecl : List Decl -> List Decl
collectDecl [] = []
collectDecl ((Def fc nm cl) :: rest@(Def _ nm' cl' :: xs)) =
if nm == nm' then collectDecl (Def fc nm (cl ++ cl') :: xs)
else (Def fc nm cl :: collectDecl rest)
collectDecl (x :: xs) = x :: collectDecl xs
-- renaming
-- dom gamma ren
data Pden = PR Nat Nat (List Nat)
@@ -233,9 +244,6 @@ unifySpine env mode True (xs :< x) (ys :< y) = [| unify env mode x y <+> (unifyS
unifySpine env mode True _ _ = error emptyFC "meta spine length mismatch"
unify env mode t u = do
debug "Unify lvl \{show $ length env}"
debug " \{show t}"
@@ -254,6 +262,9 @@ unify env mode t u = do
Pattern => unifyPattern t' u'
Normal => unify' t' u'
-- The case tree is still big here. It's hard for idris to sort
-- What we really want is what I wrote - handle meta, handle lam, handle var, etc
where
unify' : Val -> Val -> M UnifyResult
-- flex/flex
@@ -268,9 +279,6 @@ unify env mode t u = do
unify' (VPi fc _ _ a b) (VPi fc' _ _ a' b') = do
let fresh = VVar fc (length env) [<]
[| unify env mode a a' <+> unify (fresh :: env) mode !(b $$ fresh) !(b' $$ fresh) |]
unify' t'@(VVar fc k sp) u'@(VVar fc' k' sp') =
if k == k' then unifySpine env mode (k == k') sp sp'
else error fc "Failed to unify \{show t'} and \{show u'}"
-- we don't eta expand on LHS
unify' (VLam fc _ t) (VLam _ _ t') = do
@@ -285,7 +293,10 @@ unify env mode t u = do
let fresh = VVar fc (length env) [<]
unify (fresh :: env) mode !(t $$ fresh) !(t' `vapp` fresh)
-- We only want to do this for LHS pattern vars, otherwise, try expanding
unify' t'@(VVar fc k sp) u'@(VVar fc' k' sp') =
if k == k' then unifySpine env mode (k == k') sp sp'
else error fc "Failed to unify \{show t'} and \{show u'}"
unify' t'@(VVar fc k [<]) u = case !(tryEval env u) of
Just v => unify env mode t' v
Nothing => error fc "Failed to unify \{show t'} and \{show u}"
@@ -690,6 +701,32 @@ makeClause top (lhs, rhs) = do
-- we'll want both check and infer, we're augmenting a context
-- so probably a continuation:
-- Context -> List Decl -> (Context -> M a) -> M a
checkWhere : Context -> List Decl -> Raw -> Val -> M Tm
checkWhere ctx decls body ty = do
-- we're going to be very proscriptive here
let (TypeSig sigFC [name] rawtype :: decls) = decls
| x :: _ => error (getFC x) "expected type signature"
| _ => check ctx body ty
funTy <- check ctx rawtype (VU sigFC)
debug "where clause \{name} : \{pprint (names ctx) funTy}"
let (Def defFC name' clauses :: decls') = decls
| x :: _ => error (getFC x) "expected function definition"
| _ => error sigFC "expected function definition after this signature"
unless (name == name') $ error defFC "Expected def for \{name}"
-- REVIEW is this right, cribbed from my top level code
top <- get
clauses' <- traverse (makeClause top) clauses
vty <- eval ctx.env CBN funTy
debug "\{name} vty is \{show vty}"
tm <- buildTree ctx (MkProb clauses' vty)
vtm <- eval ctx.env CBN tm
let ctx' = define ctx name vtm vty
pure $ Let sigFC name tm !(checkWhere ctx' decls' body ty)
checkDone : Context -> List (String, Pattern) -> Raw -> Val -> M Tm
checkDone ctx [] body ty = do
debug "DONE-> check body \{show body} at \{show ty}"
@@ -864,6 +901,7 @@ undo ((DoLet fc nm tm) :: xs) = RLet fc nm (RImplicit fc) tm <$> undo xs
undo ((DoArrow fc nm tm) :: xs) = pure $ RApp fc (RApp fc (RVar fc "_>>=_") tm Explicit) (RLam fc nm Explicit !(undo xs)) Explicit
check ctx tm ty = case (tm, !(forceType ctx.env ty)) of
(RWhere fc decls body, ty) => checkWhere ctx (collectDecl decls) body ty
(RIf fc a b c, ty) =>
let tm' = RCase fc a [ MkAlt (RVar (getFC b) "True") b, MkAlt (RVar (getFC c) "False") c ] in
check ctx tm' ty