first pass at where clauses

This commit is contained in:
2024-11-19 22:40:52 -08:00
parent c665310653
commit 1a9d44434c
5 changed files with 68 additions and 17 deletions

View File

@@ -15,6 +15,17 @@ import Lib.Eval
import Lib.TopContext
import Lib.Syntax
||| collectDecl collects multiple Def for one function into one
export
collectDecl : List Decl -> List Decl
collectDecl [] = []
collectDecl ((Def fc nm cl) :: rest@(Def _ nm' cl' :: xs)) =
if nm == nm' then collectDecl (Def fc nm (cl ++ cl') :: xs)
else (Def fc nm cl :: collectDecl rest)
collectDecl (x :: xs) = x :: collectDecl xs
-- renaming
-- dom gamma ren
data Pden = PR Nat Nat (List Nat)
@@ -233,9 +244,6 @@ unifySpine env mode True (xs :< x) (ys :< y) = [| unify env mode x y <+> (unifyS
unifySpine env mode True _ _ = error emptyFC "meta spine length mismatch"
unify env mode t u = do
debug "Unify lvl \{show $ length env}"
debug " \{show t}"
@@ -254,6 +262,9 @@ unify env mode t u = do
Pattern => unifyPattern t' u'
Normal => unify' t' u'
-- The case tree is still big here. It's hard for idris to sort
-- What we really want is what I wrote - handle meta, handle lam, handle var, etc
where
unify' : Val -> Val -> M UnifyResult
-- flex/flex
@@ -268,9 +279,6 @@ unify env mode t u = do
unify' (VPi fc _ _ a b) (VPi fc' _ _ a' b') = do
let fresh = VVar fc (length env) [<]
[| unify env mode a a' <+> unify (fresh :: env) mode !(b $$ fresh) !(b' $$ fresh) |]
unify' t'@(VVar fc k sp) u'@(VVar fc' k' sp') =
if k == k' then unifySpine env mode (k == k') sp sp'
else error fc "Failed to unify \{show t'} and \{show u'}"
-- we don't eta expand on LHS
unify' (VLam fc _ t) (VLam _ _ t') = do
@@ -285,7 +293,10 @@ unify env mode t u = do
let fresh = VVar fc (length env) [<]
unify (fresh :: env) mode !(t $$ fresh) !(t' `vapp` fresh)
-- We only want to do this for LHS pattern vars, otherwise, try expanding
unify' t'@(VVar fc k sp) u'@(VVar fc' k' sp') =
if k == k' then unifySpine env mode (k == k') sp sp'
else error fc "Failed to unify \{show t'} and \{show u'}"
unify' t'@(VVar fc k [<]) u = case !(tryEval env u) of
Just v => unify env mode t' v
Nothing => error fc "Failed to unify \{show t'} and \{show u}"
@@ -690,6 +701,32 @@ makeClause top (lhs, rhs) = do
-- we'll want both check and infer, we're augmenting a context
-- so probably a continuation:
-- Context -> List Decl -> (Context -> M a) -> M a
checkWhere : Context -> List Decl -> Raw -> Val -> M Tm
checkWhere ctx decls body ty = do
-- we're going to be very proscriptive here
let (TypeSig sigFC [name] rawtype :: decls) = decls
| x :: _ => error (getFC x) "expected type signature"
| _ => check ctx body ty
funTy <- check ctx rawtype (VU sigFC)
debug "where clause \{name} : \{pprint (names ctx) funTy}"
let (Def defFC name' clauses :: decls') = decls
| x :: _ => error (getFC x) "expected function definition"
| _ => error sigFC "expected function definition after this signature"
unless (name == name') $ error defFC "Expected def for \{name}"
-- REVIEW is this right, cribbed from my top level code
top <- get
clauses' <- traverse (makeClause top) clauses
vty <- eval ctx.env CBN funTy
debug "\{name} vty is \{show vty}"
tm <- buildTree ctx (MkProb clauses' vty)
vtm <- eval ctx.env CBN tm
let ctx' = define ctx name vtm vty
pure $ Let sigFC name tm !(checkWhere ctx' decls' body ty)
checkDone : Context -> List (String, Pattern) -> Raw -> Val -> M Tm
checkDone ctx [] body ty = do
debug "DONE-> check body \{show body} at \{show ty}"
@@ -864,6 +901,7 @@ undo ((DoLet fc nm tm) :: xs) = RLet fc nm (RImplicit fc) tm <$> undo xs
undo ((DoArrow fc nm tm) :: xs) = pure $ RApp fc (RApp fc (RVar fc "_>>=_") tm Explicit) (RLam fc nm Explicit !(undo xs)) Explicit
check ctx tm ty = case (tm, !(forceType ctx.env ty)) of
(RWhere fc decls body, ty) => checkWhere ctx (collectDecl decls) body ty
(RIf fc a b c, ty) =>
let tm' = RCase fc a [ MkAlt (RVar (getFC b) "True") b, MkAlt (RVar (getFC c) "False") c ] in
check ctx tm' ty

View File

@@ -366,10 +366,11 @@ parseDef = do
pats <- many patAtom
keyword "="
body <- typeExpr
wfc <- getPos
w <- optional $ do
keyword "where"
startBlock $ manySame $ (parseSig <|> parseDef)
let body = maybe body (\ decls => RWhere wfc decls body) w
-- these get collected later
pure $ Def fc nm [(t, body)] -- [MkClause fc [] t body]

View File

@@ -81,14 +81,7 @@ getArity _ = Z
-- Can metas live in context for now?
-- We'll have to be able to add them, which might put gamma in a ref
||| collectDecl collects multiple Def for one function into one
export
collectDecl : List Decl -> List Decl
collectDecl [] = []
collectDecl ((Def fc nm cl) :: rest@(Def _ nm' cl' :: xs)) =
if nm == nm' then collectDecl (Def fc nm (cl ++ cl') :: xs)
else (Def fc nm cl :: collectDecl rest)
collectDecl (x :: xs) = x :: collectDecl xs
-- Makes the arg for `solve` when we solve an auto
makeSpine : Nat -> Vect k BD -> SnocList Val

View File

@@ -65,6 +65,8 @@ data DoStmt : Type where
DoLet : (fc : FC) -> String -> Raw -> DoStmt
DoArrow : (fc: FC) -> String -> Raw -> DoStmt
data Decl : Type
data Raw : Type where
RVar : (fc : FC) -> (nm : Name) -> Raw
RLam : (fc : FC) -> (nm : String) -> (icit : Icit) -> (ty : Raw) -> Raw
@@ -79,10 +81,11 @@ data Raw : Type where
RHole : (fc : FC) -> Raw
RDo : (fc : FC) -> List DoStmt -> Raw
RIf : (fc : FC) -> Raw -> Raw -> Raw -> Raw
RWhere : (fc : FC) -> (List Decl) -> Raw -> Raw
%name Raw tm
export
HasFC Raw where
getFC (RVar fc nm) = fc
@@ -98,6 +101,8 @@ HasFC Raw where
getFC (RHole fc) = fc
getFC (RDo fc stmts) = fc
getFC (RIf fc _ _ _) = fc
getFC (RWhere fc _ _) = fc
-- derive some stuff - I'd like json, eq, show, ...
@@ -123,6 +128,17 @@ data Decl
| Class FC Name Telescope (List Decl)
| Instance FC Raw (List Decl)
public export
HasFC Decl where
getFC (TypeSig x strs tm) = x
getFC (Def x str xs) = x
getFC (DCheck x tm tm1) = x
getFC (Data x str tm xs) = x
getFC (PType x str mtm) = x
getFC (PFunc x str tm str1) = x
getFC (PMixFix x strs k y) = x
getFC (Class x str xs ys) = x
getFC (Instance x tm xs) = x
public export
record Module where
@@ -202,6 +218,7 @@ Show Raw where
show (RDo _ stmts) = foo [ "DO", "FIXME"]
show (RU _) = "U"
show (RIf _ x y z) = foo [ "If", show x, show y, show z]
show (RWhere _ _ _) = foo [ "Where", "FIXME"]
export
Pretty Literal where
@@ -254,6 +271,7 @@ Pretty Raw where
asDoc p (RHole _) = text "?"
asDoc p (RDo _ stmts) = text "TODO - RDo"
asDoc p (RIf _ x y z) = par p 0 $ text "if" <+> asDoc 0 x <+/> "then" <+> asDoc 0 y <+/> "else" <+> asDoc 0 z
asDoc p (RWhere _ dd b) = text "TODO pretty where"
export
Pretty Decl where

View File

@@ -11,6 +11,7 @@ import Data.IORef
-- import Lib.Elab
import Lib.Compile
import Lib.Parser
import Lib.Elab
import Lib.Parser.Impl
import Lib.Prettier
import Lib.ProcessDecl