From e3ae301c9c8576dcdf9b47d3841c3efdae17e8e0 Mon Sep 17 00:00:00 2001 From: Steve Dunham Date: Sat, 18 Jan 2025 21:33:49 -0800 Subject: [PATCH] performance and code size improvements - Use default case for constructors with no explicit match. - self-compile is 15s now - code size is 60% smaller code size and self compile time on par with the idris-built version --- TODO.md | 5 +- newt/Prelude.newt | 2 +- port/Lib/Elab.newt | 62 +++++++++- port/Lib/Prettier.newt | 3 + port/Lib/Types.newt | 7 ++ port/Main.newt | 7 +- port/Serialize.newt | 28 +++++ scripts/stats.py | 5 +- serializer.ts | 260 +++++++++++++++++++++++++++++++++++++++++ src/Lib/CompileExp.idr | 1 + tests/TestMap.newt | 1 + 11 files changed, 371 insertions(+), 10 deletions(-) create mode 100644 port/Serialize.newt mode change 100644 => 100755 scripts/stats.py create mode 100644 serializer.ts diff --git a/TODO.md b/TODO.md index d68323d..e8b7a08 100644 --- a/TODO.md +++ b/TODO.md @@ -5,9 +5,12 @@ - [ ] review pattern matching. goal is to have a sane context on the other end. secondary goal - bring it closer to the paper. - [x] redo code to determine base path +- [x] emit only one branch for default case when splitting inductives - [ ] save/load results of processing a module - [x] keep each module separate in context - - [x] search would include imported modules, collect ops into and from modules + - [x] search would include imported modules, collect ops into and from modules + - [x] serialize modules + - [ ] deserialize modules if up to date - should I allow the idris cross module assignment hack? - >>> sort out metas (maybe push them up to the main list) - eventually we may want to support resuming halfway through a file diff --git a/newt/Prelude.newt b/newt/Prelude.newt index d0081aa..ea1b8a8 100644 --- a/newt/Prelude.newt +++ b/newt/Prelude.newt @@ -343,7 +343,7 @@ IO a = World -> IORes a instance Monad IO where bind ma mab = \ w => case ma w of MkIORes a w => mab a w - pure x = \ w => MkIORes x w + pure = MkIORes bindList : ∀ a b. List a → (a → List b) → List b diff --git a/port/Lib/Elab.newt b/port/Lib/Elab.newt index ab55d9f..9a0973d 100644 --- a/port/Lib/Elab.newt +++ b/port/Lib/Elab.newt @@ -848,6 +848,25 @@ updateContext ctx ((k, val) :: cs) = replaceV Z x (y :: xs) = x :: xs replaceV (S k) x (y :: xs) = y :: replaceV k x xs +checkCase : Context → Problem → String → Val → (QName × Int × Tm) → M Bool +checkCase ctx prob scnm scty (dcName, arity, ty) = do + vty <- eval Nil CBN ty + (ctx', ty', vars, sc) <- extendPi ctx (vty) Lin Lin + (Just res) <- catchError (Just <$> unify ctx'.env UPattern ty' scty) + (\err => do + debug $ \ _ => "SKIP \{show dcName} because unify error \{errorMsg err}" + pure Nothing) + | _ => pure False + + (Right res) <- tryError (unify ctx'.env UPattern ty' scty) + | Left err => do + debug $ \ _ => "SKIP \{show dcName} because unify error \{errorMsg err}" + pure False + + case lookupDef ctx scnm of + Just val@(VRef fc nm sp) => pure $ nm == dcName + _ => pure True + -- ok, so this is a single constructor, CaseAlt -- return Nothing if dcon doesn't unify with scrut buildCase : Context -> Problem -> String -> Val -> (QName × Int × Tm) -> M (Maybe CaseAlt) @@ -1152,6 +1171,12 @@ getLits nm ((MkClause fc cons pats expr) :: cs) = case find ((_==_ nm) ∘ fst) Just (_, (PatLit _ lit)) => lit :: getLits nm cs _ => getLits nm cs +-- collect constructors that are matched on +matchedConstructors : String → List Clause → List QName +matchedConstructors nm Nil = Nil +matchedConstructors nm ((MkClause fc cons pats expr) :: cs) = case find ((_==_ nm) ∘ fst) cons of + Just (_, (PatCon _ _ dcon _ _)) => dcon :: matchedConstructors nm cs + _ => matchedConstructors nm cs -- then build a lit case for each of those @@ -1189,7 +1214,18 @@ buildLitCase ctx prob fc scnm scty lit = do cons <- rewriteConstraint cons Nil pure $ MkClause fc cons pats expr - +buildDefault : Context → Problem → FC → String → M CaseAlt +buildDefault ctx prob fc scnm = do + let defclauses = filter isDefault prob.clauses + when (length' defclauses == 0) $ \ _ => error fc "no default for literal slot on \{show scnm}" + CaseDefault <$> buildTree ctx (MkProb defclauses prob.ty) + where + isDefault : Clause -> Bool + isDefault cl = case find ((_==_ scnm) ∘ fst) cl.cons of + Just (_, (PatVar _ _ _)) => True + Just (_, (PatWild _ _)) => True + Nothing => True + _ => False buildLitCases : Context -> Problem -> FC -> String -> Val -> M (List CaseAlt) buildLitCases ctx prob fc scnm scty = do @@ -1289,12 +1325,26 @@ buildTree ctx prob@(MkProb ((MkClause fc constraints Nil expr) :: cs) ty) = do -- this is per the paper, but it would be nice to coalesce -- default cases cons <- getConstructors ctx (getFC pat) scty' - debug $ \ _ => "CONS \{show $ map fst cons}" - -- TODO collect the wild-card only cases into one - alts <- traverse (buildCase ctx prob scnm scty') cons + let matched = matchedConstructors scnm prob.clauses + let (hit,miss) = partition (flip elem matched ∘ fst) cons + -- need to check miss is possible + miss' <- filterM (checkCase ctx prob scnm scty') miss + + debug $ \ _ => "CONS \{show $ map fst cons} matched \{show matched} miss \{show miss} miss' \{show miss'}" + + -- process constructors with matches + alts <- traverse (buildCase ctx prob scnm scty') hit debug $ \ _ => "GOTALTS \{show alts}" - when (length' (mapMaybe id alts) == 0) $ \ _ => error (fc) "no alts for \{show scty'}" - pure $ Case fc sctm (mapMaybe id alts) + let alts' = mapMaybe id alts + when (length' alts' == 0) $ \ _ => error (fc) "no alts for \{show scty'}" + -- build a default case for missed constructors + case miss' of + Nil => pure $ Case fc sctm (mapMaybe id alts) + _ => do + -- ctx prob fc scnm + default <- buildDefault ctx prob fc scnm + pure $ Case fc sctm (snoc alts' default) + PatLit fc v => do let tyname = litTyName v case scty' of diff --git a/port/Lib/Prettier.newt b/port/Lib/Prettier.newt index a95ce1d..c30dbe2 100644 --- a/port/Lib/Prettier.newt +++ b/port/Lib/Prettier.newt @@ -4,6 +4,9 @@ module Lib.Prettier import Prelude import Data.Int +-- TODO I broke this when I converted from Nat to Int, and we're disabling it +-- by flattening the Doc for now. + -- `Doc` is a pretty printing document. Constructors are private, use -- methods below. `Alt` in particular has some invariants on it, see paper -- for details. (Something along the lines of "the first line of left is not diff --git a/port/Lib/Types.newt b/port/Lib/Types.newt index d80b4a9..0a8bdc1 100644 --- a/port/Lib/Types.newt +++ b/port/Lib/Types.newt @@ -453,6 +453,13 @@ catchError (MkM ma) handler = MkM $ \tc => do tryError : ∀ a. M a -> M (Either Error a) tryError ma = catchError (map Right ma) (pure ∘ Left) +filterM : ∀ a. (a → M Bool) → List a → M (List a) +filterM pred Nil = pure Nil +filterM pred (x :: xs) = do + check <- pred x + if check then _::_ x <$> filterM pred xs else filterM pred xs + + get : M TopContext get = MkM $ \ tc => pure $ Right (tc, tc) diff --git a/port/Main.newt b/port/Main.newt index 1a34c39..e262167 100644 --- a/port/Main.newt +++ b/port/Main.newt @@ -20,6 +20,7 @@ import Lib.Types import Lib.Syntax import Lib.Syntax import Node +import Serialize primNS : List String primNS = ("Prim" :: Nil) @@ -132,7 +133,11 @@ processModule importFC base stk qn@(QN ns nm) = do -- update modules with result, leave the rest of context in case this is top file top <- get mc <- readIORef top.metaCtx - let modules = updateMap modns (MkModCtx top.defs mc top.ops) top.modules + + let mod = MkModCtx top.defs mc top.ops + dumpModule qn src mod + + let modules = updateMap modns mod top.modules freshMC <- newIORef (MC EmptyMap 0 CheckAll) modify (\ top => MkTop modules top.imported top.ns top.defs top.metaCtx top.verbose top.errors top.ops) diff --git a/port/Serialize.newt b/port/Serialize.newt new file mode 100644 index 0000000..5dc7c2c --- /dev/null +++ b/port/Serialize.newt @@ -0,0 +1,28 @@ +module Serialize + +import Prelude +import Node +import Lib.Common +import Lib.Types +import Data.SortedMap + +-- this was an experiment, prepping for dumping module information +-- it ends up with out of memory dumping defs of some of the files. +-- Prelude is 114MB pretty-printed... gzip to 1M +pfunc dumpObject uses (MkIORes MkUnit fs): ∀ a. String → a → IO Unit := `(_,fn,a) => (w) => { + try { + let {EncFile} = require('./serializer') + let enc = EncFile.encode(a) + fs.writeFileSync(fn, enc) + } catch (e) {} + return MkIORes(null, MkUnit, w) +}` + +-- for now, include src and use that to see if something changed +dumpModule : QName → String → ModContext → M Unit +dumpModule qn src mod = do + let fn = "build/\{show qn}.newtmod" + let defs = listValues mod.modDefs + let ops = toList mod.ctxOps + let mctx = toList mod.modMetaCtx.metas + liftIO $ dumpObject fn (src,defs,ops,mctx) diff --git a/scripts/stats.py b/scripts/stats.py old mode 100644 new mode 100755 index fe72c54..4702c60 --- a/scripts/stats.py +++ b/scripts/stats.py @@ -1,7 +1,10 @@ +#!/usr/bin/env python3 +import sys +fn = sys.argv[1] stats = {} acc = '' name = '' -for line in open('newt.js'): +for line in open(fn): if line.startswith('const'): if name: stats[name] = len(acc) acc = line diff --git a/serializer.ts b/serializer.ts new file mode 100644 index 0000000..e938f2e --- /dev/null +++ b/serializer.ts @@ -0,0 +1,260 @@ +// Experimental serializer / deserializer for modules +// not completely wired in yet, serialization is running. + +const END = 0; +const LIST = 1; +const TUPLE = 2; +const INDUCT = 3; +const STRING = 4; +const NUMBER = 5; +const NULL = 6; +const te = new TextEncoder(); + +// TODO - next two functions are machine generated and need to be fixed +class DeserializationStream { + pos = 0; + buf: Uint8Array; + + constructor(buf: Uint8Array) { + this.buf = buf; + } + + readByte() { + return this.buf[this.pos++]; + } + + readVarint() { + let shift = 0; + let result = 0; + while (true) { + const byte = this.readByte(); + result |= (byte & 0x7f) << shift; + if ((byte & 0x80) === 0) break; + shift += 7; + } + return result; + } + + readSignedVarint() { + const n = this.readVarint(); + return (n >>> 1) ^ -(n & 1); + } + + readString() { + const length = this.readVarint(); + const bytes = this.buf.slice(this.pos, this.pos + length); + this.pos += length; + return new TextDecoder().decode(bytes); + } +} + +export class DecFile { + pool: string[] = [""]; + buf: DeserializationStream; + static decode(encoded: Uint8Array) { + return new DecFile(encoded).read() + } + constructor(data: Uint8Array) { + this.buf = new DeserializationStream(data); + this.readPool(); + } + + readPool() { + while (true) { + let str = this.buf.readString(); + if (!str.length) break + this.pool.push(str); + } + console.log('read pool', this.buf.pos) + } + + read(): any { + const type = this.buf.readByte(); + switch (type) { + case NULL: + return null; + case LIST: { + const list: any[] = []; + while (this.buf.buf[this.buf.pos] !== END) { + list.push(this.read()); + } + this.buf.pos++; + let rval: any = { tag: "Nil", 'h0': null }; + while (list.length) + rval = { tag: "_::_", h0: null, h1: list.pop(), h2: rval }; + return rval; + } + case TUPLE: { + const tuple: any[] = []; + while (this.buf.buf[this.buf.pos] !== END) { + tuple.push(this.read()); + } + this.buf.pos++; + let rval: any = tuple.pop(); + while (tuple.length) + rval = { tag: "_,_", h0: null, h1: null, h2: tuple.pop(), h3: rval }; + return rval; + } + case STRING: + return this.pool[this.buf.readVarint()]; + case NUMBER: + return this.buf.readSignedVarint(); + case INDUCT: + const tag = this.pool[this.buf.readVarint()]; + const obj: any = { tag }; + let i = 0; + while (this.buf.buf[this.buf.pos] !== END) { + obj[`h${i++}`] = this.read(); + } + this.buf.pos++; + return obj; + default: + debugger + throw new Error(`Unknown type: ${type}`); + } + } +} + +class SerializationStream { + pos = 0; + buf = new Uint8Array(1024 * 1024); + + ensure(size: number) { + if (this.buf.length - this.pos < size) { + const tmp = new Uint8Array(this.buf.length * 1.5); + tmp.set(this.buf); + this.buf = tmp; + } + } + + writeByte(n: number) { + this.ensure(1); + this.buf[this.pos++] = n % 256; + } + + writeVarint(n: number) { + while (n > 127) { + this.writeByte((n & 0x7f) | 0x80); + n >>= 7; + } + this.writeByte(n & 0x7f); + } + + writeSignedVarint(n: number) { + const zigzag = (n << 1) ^ (n >> 31); + this.writeVarint(zigzag); + } + + writeString(s: string) { + let data = te.encode(s); + this.ensure(data.byteLength + 4); + this.writeVarint(data.byteLength); + this.buf.set(data, this.pos); + this.pos += data.byteLength; + } + toUint8Array() { + return this.buf.slice(0, this.pos); + } +} + +export class EncFile { + poollen = 1; + pool = new SerializationStream(); + buf = new SerializationStream(); + pmap: Record = { "": 0 }; + + static encode(data: any) { + let f = new EncFile() + f.write(data) + f.pool.writeVarint(0) + return f.toUint8Array() + } + + writeString(s: string) { + let n = this.pmap[s]; + if (n === undefined) { + n = this.poollen++; + this.pool.writeString(s); + this.pmap[s] = n; + } + this.buf.writeVarint(n); + } + + write(a: any) { + // shouldn't happen? + if (a == null) { + this.buf.writeByte(NULL); + } else if (a.tag == "_::_") { + this.buf.writeByte(LIST); + for (; a.tag === "_::_"; a = a.h2) { + this.write(a.h1); + } + this.buf.writeByte(END); + } else if (a.tag == "_,_") { + this.buf.writeByte(TUPLE); + for (; a.tag === "_,_"; a = a.h3) { + this.write(a.h2); + } + this.write(a); + this.buf.writeByte(END); + } else if (typeof a === "string") { + this.buf.writeByte(STRING); + this.writeString(a); + } else if (typeof a === "number") { + this.buf.writeByte(NUMBER); + this.buf.writeSignedVarint(a); + } else if (a.tag) { + this.buf.writeByte(INDUCT); + this.writeString(a.tag); + // we're actually missing a bunch of data here... + // with null, hack is not needed. + let i = 0 + for (; i <= 20; i++) { + let key = 'h' + i + let v = a[key] + if (v === undefined) break + this.write(v); + } + if (a['h' + (i + 1)] !== undefined) { + throw new Error("BOOM") + } + this.buf.writeByte(END); + } else { + throw new Error(`handle ${typeof a} ${a} ${Object.keys(a)}`); + } + } + toUint8Array() { + const poolArray = this.pool.toUint8Array(); + const bufArray = this.buf.toUint8Array(); + const rval = new Uint8Array(poolArray.length + bufArray.length); + console.log('psize', poolArray.byteLength, poolArray.length) + rval.set(poolArray); + rval.set(bufArray, poolArray.length); + return rval; + } +} + +function deepEqual(a: any, b: any): boolean { + if (a === b) return true; + if (typeof a !== typeof b) return false; + if (a == null || b == null) return false; + if (typeof a !== "object") return false; + + if (Array.isArray(a)) { + if (!Array.isArray(b) || a.length !== b.length) return false; + for (let i = 0; i < a.length; i++) { + if (!deepEqual(a[i], b[i])) return false; + } + return true; + } + + const keysA = Object.keys(a); + const keysB = Object.keys(b); + if (keysA.length !== keysB.length) return false; + + for (const key of keysA) { + if (!deepEqual(a[key], b[key])) return false; + } + + return true; +} diff --git a/src/Lib/CompileExp.idr b/src/Lib/CompileExp.idr index b34b15c..2af000e 100644 --- a/src/Lib/CompileExp.idr +++ b/src/Lib/CompileExp.idr @@ -116,6 +116,7 @@ compileTerm (Meta _ k) = pure $ CRef "meta$\{show k}" -- FIXME compileTerm (Lam _ nm _ _ t) = pure $ CLam nm !(compileTerm t) compileTerm tm@(App _ _ _) with (funArgs tm) _ | (Meta _ k, args) = do + error (getFC tm) "Compiling an unsolved meta \{showTm tm}" info (getFC tm) "Compiling an unsolved meta \{showTm tm}" pure $ CApp (CRef "Meta\{show k}") [] Z _ | (t@(Ref fc nm _), args) = do diff --git a/tests/TestMap.newt b/tests/TestMap.newt index 5653ced..87c580a 100644 --- a/tests/TestMap.newt +++ b/tests/TestMap.newt @@ -1,5 +1,6 @@ module TestMap +import Prelude import SortedMap main : IO Unit