Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 98 additions & 24 deletions Ix/Aiur/Check.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,99 @@ instance : ToString CheckError where
toString e := repr e |>.pretty

/--
Constructs a map of declarations from a toplevel, ensuring that there are no duplicate names
for functions and datatypes.
Eagerly expands type aliases, building the aliasMap on demand and detecting cycles.
-/
partial def expandTypeM (visited : Std.HashSet Global) (toplevelAliases : Array TypeAlias) :
Typ → StateT (Std.HashMap Global Typ) (Except CheckError) Typ
| .unit => pure .unit
| .field => pure .field
| .pointer t => do pure $ .pointer (← expandTypeM visited toplevelAliases t)
| .function inputs output => do
let inputs' ← inputs.mapM (expandTypeM visited toplevelAliases)
let output' ← expandTypeM visited toplevelAliases output
pure $ .function inputs' output'
| .tuple ts => do
let ts' ← ts.mapM (expandTypeM visited toplevelAliases)
pure $ .tuple ts'
| .array t n => do
let t' ← expandTypeM visited toplevelAliases t
pure $ .array t' n
| .typeRef g => do
let aliasMap ← get
-- Check if already expanded
if let some expanded := aliasMap[g]? then
return expanded
-- Check if it's an alias
if let some (alias : TypeAlias) := toplevelAliases.find? (·.name == g) then
-- Check for cycle
if visited.contains g then
throw $ CheckError.undefinedGlobal ⟨.mkSimple s!"Cycle detected in type alias `{g}`"⟩
-- Expand the alias recursively
let expanded ← expandTypeM (visited.insert g) toplevelAliases alias.expansion
-- Save to aliasMap
set (aliasMap.insert g expanded)
return expanded
else
-- It's a dataType, keep it
pure $ .typeRef g

/--
Constructs a map of declarations from a toplevel, expanding all type aliases.
Type aliases are not added to the declarations - they are eliminated during construction.
-/
def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do
let map ← toplevel.functions.foldlM (init := default)
fun acc function => addDecl acc Function.name .function function
toplevel.dataTypes.foldlM (init := map) addDataType
where
ensureUnique name (map : IndexMap Global _) := do
if map.containsKey name then throw $ .duplicatedDefinition name
addDecl {α : Type} map (nameFn : α → Global) (wrapper : α → Declaration) (inner : α) := do
ensureUnique (nameFn inner) map
pure $ map.insert (nameFn inner) (wrapper inner)
addDataType map dataType := do
let dataTypeName := dataType.name
ensureUnique dataTypeName map
let map' := map.insert dataTypeName (.dataType dataType)
dataType.constructors.foldlM (init := map') fun acc (constructor : Constructor) =>
addDecl acc (dataTypeName.pushNamespace ∘ Constructor.nameHead) (.constructor dataType) constructor
-- Check for duplicate names among aliases
let mut allNames : Std.HashSet Global := {}
for alias in toplevel.typeAliases do
if allNames.contains alias.name then
throw $ CheckError.duplicatedDefinition alias.name
allNames := allNames.insert alias.name

-- Build aliasMap by expanding all aliases in order
let initAliasMap := {}
let (_, finalAliasMap) ← (toplevel.typeAliases.mapM fun (alias : TypeAlias) => do
-- Expand and save the alias
let expanded ← expandTypeM {} toplevel.typeAliases alias.expansion
modify fun (aliasMap : Std.HashMap Global Typ) => aliasMap.insert alias.name expanded
).run initAliasMap

-- Helper to expand types in the declarations using the built aliasMap
let expandTyp (typ : Typ) : Except CheckError Typ :=
(expandTypeM {} toplevel.typeAliases typ).run' finalAliasMap

-- Add functions with expanded types
let mut decls : Decls := default
for function in toplevel.functions do
if allNames.contains function.name then
throw $ CheckError.duplicatedDefinition function.name
allNames := allNames.insert function.name
let inputs' ← function.inputs.mapM fun (loc, typ) => do
let typ' ← expandTyp typ
pure (loc, typ')
let output' ← expandTyp function.output
let function' := { function with inputs := inputs', output := output' }
decls := decls.insert function.name (.function function')

-- Add datatypes with expanded types
for dataType in toplevel.dataTypes do
if allNames.contains dataType.name then
throw $ CheckError.duplicatedDefinition dataType.name
allNames := allNames.insert dataType.name
let mut constructors : List Constructor := []
for ctor in dataType.constructors do
let argTypes' ← ctor.argTypes.mapM expandTyp
constructors := constructors.concat { ctor with argTypes := argTypes' }
let dataType' := { dataType with constructors }
decls := decls.insert dataType.name (.dataType dataType')
-- Add constructors
for ctor in constructors do
let ctorName := dataType.name.pushNamespace ctor.nameHead
if allNames.contains ctorName then
throw $ CheckError.duplicatedDefinition ctorName
allNames := allNames.insert ctorName
decls := decls.insert ctorName (.constructor dataType' ctor)

pure decls

structure CheckContext where
decls : Decls
Expand All @@ -73,7 +147,7 @@ def refLookup (global : Global) : CheckM Typ := do
| some (.constructor dataType constructor) =>
let args := constructor.argTypes
unless args.isEmpty do (throw $ .wrongNumArgs global args.length 0)
pure $ .dataType $ dataType.name
pure $ .typeRef $ dataType.name
| some _ => throw $ .notAValue global
| none => throw $ .unboundVariable global

Expand Down Expand Up @@ -278,7 +352,7 @@ partial def inferUnqualifiedApp (func : Global) (unqualifiedFunc : String) (args
pure ⟨function.output, .app func args, false⟩
| some (.constructor dataType constr) => do
let args ← checkArgsAndInputs func args constr.argTypes
pure ⟨.dataType dataType.name, .app func args, false⟩
pure ⟨.typeRef dataType.name, .app func args, false⟩
| _ => throw $ .cannotApply func
where
checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do
Expand All @@ -298,7 +372,7 @@ partial def inferQualifiedApp (func : Global) (args : List Term) : CheckM TypedT
pure ⟨function.output, .app func args, false⟩
| some (.constructor dataType constr) =>
let args ← checkArgsAndInputs func args constr.argTypes
pure ⟨.dataType dataType.name, .app func args, false⟩
pure ⟨.typeRef dataType.name, .app func args, false⟩
| _ => throw $ .cannotApply func
where
checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do
Expand Down Expand Up @@ -399,7 +473,7 @@ where
let typ' := .function (function.inputs.map Prod.snd) function.output
unless typ == typ' do throw $ .typeMismatch typ typ'
pure []
| (.ref constrRef pats, .dataType dataTypeRef) => do
| (.ref constrRef pats, .typeRef dataTypeRef) => do
let ctx ← read
let some (.dataType dataType) := ctx.decls.getByKey dataTypeRef | unreachable!
let some (.constructor dataType' constr) := ctx.decls.getByKey constrRef | throw $ .notAConstructor constrRef
Expand Down Expand Up @@ -463,10 +537,10 @@ where
wellFormedType : Typ → EStateM CheckError (Std.HashSet Global) Unit
| .tuple typs => typs.forM wellFormedType
| .pointer pointerTyp => wellFormedType pointerTyp
| .dataType dataTypeRef => match decls.getByKey dataTypeRef with
| .typeRef typeRef => match decls.getByKey typeRef with
| some (.dataType _) => pure ()
| some _ => throw $ .notADataType dataTypeRef
| none => throw $ .undefinedGlobal dataTypeRef
| some _ => throw $ .notADataType typeRef
| none => throw $ .undefinedGlobal typeRef
| _ => pure ()

/-- Checks a function to ensure its body's type matches its declared output type. -/
Expand Down
2 changes: 1 addition & 1 deletion Ix/Aiur/Compile.lean
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def typSize (layoutMap : LayoutMap) : Typ → Except String Nat
| .array typ n => do
let size ← typSize layoutMap typ
pure $ n * size
| .dataType g => match layoutMap[g]? with
| .typeRef g => match layoutMap[g]? with
| some (.dataType layout) => pure layout.size
| _ => throw "Impossible case"

Expand Down
31 changes: 22 additions & 9 deletions Ix/Aiur/Meta.lean
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ partial def elabTyp : ElabStxCat `typ
mkAppM ``Typ.pointer #[← elabTyp t]
| `(typ| $[.]?$i:ident) => do
let g ← mkAppM ``Global.mk #[toExpr i.getId]
mkAppM ``Typ.dataType #[g]
mkAppM ``Typ.typeRef #[g]
| `(typ| fn() -> $t:typ) => do
mkAppM ``Typ.function #[← elabEmptyList ``Typ, ← elabTyp t]
| `(typ| fn($t$[, $ts:typ]*) -> $t':typ) => do
Expand Down Expand Up @@ -441,6 +441,15 @@ def elabDataType : ElabStxCat `data_type
mkAppM ``DataType.mk #[g, ← elabList c cs elabConstructor ``Constructor]
| stx => throw $ .error stx "Invalid syntax for data type"

declare_syntax_cat type_alias
syntax "type " ident " = " typ : type_alias

def elabTypeAlias : ElabStxCat `type_alias
| `(type_alias| type $n:ident = $t:typ) => do
let g ← mkAppM ``Global.mk #[toExpr n.getId]
mkAppM ``TypeAlias.mk #[g, ← elabTyp t]
| stx => throw $ .error stx "Invalid syntax for type alias"

declare_syntax_cat bind
syntax ident ": " typ : bind

Expand Down Expand Up @@ -477,27 +486,31 @@ where
| some typ => elabTyp typ

declare_syntax_cat declaration
syntax function : declaration
syntax data_type : declaration
syntax function : declaration
syntax data_type : declaration
syntax type_alias : declaration

def accElabDeclarations (declarations : (Array Expr × Array Expr))
(stx : TSyntax `declaration) : TermElabM (Array Expr × Array Expr) :=
let (dataTypes, functions) := declarations
def accElabDeclarations (declarations : (Array Expr × Array Expr × Array Expr))
(stx : TSyntax `declaration) : TermElabM (Array Expr × Array Expr × Array Expr) :=
let (dataTypes, typeAliases, functions) := declarations
match stx with
| `(declaration| $f:function) => do
pure (dataTypes, functions.push $ ← elabFunction f)
pure (dataTypes, typeAliases, functions.push $ ← elabFunction f)
| `(declaration| $d:data_type) => do
pure (dataTypes.push $ ← elabDataType d, functions)
pure (dataTypes.push $ ← elabDataType d, typeAliases, functions)
| `(declaration| $ta:type_alias) => do
pure (dataTypes, typeAliases.push $ ← elabTypeAlias ta, functions)
| stx => throw $ .error stx "Invalid syntax for declaration"

declare_syntax_cat toplevel
syntax declaration* : toplevel

def elabToplevel : ElabStxCat `toplevel
| `(toplevel| $[$ds:declaration]*) => do
let (dataTypes, functions) ← ds.foldlM (init := default) accElabDeclarations
let (dataTypes, typeAliases, functions) ← ds.foldlM (init := default) accElabDeclarations
mkAppM ``Toplevel.mk #[
← mkArrayLit (mkConst ``DataType) dataTypes.toList,
← mkArrayLit (mkConst ``TypeAlias) typeAliases.toList,
← mkArrayLit (mkConst ``Function) functions.toList,
]
| stx => throw $ .error stx "Invalid syntax for toplevel"
Expand Down
7 changes: 4 additions & 3 deletions Ix/Aiur/Simple.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ where
def Toplevel.checkAndSimplify (toplevel : Toplevel) : Except CheckError TypedDecls := do
let decls ← toplevel.mkDecls
wellFormedDecls decls
-- The first check happens on the original terms.
toplevel.functions.forM fun function => do
let _ ← (checkFunction function) (getFunctionContext function decls)
-- The first check happens on the original terms (but with expanded types).
decls.forM fun (_, decl) => do
if let .function f := decl then
let _ ← (checkFunction f) (getFunctionContext f decls)
let decls := decls.map fun decl => match decl with
| .function f => .function { f with body := simplifyTerm decls f.body }
| _ => decl
Expand Down
22 changes: 17 additions & 5 deletions Ix/Aiur/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ inductive Typ where
| tuple : Array Typ → Typ
| array : Typ → Nat → Typ
| pointer : Typ → Typ
| dataType : Global → Typ
| typeRef : Global → Typ
| function : List Typ → Typ → Typ
deriving Repr, BEq, Hashable, Inhabited

Expand Down Expand Up @@ -176,6 +176,11 @@ structure DataType where
constructors : List Constructor
deriving Repr, BEq, Inhabited

structure TypeAlias where
name : Global
expansion : Typ
deriving Repr, BEq, Inhabited

structure Function where
name : Global
inputs : List (Local × Typ)
Expand All @@ -186,29 +191,36 @@ structure Function where

structure Toplevel where
dataTypes : Array DataType
typeAliases : Array TypeAlias
functions : Array Function
deriving Repr

def Toplevel.getFuncIdx (toplevel : Toplevel) (funcName : Lean.Name) : Option Nat := do
toplevel.functions.findIdx? fun function => function.name.toName == funcName

def Toplevel.merge (x y : Toplevel) : Except Global Toplevel := do
let ⟨xDataTypes, xFunctions⟩ := x
let ⟨yDataTypes, yFunctions⟩ := y
let ⟨xDataTypes, xTypeAliases, xFunctions⟩ := x
let ⟨yDataTypes, yTypeAliases, yFunctions⟩ := y
let mut globals : Std.HashSet Global := ∅
let mut dataTypes := .emptyWithCapacity (xDataTypes.size + yDataTypes.size)
let mut typeAliases := .emptyWithCapacity (xTypeAliases.size + yTypeAliases.size)
let mut functions := .emptyWithCapacity (xFunctions.size + yFunctions.size)
for dtSet in [xDataTypes, yDataTypes] do
for dt in dtSet do
if globals.contains dt.name then throw dt.name
globals := globals.insert dt.name
dataTypes := dataTypes.push dt
for taSet in [xTypeAliases, yTypeAliases] do
for ta in taSet do
if globals.contains ta.name then throw ta.name
globals := globals.insert ta.name
typeAliases := typeAliases.push ta
for fSet in [xFunctions, yFunctions] do
for f in fSet do
if globals.contains f.name then throw f.name
globals := globals.insert f.name
functions := functions.push f
pure ⟨dataTypes, functions⟩
pure ⟨dataTypes, typeAliases, functions⟩

inductive Declaration
| function : Function → Declaration
Expand Down Expand Up @@ -255,7 +267,7 @@ partial def Typ.size (decls : TypedDecls) (visited : HashSet Global := {}) :
| .array t n => do
let tSize ← t.size decls visited
pure $ n * tSize
| .dataType g => match decls.getByKey g with
| .typeRef g => match decls.getByKey g with
| some (.dataType data) => data.size decls visited
| _ => throw s!"Datatype not found: `{g}`"

Expand Down
3 changes: 3 additions & 0 deletions Ix/IndexMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def map (f : β → β) : IndexMap α β := by
@[inline] def foldrM [Monad μ] (f : α × β → γ → μ γ) (init : γ) : μ γ :=
m.pairs.foldrM f init

@[inline] def forM [Monad μ] (f : α × β → μ PUnit) : μ PUnit :=
m.pairs.forM f

end IndexMap

end
4 changes: 2 additions & 2 deletions Ix/IxVM/Blake3.lean
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def blake3 := ⟦
block_buffer: [[G; 4]; 16],
block_index: G,
chunk_index: G,
chunk_count: [G; 8],
chunk_count: U64,
block_digest: [[G; 4]; 8],
layer: Layer
) -> Layer {
Expand Down Expand Up @@ -307,7 +307,7 @@ def blake3 := ⟦
fn blake3_compress(
chaining_value: [[G; 4]; 8],
block_words: [[G; 4]; 16],
counter: [G; 8],
counter: U64,
block_len: G,
flags: G
) -> [[G; 4]; 8] {
Expand Down
Loading
Loading