performance and code size improvements

- Use default case for constructors with no explicit match.
- self-compile is 15s now
- code size is 60% smaller

code size and self compile time on par with the idris-built version
This commit is contained in:
2025-01-18 21:33:49 -08:00
parent f991ca0d52
commit e3ae301c9c
11 changed files with 371 additions and 10 deletions

View File

@@ -5,9 +5,12 @@
- [ ] review pattern matching. goal is to have a sane context on the other end. secondary goal - bring it closer to the paper.
- [x] redo code to determine base path
- [x] emit only one branch for default case when splitting inductives
- [ ] save/load results of processing a module
- [x] keep each module separate in context
- [x] search would include imported modules, collect ops into and from modules
- [x] serialize modules
- [ ] deserialize modules if up to date
- should I allow the idris cross module assignment hack?
- >>> sort out metas (maybe push them up to the main list)
- eventually we may want to support resuming halfway through a file

View File

@@ -343,7 +343,7 @@ IO a = World -> IORes a
instance Monad IO where
bind ma mab = \ w => case ma w of
MkIORes a w => mab a w
pure x = \ w => MkIORes x w
pure = MkIORes
bindList : a b. List a (a List b) List b

View File

@@ -848,6 +848,25 @@ updateContext ctx ((k, val) :: cs) =
replaceV Z x (y :: xs) = x :: xs
replaceV (S k) x (y :: xs) = y :: replaceV k x xs
checkCase : Context Problem String Val (QName × Int × Tm) M Bool
checkCase ctx prob scnm scty (dcName, arity, ty) = do
vty <- eval Nil CBN ty
(ctx', ty', vars, sc) <- extendPi ctx (vty) Lin Lin
(Just res) <- catchError (Just <$> unify ctx'.env UPattern ty' scty)
(\err => do
debug $ \ _ => "SKIP \{show dcName} because unify error \{errorMsg err}"
pure Nothing)
| _ => pure False
(Right res) <- tryError (unify ctx'.env UPattern ty' scty)
| Left err => do
debug $ \ _ => "SKIP \{show dcName} because unify error \{errorMsg err}"
pure False
case lookupDef ctx scnm of
Just val@(VRef fc nm sp) => pure $ nm == dcName
_ => pure True
-- ok, so this is a single constructor, CaseAlt
-- return Nothing if dcon doesn't unify with scrut
buildCase : Context -> Problem -> String -> Val -> (QName × Int × Tm) -> M (Maybe CaseAlt)
@@ -1152,6 +1171,12 @@ getLits nm ((MkClause fc cons pats expr) :: cs) = case find ((_==_ nm) ∘ fst)
Just (_, (PatLit _ lit)) => lit :: getLits nm cs
_ => getLits nm cs
-- collect constructors that are matched on
matchedConstructors : String List Clause List QName
matchedConstructors nm Nil = Nil
matchedConstructors nm ((MkClause fc cons pats expr) :: cs) = case find ((_==_ nm) fst) cons of
Just (_, (PatCon _ _ dcon _ _)) => dcon :: matchedConstructors nm cs
_ => matchedConstructors nm cs
-- then build a lit case for each of those
@@ -1189,7 +1214,18 @@ buildLitCase ctx prob fc scnm scty lit = do
cons <- rewriteConstraint cons Nil
pure $ MkClause fc cons pats expr
buildDefault : Context Problem FC String M CaseAlt
buildDefault ctx prob fc scnm = do
let defclauses = filter isDefault prob.clauses
when (length' defclauses == 0) $ \ _ => error fc "no default for literal slot on \{show scnm}"
CaseDefault <$> buildTree ctx (MkProb defclauses prob.ty)
where
isDefault : Clause -> Bool
isDefault cl = case find ((_==_ scnm) fst) cl.cons of
Just (_, (PatVar _ _ _)) => True
Just (_, (PatWild _ _)) => True
Nothing => True
_ => False
buildLitCases : Context -> Problem -> FC -> String -> Val -> M (List CaseAlt)
buildLitCases ctx prob fc scnm scty = do
@@ -1289,12 +1325,26 @@ buildTree ctx prob@(MkProb ((MkClause fc constraints Nil expr) :: cs) ty) = do
-- this is per the paper, but it would be nice to coalesce
-- default cases
cons <- getConstructors ctx (getFC pat) scty'
debug $ \ _ => "CONS \{show $ map fst cons}"
-- TODO collect the wild-card only cases into one
alts <- traverse (buildCase ctx prob scnm scty') cons
let matched = matchedConstructors scnm prob.clauses
let (hit,miss) = partition (flip elem matched fst) cons
-- need to check miss is possible
miss' <- filterM (checkCase ctx prob scnm scty') miss
debug $ \ _ => "CONS \{show $ map fst cons} matched \{show matched} miss \{show miss} miss' \{show miss'}"
-- process constructors with matches
alts <- traverse (buildCase ctx prob scnm scty') hit
debug $ \ _ => "GOTALTS \{show alts}"
when (length' (mapMaybe id alts) == 0) $ \ _ => error (fc) "no alts for \{show scty'}"
pure $ Case fc sctm (mapMaybe id alts)
let alts' = mapMaybe id alts
when (length' alts' == 0) $ \ _ => error (fc) "no alts for \{show scty'}"
-- build a default case for missed constructors
case miss' of
Nil => pure $ Case fc sctm (mapMaybe id alts)
_ => do
-- ctx prob fc scnm
default <- buildDefault ctx prob fc scnm
pure $ Case fc sctm (snoc alts' default)
PatLit fc v => do
let tyname = litTyName v
case scty' of

View File

@@ -4,6 +4,9 @@ module Lib.Prettier
import Prelude
import Data.Int
-- TODO I broke this when I converted from Nat to Int, and we're disabling it
-- by flattening the Doc for now.
-- `Doc` is a pretty printing document. Constructors are private, use
-- methods below. `Alt` in particular has some invariants on it, see paper
-- for details. (Something along the lines of "the first line of left is not

View File

@@ -453,6 +453,13 @@ catchError (MkM ma) handler = MkM $ \tc => do
tryError : a. M a -> M (Either Error a)
tryError ma = catchError (map Right ma) (pure Left)
filterM : a. (a M Bool) List a M (List a)
filterM pred Nil = pure Nil
filterM pred (x :: xs) = do
check <- pred x
if check then _::_ x <$> filterM pred xs else filterM pred xs
get : M TopContext
get = MkM $ \ tc => pure $ Right (tc, tc)

View File

@@ -20,6 +20,7 @@ import Lib.Types
import Lib.Syntax
import Lib.Syntax
import Node
import Serialize
primNS : List String
primNS = ("Prim" :: Nil)
@@ -132,7 +133,11 @@ processModule importFC base stk qn@(QN ns nm) = do
-- update modules with result, leave the rest of context in case this is top file
top <- get
mc <- readIORef top.metaCtx
let modules = updateMap modns (MkModCtx top.defs mc top.ops) top.modules
let mod = MkModCtx top.defs mc top.ops
dumpModule qn src mod
let modules = updateMap modns mod top.modules
freshMC <- newIORef (MC EmptyMap 0 CheckAll)
modify (\ top => MkTop modules top.imported top.ns top.defs top.metaCtx top.verbose top.errors top.ops)

28
port/Serialize.newt Normal file
View File

@@ -0,0 +1,28 @@
module Serialize
import Prelude
import Node
import Lib.Common
import Lib.Types
import Data.SortedMap
-- this was an experiment, prepping for dumping module information
-- it ends up with out of memory dumping defs of some of the files.
-- Prelude is 114MB pretty-printed... gzip to 1M
pfunc dumpObject uses (MkIORes MkUnit fs): a. String a IO Unit := `(_,fn,a) => (w) => {
try {
let {EncFile} = require('./serializer')
let enc = EncFile.encode(a)
fs.writeFileSync(fn, enc)
} catch (e) {}
return MkIORes(null, MkUnit, w)
}`
-- for now, include src and use that to see if something changed
dumpModule : QName String ModContext M Unit
dumpModule qn src mod = do
let fn = "build/\{show qn}.newtmod"
let defs = listValues mod.modDefs
let ops = toList mod.ctxOps
let mctx = toList mod.modMetaCtx.metas
liftIO $ dumpObject fn (src,defs,ops,mctx)

5
scripts/stats.py Normal file → Executable file
View File

@@ -1,7 +1,10 @@
#!/usr/bin/env python3
import sys
fn = sys.argv[1]
stats = {}
acc = ''
name = ''
for line in open('newt.js'):
for line in open(fn):
if line.startswith('const'):
if name: stats[name] = len(acc)
acc = line

260
serializer.ts Normal file
View File

@@ -0,0 +1,260 @@
// Experimental serializer / deserializer for modules
// not completely wired in yet, serialization is running.
const END = 0;
const LIST = 1;
const TUPLE = 2;
const INDUCT = 3;
const STRING = 4;
const NUMBER = 5;
const NULL = 6;
const te = new TextEncoder();
// TODO - next two functions are machine generated and need to be fixed
class DeserializationStream {
pos = 0;
buf: Uint8Array;
constructor(buf: Uint8Array) {
this.buf = buf;
}
readByte() {
return this.buf[this.pos++];
}
readVarint() {
let shift = 0;
let result = 0;
while (true) {
const byte = this.readByte();
result |= (byte & 0x7f) << shift;
if ((byte & 0x80) === 0) break;
shift += 7;
}
return result;
}
readSignedVarint() {
const n = this.readVarint();
return (n >>> 1) ^ -(n & 1);
}
readString() {
const length = this.readVarint();
const bytes = this.buf.slice(this.pos, this.pos + length);
this.pos += length;
return new TextDecoder().decode(bytes);
}
}
export class DecFile {
pool: string[] = [""];
buf: DeserializationStream;
static decode(encoded: Uint8Array) {
return new DecFile(encoded).read()
}
constructor(data: Uint8Array) {
this.buf = new DeserializationStream(data);
this.readPool();
}
readPool() {
while (true) {
let str = this.buf.readString();
if (!str.length) break
this.pool.push(str);
}
console.log('read pool', this.buf.pos)
}
read(): any {
const type = this.buf.readByte();
switch (type) {
case NULL:
return null;
case LIST: {
const list: any[] = [];
while (this.buf.buf[this.buf.pos] !== END) {
list.push(this.read());
}
this.buf.pos++;
let rval: any = { tag: "Nil", 'h0': null };
while (list.length)
rval = { tag: "_::_", h0: null, h1: list.pop(), h2: rval };
return rval;
}
case TUPLE: {
const tuple: any[] = [];
while (this.buf.buf[this.buf.pos] !== END) {
tuple.push(this.read());
}
this.buf.pos++;
let rval: any = tuple.pop();
while (tuple.length)
rval = { tag: "_,_", h0: null, h1: null, h2: tuple.pop(), h3: rval };
return rval;
}
case STRING:
return this.pool[this.buf.readVarint()];
case NUMBER:
return this.buf.readSignedVarint();
case INDUCT:
const tag = this.pool[this.buf.readVarint()];
const obj: any = { tag };
let i = 0;
while (this.buf.buf[this.buf.pos] !== END) {
obj[`h${i++}`] = this.read();
}
this.buf.pos++;
return obj;
default:
debugger
throw new Error(`Unknown type: ${type}`);
}
}
}
class SerializationStream {
pos = 0;
buf = new Uint8Array(1024 * 1024);
ensure(size: number) {
if (this.buf.length - this.pos < size) {
const tmp = new Uint8Array(this.buf.length * 1.5);
tmp.set(this.buf);
this.buf = tmp;
}
}
writeByte(n: number) {
this.ensure(1);
this.buf[this.pos++] = n % 256;
}
writeVarint(n: number) {
while (n > 127) {
this.writeByte((n & 0x7f) | 0x80);
n >>= 7;
}
this.writeByte(n & 0x7f);
}
writeSignedVarint(n: number) {
const zigzag = (n << 1) ^ (n >> 31);
this.writeVarint(zigzag);
}
writeString(s: string) {
let data = te.encode(s);
this.ensure(data.byteLength + 4);
this.writeVarint(data.byteLength);
this.buf.set(data, this.pos);
this.pos += data.byteLength;
}
toUint8Array() {
return this.buf.slice(0, this.pos);
}
}
export class EncFile {
poollen = 1;
pool = new SerializationStream();
buf = new SerializationStream();
pmap: Record<string, number> = { "": 0 };
static encode(data: any) {
let f = new EncFile()
f.write(data)
f.pool.writeVarint(0)
return f.toUint8Array()
}
writeString(s: string) {
let n = this.pmap[s];
if (n === undefined) {
n = this.poollen++;
this.pool.writeString(s);
this.pmap[s] = n;
}
this.buf.writeVarint(n);
}
write(a: any) {
// shouldn't happen?
if (a == null) {
this.buf.writeByte(NULL);
} else if (a.tag == "_::_") {
this.buf.writeByte(LIST);
for (; a.tag === "_::_"; a = a.h2) {
this.write(a.h1);
}
this.buf.writeByte(END);
} else if (a.tag == "_,_") {
this.buf.writeByte(TUPLE);
for (; a.tag === "_,_"; a = a.h3) {
this.write(a.h2);
}
this.write(a);
this.buf.writeByte(END);
} else if (typeof a === "string") {
this.buf.writeByte(STRING);
this.writeString(a);
} else if (typeof a === "number") {
this.buf.writeByte(NUMBER);
this.buf.writeSignedVarint(a);
} else if (a.tag) {
this.buf.writeByte(INDUCT);
this.writeString(a.tag);
// we're actually missing a bunch of data here...
// with null, hack is not needed.
let i = 0
for (; i <= 20; i++) {
let key = 'h' + i
let v = a[key]
if (v === undefined) break
this.write(v);
}
if (a['h' + (i + 1)] !== undefined) {
throw new Error("BOOM")
}
this.buf.writeByte(END);
} else {
throw new Error(`handle ${typeof a} ${a} ${Object.keys(a)}`);
}
}
toUint8Array() {
const poolArray = this.pool.toUint8Array();
const bufArray = this.buf.toUint8Array();
const rval = new Uint8Array(poolArray.length + bufArray.length);
console.log('psize', poolArray.byteLength, poolArray.length)
rval.set(poolArray);
rval.set(bufArray, poolArray.length);
return rval;
}
}
function deepEqual(a: any, b: any): boolean {
if (a === b) return true;
if (typeof a !== typeof b) return false;
if (a == null || b == null) return false;
if (typeof a !== "object") return false;
if (Array.isArray(a)) {
if (!Array.isArray(b) || a.length !== b.length) return false;
for (let i = 0; i < a.length; i++) {
if (!deepEqual(a[i], b[i])) return false;
}
return true;
}
const keysA = Object.keys(a);
const keysB = Object.keys(b);
if (keysA.length !== keysB.length) return false;
for (const key of keysA) {
if (!deepEqual(a[key], b[key])) return false;
}
return true;
}

View File

@@ -116,6 +116,7 @@ compileTerm (Meta _ k) = pure $ CRef "meta$\{show k}" -- FIXME
compileTerm (Lam _ nm _ _ t) = pure $ CLam nm !(compileTerm t)
compileTerm tm@(App _ _ _) with (funArgs tm)
_ | (Meta _ k, args) = do
error (getFC tm) "Compiling an unsolved meta \{showTm tm}"
info (getFC tm) "Compiling an unsolved meta \{showTm tm}"
pure $ CApp (CRef "Meta\{show k}") [] Z
_ | (t@(Ref fc nm _), args) = do

View File

@@ -1,5 +1,6 @@
module TestMap
import Prelude
import SortedMap
main : IO Unit