From 24dafcdb1a1df220fd25a6560674af40ec7de2a7 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 17 Mar 2026 17:23:05 -0300 Subject: [PATCH 1/6] U64 alias --- Ix/IxVM/Blake3.lean | 4 +-- Ix/IxVM/ByteStream.lean | 26 +++++++++--------- Ix/IxVM/KernelTypes.lean | 58 ++++++++++++++++++++-------------------- Ix/IxVM/Sha256.lean | 4 +-- 4 files changed, 47 insertions(+), 45 deletions(-) diff --git a/Ix/IxVM/Blake3.lean b/Ix/IxVM/Blake3.lean index 0603a4e0..df1a5ca8 100644 --- a/Ix/IxVM/Blake3.lean +++ b/Ix/IxVM/Blake3.lean @@ -91,7 +91,7 @@ def blake3 := ⟦ block_buffer: [[G; 4]; 16], block_index: G, chunk_index: G, - chunk_count: [G; 8], + chunk_count: U64, block_digest: [[G; 4]; 8], layer: Layer ) -> Layer { @@ -307,7 +307,7 @@ def blake3 := ⟦ fn blake3_compress( chaining_value: [[G; 4]; 8], block_words: [[G; 4]; 16], - counter: [G; 8], + counter: U64, block_len: G, flags: G ) -> [[G; 4]; 8] { diff --git a/Ix/IxVM/ByteStream.lean b/Ix/IxVM/ByteStream.lean index 425ba357..e2779905 100644 --- a/Ix/IxVM/ByteStream.lean +++ b/Ix/IxVM/ByteStream.lean @@ -11,6 +11,8 @@ def byteStream := ⟦ Nil } + type U64 = [G; 8] + fn byte_stream_concat(a: ByteStream, b: ByteStream) -> ByteStream { match a { ByteStream.Nil => b, @@ -19,7 +21,7 @@ def byteStream := ⟦ } } - fn byte_stream_length(bytes: ByteStream) -> [G; 8] { + fn byte_stream_length(bytes: ByteStream) -> U64 { match bytes { ByteStream.Nil => [0; 8], ByteStream.Cons(_, &rest) => relaxed_u64_succ(byte_stream_length(rest)), @@ -48,7 +50,7 @@ def byteStream := ⟦ -- Count bytes needed to represent a u64. -- Important: this implementation differs from the Lean and Rust ones, returning -- 1 for [0; 8] instead of 0. - fn u64_byte_count(x: [G; 8]) -> G { + fn u64_byte_count(x: U64) -> G { match x { [_, 0, 0, 0, 0, 0, 0, 0] => 1, [_, _, 0, 0, 0, 0, 0, 0] => 2, @@ -61,7 +63,7 @@ def byteStream := ⟦ } } - fn u64_is_zero(x: [G; 8]) -> G { + fn u64_is_zero(x: U64) -> G { match x { [0, 0, 0, 0, 0, 0, 0, 0] => 1, _ => 0, @@ -69,7 +71,7 @@ def byteStream := ⟦ } -- Reconstructs a byte from its bits in little-endian. - fn u8_recompose(bits: [G; 8]) -> G { + fn u8_recompose(bits: U64) -> G { let [b0, b1, b2, b3, b4, b5, b6, b7] = bits; b0 + 2 * b1 + 4 * b2 + 8 * b3 + 16 * b4 + 32 * b5 + 64 * b6 + 128 * b7 } @@ -115,7 +117,7 @@ def byteStream := ⟦ } -- Byte-by-byte `u64` equality - fn u64_eq(a: [G; 8], b: [G; 8]) -> G { + fn u64_eq(a: U64, b: U64) -> G { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; match [a0 - b0, a1 - b1, a2 - b2, a3 - b3, a4 - b4, a5 - b5, a6 - b6, a7 - b7] { @@ -125,7 +127,7 @@ def byteStream := ⟦ } -- `u64` addition with carry propagation (little-endian bytes) - fn u64_add(a: [G; 8], b: [G; 8]) -> [G; 8] { + fn u64_add(a: U64, b: U64) -> U64 { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; let (s0, c1) = u8_add(a0, b0); @@ -153,7 +155,7 @@ def byteStream := ⟦ } -- `u64` subtraction via repeated decrement (correct for small b) - fn u64_sub(a: [G; 8], b: [G; 8]) -> [G; 8] { + fn u64_sub(a: U64, b: U64) -> U64 { match u64_is_zero(b) { 1 => a, 0 => u64_sub(relaxed_u64_pred(a), relaxed_u64_pred(b)), @@ -169,7 +171,7 @@ def byteStream := ⟦ -- Computes the successor of an `u64` assumed to be properly represented in -- little-endian bytes. If that's not the case, this implementation has UB. - fn relaxed_u64_succ(bytes: [G; 8]) -> [G; 8] { + fn relaxed_u64_succ(bytes: U64) -> U64 { let [b0, b1, b2, b3, b4, b5, b6, b7] = bytes; match b0 { 255 => match b1 { @@ -198,7 +200,7 @@ def byteStream := ⟦ } } - fn relaxed_u64_be_add_2_bytes(u64: [G; 8], bs: [G; 2]) -> [G; 8] { + fn relaxed_u64_be_add_2_bytes(u64: U64, bs: [G; 2]) -> U64 { -- Byte 0, no initial carry let (sum0, carry1) = u8_add(u64[7], bs[1]); @@ -240,13 +242,13 @@ def byteStream := ⟦ } enum U64List { - Cons([G; 8], &U64List), + Cons(U64, &U64List), Nil } -- Computes the predecessor of an `u64` assumed to be properly represented in -- little-endian bytes. If that's not the case, this implementation has UB. - fn relaxed_u64_pred(bytes: [G; 8]) -> [G; 8] { + fn relaxed_u64_pred(bytes: U64) -> U64 { let [b0, b1, b2, b3, b4, b5, b6, b7] = bytes; match b0 { 0 => match b1 { @@ -275,7 +277,7 @@ def byteStream := ⟦ } } - fn u64_list_length(xs: U64List) -> [G; 8] { + fn u64_list_length(xs: U64List) -> U64 { match xs { U64List.Nil => [0; 8], U64List.Cons(_, rest) => relaxed_u64_succ(u64_list_length(load(rest))), diff --git a/Ix/IxVM/KernelTypes.lean b/Ix/IxVM/KernelTypes.lean index 945ea584..8d8aa71b 100644 --- a/Ix/IxVM/KernelTypes.lean +++ b/Ix/IxVM/KernelTypes.lean @@ -10,13 +10,13 @@ def kernelTypes := ⟦ -- Universe Levels -- ============================================================================ - -- TODO: Param index could be G instead of [G; 8] (Goldilocks is big enough) + -- TODO: Param index could be G instead of U64 (Goldilocks is big enough) enum KLevel { Zero, Succ(&KLevel), Max(&KLevel, &KLevel), IMax(&KLevel, &KLevel), - Param([G; 8]) + Param(U64) } enum KLevelList { @@ -28,37 +28,37 @@ def kernelTypes := ⟦ -- Literals -- ============================================================================ - -- TODO: [G; 8] is insufficient — Nat and String literals are arbitrarily large. + -- TODO: U64 is insufficient — Nat and String literals are arbitrarily large. -- Nat should be a list of u64 limbs (little-endian bignum). -- Str should be a list of bytes (or a ByteStream). -- This also requires fixing the blob ingress and conversion to produce these types. enum KLiteral { - Nat([G; 8]), - Str([G; 8]) + Nat(U64), + Str(U64) } -- ============================================================================ -- Expressions (de Bruijn indexed, no binder info or names) -- ============================================================================ - -- TODO: all [G; 8] here (BVar index, Const index, Proj indices) could be G + -- TODO: all U64 here (BVar index, Const index, Proj indices) could be G enum KExpr { - BVar([G; 8]), + BVar(U64), Srt(&KLevel), - Const([G; 8], &KLevelList), + Const(U64, &KLevelList), App(&KExpr, &KExpr), Lam(&KExpr, &KExpr), Forall(&KExpr, &KExpr), Let(&KExpr, &KExpr, &KExpr), Lit(KLiteral), - Proj([G; 8], [G; 8], &KExpr) + Proj(U64, U64, &KExpr) } -- ============================================================================ -- Values (NbE semantic domain) -- ============================================================================ - -- TODO: all [G; 8] here could be G. In particular, FVar's de Bruijn level + -- TODO: all U64 here could be G. In particular, FVar's de Bruijn level -- is a runtime counter (not from Ixon) and would benefit most from the change, -- since it would simplify depth tracking throughout the kernel to use plain G -- arithmetic instead of u64 operations. @@ -67,10 +67,10 @@ def kernelTypes := ⟦ Lit(KLiteral), Lam(&KVal, &KExpr, &KValEnv), Pi(&KVal, &KExpr, &KValEnv), - Ctor([G; 8], &KLevelList, [G; 8], &KValList), - FVar([G; 8], &KValList), - Const([G; 8], &KLevelList, &KValList), - Proj([G; 8], [G; 8], &KVal, &KValList), + Ctor(U64, &KLevelList, U64, &KValList), + FVar(U64, &KValList), + Const(U64, &KLevelList, &KValList), + Proj(U64, U64, &KVal, &KValList), Thunk(&KExpr, &KValEnv) } @@ -90,11 +90,11 @@ def kernelTypes := ⟦ -- Reducibility Hints -- ============================================================================ - -- TODO: Regular hint could be G instead of [G; 8] + -- TODO: Regular hint could be G instead of U64 enum KHints { Opaque, Abbrev, - Regular([G; 8]) + Regular(U64) } -- ============================================================================ @@ -122,9 +122,9 @@ def kernelTypes := ⟦ -- Recursor Rule: (ctor_const_idx, num_fields, rhs) -- ============================================================================ - -- TODO: ctor_const_idx and num_fields could be G instead of [G; 8] + -- TODO: ctor_const_idx and num_fields could be G instead of U64 enum KRecRule { - Mk([G; 8], [G; 8], &KExpr) + Mk(U64, U64, &KExpr) } enum KRecRuleList { @@ -148,26 +148,26 @@ def kernelTypes := ⟦ -- num_motives, num_minors, rules, k_flag, is_unsafe) -- ============================================================================ - -- TODO: could be a list of G instead of [G; 8] + -- TODO: could be a list of G instead of U64 enum KU64List { - Cons([G; 8], &KU64List), + Cons(U64, &KU64List), Nil } - -- TODO: all [G; 8] fields (num_levels, num_params, num_indices, etc.) + -- TODO: all U64 fields (num_levels, num_params, num_indices, etc.) -- could be G instead. The Goldilocks field is large enough for any -- realistic value, and using G would simplify arithmetic throughout -- the kernel (native field ops instead of u64_add/u64_sub/u64_eq/etc.). -- This requires a corresponding change in Convert.lean to emit G values. enum KConstantInfo { - Axiom([G; 8], &KExpr, G), - Defn([G; 8], &KExpr, &KExpr, KHints, KSafety), - Thm([G; 8], &KExpr, &KExpr), - Opaque([G; 8], &KExpr, &KExpr, G), - Quot([G; 8], &KExpr, KQuotKind), - Induct([G; 8], &KExpr, [G; 8], [G; 8], &KU64List, G, G, G), - Ctor([G; 8], &KExpr, [G; 8], [G; 8], [G; 8], [G; 8], G), - Rec([G; 8], &KExpr, [G; 8], [G; 8], [G; 8], [G; 8], &KRecRuleList, G, G) + Axiom(U64, &KExpr, G), + Defn(U64, &KExpr, &KExpr, KHints, KSafety), + Thm(U64, &KExpr, &KExpr), + Opaque(U64, &KExpr, &KExpr, G), + Quot(U64, &KExpr, KQuotKind), + Induct(U64, &KExpr, U64, U64, &KU64List, G, G, G), + Ctor(U64, &KExpr, U64, U64, U64, U64, G), + Rec(U64, &KExpr, U64, U64, U64, U64, &KRecRuleList, G, G) } -- The global environment: a list of constants indexed by position diff --git a/Ix/IxVM/Sha256.lean b/Ix/IxVM/Sha256.lean index 09e6528e..697d1b72 100644 --- a/Ix/IxVM/Sha256.lean +++ b/Ix/IxVM/Sha256.lean @@ -44,7 +44,7 @@ def sha256 := ⟦ sha256_aux(stream, [0; 8], state) } - fn sha256_aux(stream: ByteStream, len_be: [G; 8], state: [[G; 4]; 8]) -> [[G; 4]; 8] { + fn sha256_aux(stream: ByteStream, len_be: U64, state: [[G; 4]; 8]) -> [[G; 4]; 8] { let W = [[0; 4]; 16]; match stream { ByteStream.Nil => @@ -586,7 +586,7 @@ def sha256 := ⟦ } } - fn fill_W_with_length_and_run_rounds(len_be: [G; 8], state: [[G; 4]; 8]) -> [[G; 4]; 8] { + fn fill_W_with_length_and_run_rounds(len_be: U64, state: [[G; 4]; 8]) -> [[G; 4]; 8] { let W = [[0; 4]; 16]; let W = set(W, 14, [len_be[0], len_be[1], len_be[2], len_be[3]]); let W = set(W, 15, [len_be[4], len_be[5], len_be[6], len_be[7]]); From 06298553c136826f9d6945b024ef89d7c4ad1022 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Thu, 19 Mar 2026 10:38:12 -0300 Subject: [PATCH 2/6] type alias with eager expansion --- Ix/Aiur/Check.lean | 21 ++++++----- Ix/Aiur/Compile.lean | 2 +- Ix/Aiur/Meta.lean | 31 +++++++++++----- Ix/Aiur/Simple.lean | 5 +++ Ix/Aiur/Term.lean | 84 +++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 120 insertions(+), 23 deletions(-) diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index 38690d83..81a2b24b 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -37,11 +37,13 @@ instance : ToString CheckError where /-- Constructs a map of declarations from a toplevel, ensuring that there are no duplicate names -for functions and datatypes. +for functions, datatypes, and type aliases. -/ def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do let map ← toplevel.functions.foldlM (init := default) fun acc function => addDecl acc Function.name .function function + let map ← toplevel.typeAliases.foldlM (init := map) + fun acc alias => addDecl (α := TypeAlias) acc (·.name) .typeAlias alias toplevel.dataTypes.foldlM (init := map) addDataType where ensureUnique name (map : IndexMap Global _) := do @@ -73,7 +75,7 @@ def refLookup (global : Global) : CheckM Typ := do | some (.constructor dataType constructor) => let args := constructor.argTypes unless args.isEmpty do (throw $ .wrongNumArgs global args.length 0) - pure $ .dataType $ dataType.name + pure $ .typeRef $ dataType.name | some _ => throw $ .notAValue global | none => throw $ .unboundVariable global @@ -278,7 +280,7 @@ partial def inferUnqualifiedApp (func : Global) (unqualifiedFunc : String) (args pure ⟨function.output, .app func args, false⟩ | some (.constructor dataType constr) => do let args ← checkArgsAndInputs func args constr.argTypes - pure ⟨.dataType dataType.name, .app func args, false⟩ + pure ⟨.typeRef dataType.name, .app func args, false⟩ | _ => throw $ .cannotApply func where checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do @@ -298,7 +300,7 @@ partial def inferQualifiedApp (func : Global) (args : List Term) : CheckM TypedT pure ⟨function.output, .app func args, false⟩ | some (.constructor dataType constr) => let args ← checkArgsAndInputs func args constr.argTypes - pure ⟨.dataType dataType.name, .app func args, false⟩ + pure ⟨.typeRef dataType.name, .app func args, false⟩ | _ => throw $ .cannotApply func where checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do @@ -399,7 +401,7 @@ where let typ' := .function (function.inputs.map Prod.snd) function.output unless typ == typ' do throw $ .typeMismatch typ typ' pure [] - | (.ref constrRef pats, .dataType dataTypeRef) => do + | (.ref constrRef pats, .typeRef dataTypeRef) => do let ctx ← read let some (.dataType dataType) := ctx.decls.getByKey dataTypeRef | unreachable! let some (.constructor dataType' constr) := ctx.decls.getByKey constrRef | throw $ .notAConstructor constrRef @@ -455,6 +457,8 @@ where if !map.contains dataType.name then set $ map.insert dataType.name dataType.constructors.flatMap (·.argTypes) |>.forM wellFormedType + | .typeAlias alias => do + wellFormedType alias.expansion | .function function => do wellFormedType function.output function.inputs.forM fun (_, typ) => wellFormedType typ @@ -463,10 +467,11 @@ where wellFormedType : Typ → EStateM CheckError (Std.HashSet Global) Unit | .tuple typs => typs.forM wellFormedType | .pointer pointerTyp => wellFormedType pointerTyp - | .dataType dataTypeRef => match decls.getByKey dataTypeRef with + | .typeRef typeRef => match decls.getByKey typeRef with | some (.dataType _) => pure () - | some _ => throw $ .notADataType dataTypeRef - | none => throw $ .undefinedGlobal dataTypeRef + | some (.typeAlias _) => pure () + | some _ => throw $ .notADataType typeRef + | none => throw $ .undefinedGlobal typeRef | _ => pure () /-- Checks a function to ensure its body's type matches its declared output type. -/ diff --git a/Ix/Aiur/Compile.lean b/Ix/Aiur/Compile.lean index 9b553b2a..a741e2d9 100644 --- a/Ix/Aiur/Compile.lean +++ b/Ix/Aiur/Compile.lean @@ -240,7 +240,7 @@ def typSize (layoutMap : LayoutMap) : Typ → Except String Nat | .array typ n => do let size ← typSize layoutMap typ pure $ n * size -| .dataType g => match layoutMap[g]? with +| .typeRef g => match layoutMap[g]? with | some (.dataType layout) => pure layout.size | _ => throw "Impossible case" diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index b0eaf280..2f077428 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -86,7 +86,7 @@ partial def elabTyp : ElabStxCat `typ mkAppM ``Typ.pointer #[← elabTyp t] | `(typ| $[.]?$i:ident) => do let g ← mkAppM ``Global.mk #[toExpr i.getId] - mkAppM ``Typ.dataType #[g] + mkAppM ``Typ.typeRef #[g] | `(typ| fn() -> $t:typ) => do mkAppM ``Typ.function #[← elabEmptyList ``Typ, ← elabTyp t] | `(typ| fn($t$[, $ts:typ]*) -> $t':typ) => do @@ -441,6 +441,15 @@ def elabDataType : ElabStxCat `data_type mkAppM ``DataType.mk #[g, ← elabList c cs elabConstructor ``Constructor] | stx => throw $ .error stx "Invalid syntax for data type" +declare_syntax_cat type_alias +syntax "type " ident " = " typ : type_alias + +def elabTypeAlias : ElabStxCat `type_alias + | `(type_alias| type $n:ident = $t:typ) => do + let g ← mkAppM ``Global.mk #[toExpr n.getId] + mkAppM ``TypeAlias.mk #[g, ← elabTyp t] + | stx => throw $ .error stx "Invalid syntax for type alias" + declare_syntax_cat bind syntax ident ": " typ : bind @@ -477,17 +486,20 @@ where | some typ => elabTyp typ declare_syntax_cat declaration -syntax function : declaration -syntax data_type : declaration +syntax function : declaration +syntax data_type : declaration +syntax type_alias : declaration -def accElabDeclarations (declarations : (Array Expr × Array Expr)) - (stx : TSyntax `declaration) : TermElabM (Array Expr × Array Expr) := - let (dataTypes, functions) := declarations +def accElabDeclarations (declarations : (Array Expr × Array Expr × Array Expr)) + (stx : TSyntax `declaration) : TermElabM (Array Expr × Array Expr × Array Expr) := + let (dataTypes, typeAliases, functions) := declarations match stx with | `(declaration| $f:function) => do - pure (dataTypes, functions.push $ ← elabFunction f) + pure (dataTypes, typeAliases, functions.push $ ← elabFunction f) | `(declaration| $d:data_type) => do - pure (dataTypes.push $ ← elabDataType d, functions) + pure (dataTypes.push $ ← elabDataType d, typeAliases, functions) + | `(declaration| $ta:type_alias) => do + pure (dataTypes, typeAliases.push $ ← elabTypeAlias ta, functions) | stx => throw $ .error stx "Invalid syntax for declaration" declare_syntax_cat toplevel @@ -495,9 +507,10 @@ syntax declaration* : toplevel def elabToplevel : ElabStxCat `toplevel | `(toplevel| $[$ds:declaration]*) => do - let (dataTypes, functions) ← ds.foldlM (init := default) accElabDeclarations + let (dataTypes, typeAliases, functions) ← ds.foldlM (init := default) accElabDeclarations mkAppM ``Toplevel.mk #[ ← mkArrayLit (mkConst ``DataType) dataTypes.toList, + ← mkArrayLit (mkConst ``TypeAlias) typeAliases.toList, ← mkArrayLit (mkConst ``Function) functions.toList, ] | stx => throw $ .error stx "Invalid syntax for toplevel" diff --git a/Ix/Aiur/Simple.lean b/Ix/Aiur/Simple.lean index f20e5e9d..2e5036d7 100644 --- a/Ix/Aiur/Simple.lean +++ b/Ix/Aiur/Simple.lean @@ -38,6 +38,10 @@ where recr := simplifyTerm decls def Toplevel.checkAndSimplify (toplevel : Toplevel) : Except CheckError TypedDecls := do + -- First expand all type aliases + let toplevel ← match toplevel.expandAliases with + | .ok t => pure t + | .error e => throw (.undefinedGlobal ⟨.mkSimple e⟩) -- TODO: better error handling let decls ← toplevel.mkDecls wellFormedDecls decls -- The first check happens on the original terms. @@ -49,6 +53,7 @@ def Toplevel.checkAndSimplify (toplevel : Toplevel) : Except CheckError TypedDec decls.foldlM (init := default) fun typedDecls (name, decl) => match decl with | .constructor d c => pure $ typedDecls.insert name (.constructor d c) | .dataType d => pure $ typedDecls.insert name (.dataType d) + | .typeAlias _ => pure typedDecls -- Type aliases should have been expanded by now | .function f => do -- The second check happens on the simplified terms. let f ← (checkFunction f) (getFunctionContext f decls) diff --git a/Ix/Aiur/Term.lean b/Ix/Aiur/Term.lean index ee662e79..78de10d6 100644 --- a/Ix/Aiur/Term.lean +++ b/Ix/Aiur/Term.lean @@ -58,7 +58,7 @@ inductive Typ where | tuple : Array Typ → Typ | array : Typ → Nat → Typ | pointer : Typ → Typ - | dataType : Global → Typ + | typeRef : Global → Typ | function : List Typ → Typ → Typ deriving Repr, BEq, Hashable, Inhabited @@ -176,6 +176,11 @@ structure DataType where constructors : List Constructor deriving Repr, BEq, Inhabited +structure TypeAlias where + name : Global + expansion : Typ + deriving Repr, BEq, Inhabited + structure Function where name : Global inputs : List (Local × Typ) @@ -186,6 +191,7 @@ structure Function where structure Toplevel where dataTypes : Array DataType + typeAliases : Array TypeAlias functions : Array Function deriving Repr @@ -193,31 +199,99 @@ def Toplevel.getFuncIdx (toplevel : Toplevel) (funcName : Lean.Name) : Option Na toplevel.functions.findIdx? fun function => function.name.toName == funcName def Toplevel.merge (x y : Toplevel) : Except Global Toplevel := do - let ⟨xDataTypes, xFunctions⟩ := x - let ⟨yDataTypes, yFunctions⟩ := y + let ⟨xDataTypes, xTypeAliases, xFunctions⟩ := x + let ⟨yDataTypes, yTypeAliases, yFunctions⟩ := y let mut globals : Std.HashSet Global := ∅ let mut dataTypes := .emptyWithCapacity (xDataTypes.size + yDataTypes.size) + let mut typeAliases := .emptyWithCapacity (xTypeAliases.size + yTypeAliases.size) let mut functions := .emptyWithCapacity (xFunctions.size + yFunctions.size) for dtSet in [xDataTypes, yDataTypes] do for dt in dtSet do if globals.contains dt.name then throw dt.name globals := globals.insert dt.name dataTypes := dataTypes.push dt + for taSet in [xTypeAliases, yTypeAliases] do + for ta in taSet do + if globals.contains ta.name then throw ta.name + globals := globals.insert ta.name + typeAliases := typeAliases.push ta for fSet in [xFunctions, yFunctions] do for f in fSet do if globals.contains f.name then throw f.name globals := globals.insert f.name functions := functions.push f - pure ⟨dataTypes, functions⟩ + pure ⟨dataTypes, typeAliases, functions⟩ inductive Declaration | function : Function → Declaration | dataType : DataType → Declaration + | typeAlias : TypeAlias → Declaration | constructor : DataType → Constructor → Declaration deriving Repr, Inhabited abbrev Decls := IndexMap Global Declaration +/-- Eagerly expands type aliases in a type, detecting cycles. -/ +partial def Typ.expandAliases (decls : Decls) (visited : Std.HashSet Global := {}) : + Typ → Except String Typ + | .unit => pure .unit + | .field => pure .field + | .pointer t => do pure $ .pointer (← t.expandAliases decls visited) + | .function inputs output => do + let inputs' ← inputs.mapM (·.expandAliases decls visited) + let output' ← output.expandAliases decls visited + pure $ .function inputs' output' + | .tuple ts => do + let ts' ← ts.mapM (·.expandAliases decls visited) + pure $ .tuple ts' + | .array t n => do + let t' ← t.expandAliases decls visited + pure $ .array t' n + | .typeRef g => match decls.getByKey g with + | some (.typeAlias alias) => + if visited.contains g then + throw s!"Cycle detected in type alias `{g}`" + else do + let visited' := visited.insert g + alias.expansion.expandAliases decls visited' + | some (.dataType _) => pure $ .typeRef g + | some _ => throw s!"Type reference `{g}` does not refer to a type" + | none => throw s!"Type reference `{g}` not found" + +/-- Expand all type aliases in a Toplevel, removing the aliases themselves. -/ +def Toplevel.expandAliases (toplevel : Toplevel) : Except String Toplevel := do + -- First create the Decls map to use for expansion + let mut decls : Decls := default + for ta in toplevel.typeAliases do + decls := decls.insert ta.name (.typeAlias ta) + for dt in toplevel.dataTypes do + decls := decls.insert dt.name (.dataType dt) + + -- Validate all type aliases can be expanded (checks for cycles) + for ta in toplevel.typeAliases do + let _ ← ta.expansion.expandAliases decls + + -- Expand all type references in data types + let mut dataTypes : Array DataType := #[] + for dt in toplevel.dataTypes do + let mut constructors : List Constructor := [] + for ctor in dt.constructors do + let argTypes' ← ctor.argTypes.mapM (·.expandAliases decls) + constructors := constructors.concat { ctor with argTypes := argTypes' } + dataTypes := dataTypes.push { dt with constructors } + + -- Expand all type references in functions + let mut functions : Array Function := #[] + for fn in toplevel.functions do + let inputs' ← fn.inputs.mapM fun (loc, typ) => do + let typ' ← typ.expandAliases decls + pure (loc, typ') + let output' ← fn.output.expandAliases decls + functions := functions.push { fn with inputs := inputs', output := output' } + + -- Return Toplevel without aliases + pure ⟨dataTypes, #[], functions⟩ + structure TypedFunction where name : Global inputs : List (Local × Typ) @@ -255,7 +329,7 @@ partial def Typ.size (decls : TypedDecls) (visited : HashSet Global := {}) : | .array t n => do let tSize ← t.size decls visited pure $ n * tSize - | .dataType g => match decls.getByKey g with + | .typeRef g => match decls.getByKey g with | some (.dataType data) => data.size decls visited | _ => throw s!"Datatype not found: `{g}`" From 16aecab26636164897079ad46957481f23ac2cd0 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Thu, 19 Mar 2026 10:45:17 -0300 Subject: [PATCH 3/6] removed aliases from declarations --- Ix/Aiur/Check.lean | 104 ++++++++++++++++++++++++++++++++++---------- Ix/Aiur/Simple.lean | 12 ++--- Ix/Aiur/Term.lean | 62 -------------------------- 3 files changed, 86 insertions(+), 92 deletions(-) diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index 81a2b24b..ac89b1d1 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -36,27 +36,90 @@ instance : ToString CheckError where toString e := repr e |>.pretty /-- -Constructs a map of declarations from a toplevel, ensuring that there are no duplicate names -for functions, datatypes, and type aliases. +Eagerly expands type aliases, detecting cycles. +-/ +partial def expandType (aliasMap : Std.HashMap Global Typ) (visited : Std.HashSet Global := {}) : + Typ → Except CheckError Typ + | .unit => pure .unit + | .field => pure .field + | .pointer t => do pure $ .pointer (← expandType aliasMap visited t) + | .function inputs output => do + let inputs' ← inputs.mapM (expandType aliasMap visited) + let output' ← expandType aliasMap visited output + pure $ .function inputs' output' + | .tuple ts => do + let ts' ← ts.mapM (expandType aliasMap visited) + pure $ .tuple ts' + | .array t n => do + let t' ← expandType aliasMap visited t + pure $ .array t' n + | .typeRef g => + if let some expansion := aliasMap[g]? then + if visited.contains g then + throw $ CheckError.undefinedGlobal ⟨.mkSimple s!"Cycle detected in type alias `{g}`"⟩ + else + expandType aliasMap (visited.insert g) expansion + else + pure $ .typeRef g -- It's a dataType, keep it + +/-- +Constructs a map of declarations from a toplevel, expanding all type aliases. +Type aliases are not added to the declarations - they are eliminated during construction. -/ def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do - let map ← toplevel.functions.foldlM (init := default) - fun acc function => addDecl acc Function.name .function function - let map ← toplevel.typeAliases.foldlM (init := map) - fun acc alias => addDecl (α := TypeAlias) acc (·.name) .typeAlias alias - toplevel.dataTypes.foldlM (init := map) addDataType -where - ensureUnique name (map : IndexMap Global _) := do - if map.containsKey name then throw $ .duplicatedDefinition name - addDecl {α : Type} map (nameFn : α → Global) (wrapper : α → Declaration) (inner : α) := do - ensureUnique (nameFn inner) map - pure $ map.insert (nameFn inner) (wrapper inner) - addDataType map dataType := do - let dataTypeName := dataType.name - ensureUnique dataTypeName map - let map' := map.insert dataTypeName (.dataType dataType) - dataType.constructors.foldlM (init := map') fun acc (constructor : Constructor) => - addDecl acc (dataTypeName.pushNamespace ∘ Constructor.nameHead) (.constructor dataType) constructor + -- First build the unexpanded alias map and check for name collisions + let mut rawAliasMap : Std.HashMap Global Typ := {} + let mut allNames : Std.HashSet Global := {} + + -- Collect alias definitions + for alias in toplevel.typeAliases do + if allNames.contains alias.name then + throw $ .duplicatedDefinition alias.name + allNames := allNames.insert alias.name + rawAliasMap := rawAliasMap.insert alias.name alias.expansion + + -- Now expand all aliases in terms of each other, detecting cycles + let mut aliasMap : Std.HashMap Global Typ := {} + for alias in toplevel.typeAliases do + let expanded ← expandType rawAliasMap {} alias.expansion + aliasMap := aliasMap.insert alias.name expanded + + -- Helper to expand types in the declarations + let expandTyp := expandType aliasMap {} + + -- Add functions with expanded types + let mut decls : Decls := default + for function in toplevel.functions do + if allNames.contains function.name then + throw $ .duplicatedDefinition function.name + allNames := allNames.insert function.name + let inputs' ← function.inputs.mapM fun (loc, typ) => do + let typ' ← expandTyp typ + pure (loc, typ') + let output' ← expandTyp function.output + let function' := { function with inputs := inputs', output := output' } + decls := decls.insert function.name (.function function') + + -- Add datatypes with expanded types + for dataType in toplevel.dataTypes do + if allNames.contains dataType.name then + throw $ .duplicatedDefinition dataType.name + allNames := allNames.insert dataType.name + let mut constructors : List Constructor := [] + for ctor in dataType.constructors do + let argTypes' ← ctor.argTypes.mapM expandTyp + constructors := constructors.concat { ctor with argTypes := argTypes' } + let dataType' := { dataType with constructors } + decls := decls.insert dataType.name (.dataType dataType') + -- Add constructors + for ctor in constructors do + let ctorName := dataType.name.pushNamespace ctor.nameHead + if allNames.contains ctorName then + throw $ .duplicatedDefinition ctorName + allNames := allNames.insert ctorName + decls := decls.insert ctorName (.constructor dataType' ctor) + + pure decls structure CheckContext where decls : Decls @@ -457,8 +520,6 @@ where if !map.contains dataType.name then set $ map.insert dataType.name dataType.constructors.flatMap (·.argTypes) |>.forM wellFormedType - | .typeAlias alias => do - wellFormedType alias.expansion | .function function => do wellFormedType function.output function.inputs.forM fun (_, typ) => wellFormedType typ @@ -469,7 +530,6 @@ where | .pointer pointerTyp => wellFormedType pointerTyp | .typeRef typeRef => match decls.getByKey typeRef with | some (.dataType _) => pure () - | some (.typeAlias _) => pure () | some _ => throw $ .notADataType typeRef | none => throw $ .undefinedGlobal typeRef | _ => pure () diff --git a/Ix/Aiur/Simple.lean b/Ix/Aiur/Simple.lean index 2e5036d7..37e420ae 100644 --- a/Ix/Aiur/Simple.lean +++ b/Ix/Aiur/Simple.lean @@ -38,22 +38,18 @@ where recr := simplifyTerm decls def Toplevel.checkAndSimplify (toplevel : Toplevel) : Except CheckError TypedDecls := do - -- First expand all type aliases - let toplevel ← match toplevel.expandAliases with - | .ok t => pure t - | .error e => throw (.undefinedGlobal ⟨.mkSimple e⟩) -- TODO: better error handling let decls ← toplevel.mkDecls wellFormedDecls decls - -- The first check happens on the original terms. - toplevel.functions.forM fun function => do - let _ ← (checkFunction function) (getFunctionContext function decls) + -- The first check happens on the original terms (but with expanded types). + decls.foldlM (init := ()) fun _ (_, decl) => match decl with + | .function f => do let _ ← (checkFunction f) (getFunctionContext f decls); pure () + | _ => pure () let decls := decls.map fun decl => match decl with | .function f => .function { f with body := simplifyTerm decls f.body } | _ => decl decls.foldlM (init := default) fun typedDecls (name, decl) => match decl with | .constructor d c => pure $ typedDecls.insert name (.constructor d c) | .dataType d => pure $ typedDecls.insert name (.dataType d) - | .typeAlias _ => pure typedDecls -- Type aliases should have been expanded by now | .function f => do -- The second check happens on the simplified terms. let f ← (checkFunction f) (getFunctionContext f decls) diff --git a/Ix/Aiur/Term.lean b/Ix/Aiur/Term.lean index 78de10d6..a913e654 100644 --- a/Ix/Aiur/Term.lean +++ b/Ix/Aiur/Term.lean @@ -225,73 +225,11 @@ def Toplevel.merge (x y : Toplevel) : Except Global Toplevel := do inductive Declaration | function : Function → Declaration | dataType : DataType → Declaration - | typeAlias : TypeAlias → Declaration | constructor : DataType → Constructor → Declaration deriving Repr, Inhabited abbrev Decls := IndexMap Global Declaration -/-- Eagerly expands type aliases in a type, detecting cycles. -/ -partial def Typ.expandAliases (decls : Decls) (visited : Std.HashSet Global := {}) : - Typ → Except String Typ - | .unit => pure .unit - | .field => pure .field - | .pointer t => do pure $ .pointer (← t.expandAliases decls visited) - | .function inputs output => do - let inputs' ← inputs.mapM (·.expandAliases decls visited) - let output' ← output.expandAliases decls visited - pure $ .function inputs' output' - | .tuple ts => do - let ts' ← ts.mapM (·.expandAliases decls visited) - pure $ .tuple ts' - | .array t n => do - let t' ← t.expandAliases decls visited - pure $ .array t' n - | .typeRef g => match decls.getByKey g with - | some (.typeAlias alias) => - if visited.contains g then - throw s!"Cycle detected in type alias `{g}`" - else do - let visited' := visited.insert g - alias.expansion.expandAliases decls visited' - | some (.dataType _) => pure $ .typeRef g - | some _ => throw s!"Type reference `{g}` does not refer to a type" - | none => throw s!"Type reference `{g}` not found" - -/-- Expand all type aliases in a Toplevel, removing the aliases themselves. -/ -def Toplevel.expandAliases (toplevel : Toplevel) : Except String Toplevel := do - -- First create the Decls map to use for expansion - let mut decls : Decls := default - for ta in toplevel.typeAliases do - decls := decls.insert ta.name (.typeAlias ta) - for dt in toplevel.dataTypes do - decls := decls.insert dt.name (.dataType dt) - - -- Validate all type aliases can be expanded (checks for cycles) - for ta in toplevel.typeAliases do - let _ ← ta.expansion.expandAliases decls - - -- Expand all type references in data types - let mut dataTypes : Array DataType := #[] - for dt in toplevel.dataTypes do - let mut constructors : List Constructor := [] - for ctor in dt.constructors do - let argTypes' ← ctor.argTypes.mapM (·.expandAliases decls) - constructors := constructors.concat { ctor with argTypes := argTypes' } - dataTypes := dataTypes.push { dt with constructors } - - -- Expand all type references in functions - let mut functions : Array Function := #[] - for fn in toplevel.functions do - let inputs' ← fn.inputs.mapM fun (loc, typ) => do - let typ' ← typ.expandAliases decls - pure (loc, typ') - let output' ← fn.output.expandAliases decls - functions := functions.push { fn with inputs := inputs', output := output' } - - -- Return Toplevel without aliases - pure ⟨dataTypes, #[], functions⟩ - structure TypedFunction where name : Global inputs : List (Local × Typ) From 7970733bcef40e7c049dc9d416c7d8d0ed56e39e Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Thu, 19 Mar 2026 10:58:53 -0300 Subject: [PATCH 4/6] better algorithm --- Ix/Aiur/Check.lean | 67 ++++++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index ac89b1d1..62ed45fe 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -36,62 +36,71 @@ instance : ToString CheckError where toString e := repr e |>.pretty /-- -Eagerly expands type aliases, detecting cycles. +Eagerly expands type aliases, building the aliasMap on demand and detecting cycles. -/ -partial def expandType (aliasMap : Std.HashMap Global Typ) (visited : Std.HashSet Global := {}) : - Typ → Except CheckError Typ +partial def expandTypeM (visited : Std.HashSet Global) (toplevelAliases : Array TypeAlias) : + Typ → StateT (Std.HashMap Global Typ) (Except CheckError) Typ | .unit => pure .unit | .field => pure .field - | .pointer t => do pure $ .pointer (← expandType aliasMap visited t) + | .pointer t => do pure $ .pointer (← expandTypeM visited toplevelAliases t) | .function inputs output => do - let inputs' ← inputs.mapM (expandType aliasMap visited) - let output' ← expandType aliasMap visited output + let inputs' ← inputs.mapM (expandTypeM visited toplevelAliases) + let output' ← expandTypeM visited toplevelAliases output pure $ .function inputs' output' | .tuple ts => do - let ts' ← ts.mapM (expandType aliasMap visited) + let ts' ← ts.mapM (expandTypeM visited toplevelAliases) pure $ .tuple ts' | .array t n => do - let t' ← expandType aliasMap visited t + let t' ← expandTypeM visited toplevelAliases t pure $ .array t' n - | .typeRef g => - if let some expansion := aliasMap[g]? then + | .typeRef g => do + let aliasMap ← get + -- Check if already expanded + if let some expanded := aliasMap[g]? then + return expanded + -- Check if it's an alias + if let some (alias : TypeAlias) := toplevelAliases.find? (·.name == g) then + -- Check for cycle if visited.contains g then throw $ CheckError.undefinedGlobal ⟨.mkSimple s!"Cycle detected in type alias `{g}`"⟩ - else - expandType aliasMap (visited.insert g) expansion + -- Expand the alias recursively + let expanded ← expandTypeM (visited.insert g) toplevelAliases alias.expansion + -- Save to aliasMap + set (aliasMap.insert g expanded) + return expanded else - pure $ .typeRef g -- It's a dataType, keep it + -- It's a dataType, keep it + pure $ .typeRef g /-- Constructs a map of declarations from a toplevel, expanding all type aliases. Type aliases are not added to the declarations - they are eliminated during construction. -/ def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do - -- First build the unexpanded alias map and check for name collisions - let mut rawAliasMap : Std.HashMap Global Typ := {} + -- Check for duplicate names among aliases let mut allNames : Std.HashSet Global := {} - - -- Collect alias definitions for alias in toplevel.typeAliases do if allNames.contains alias.name then - throw $ .duplicatedDefinition alias.name + throw $ CheckError.duplicatedDefinition alias.name allNames := allNames.insert alias.name - rawAliasMap := rawAliasMap.insert alias.name alias.expansion - -- Now expand all aliases in terms of each other, detecting cycles - let mut aliasMap : Std.HashMap Global Typ := {} - for alias in toplevel.typeAliases do - let expanded ← expandType rawAliasMap {} alias.expansion - aliasMap := aliasMap.insert alias.name expanded + -- Build aliasMap by expanding all aliases in order + let initAliasMap := {} + let (_, finalAliasMap) ← (toplevel.typeAliases.mapM fun (alias : TypeAlias) => do + -- Expand and save the alias + let expanded ← expandTypeM {} toplevel.typeAliases alias.expansion + modify fun (aliasMap : Std.HashMap Global Typ) => aliasMap.insert alias.name expanded + ).run initAliasMap - -- Helper to expand types in the declarations - let expandTyp := expandType aliasMap {} + -- Helper to expand types in the declarations using the built aliasMap + let expandTyp (typ : Typ) : Except CheckError Typ := + (expandTypeM {} toplevel.typeAliases typ).run' finalAliasMap -- Add functions with expanded types let mut decls : Decls := default for function in toplevel.functions do if allNames.contains function.name then - throw $ .duplicatedDefinition function.name + throw $ CheckError.duplicatedDefinition function.name allNames := allNames.insert function.name let inputs' ← function.inputs.mapM fun (loc, typ) => do let typ' ← expandTyp typ @@ -103,7 +112,7 @@ def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do -- Add datatypes with expanded types for dataType in toplevel.dataTypes do if allNames.contains dataType.name then - throw $ .duplicatedDefinition dataType.name + throw $ CheckError.duplicatedDefinition dataType.name allNames := allNames.insert dataType.name let mut constructors : List Constructor := [] for ctor in dataType.constructors do @@ -115,7 +124,7 @@ def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do for ctor in constructors do let ctorName := dataType.name.pushNamespace ctor.nameHead if allNames.contains ctorName then - throw $ .duplicatedDefinition ctorName + throw $ CheckError.duplicatedDefinition ctorName allNames := allNames.insert ctorName decls := decls.insert ctorName (.constructor dataType' ctor) From b0a5401db8f2c63d7d18024277f12cda8a9bd581 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Thu, 19 Mar 2026 12:40:36 -0300 Subject: [PATCH 5/6] type alias tests --- Tests/Aiur/Aiur.lean | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index bd0414d2..b4ef11b5 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -364,6 +364,19 @@ def toplevel := ⟦ ) ) } + + --------------------------------------------------------------------------- + -- Type aliases: basic, nested, in patterns + --------------------------------------------------------------------------- + type U8 = G + type U16 = (U8, U8) + type U32 = (U16, U16) + type U64 = [U8; 8] + type Pair = (U8, U8) + + fn alias_conversion(x: U64) -> U32 { + ((x[0], x[1]), (x[2], x[3])) + } ⟧ def aiurTestCases : List AiurTestCase := [ @@ -488,6 +501,10 @@ def aiurTestCases : List AiurTestCase := [ -- Fold/iteration .noIO `fold_matrix_sum #[1, 2, 3, 4] #[10], + + -- Type aliases + { AiurTestCase.noIO `alias_conversion #[1, 2, 3, 4, 5, 6, 7, 8] #[1, 2, 3, 4] + with label := "alias_conversion (U64 = [U8; 8], U32 = (U16, U16))" }, ] end From e210bb292e3e7d00eb00e70cdfde51db79aa6b15 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Thu, 19 Mar 2026 14:20:37 -0700 Subject: [PATCH 6/6] Use idiomatic forM instead of foldlM with unit accumulator Add IndexMap.forM and use it in checkAndSimplify for cleaner iteration over declarations when only side effects (type-checking) are needed. --- Ix/Aiur/Simple.lean | 6 +++--- Ix/IndexMap.lean | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Ix/Aiur/Simple.lean b/Ix/Aiur/Simple.lean index 37e420ae..4e7b780a 100644 --- a/Ix/Aiur/Simple.lean +++ b/Ix/Aiur/Simple.lean @@ -41,9 +41,9 @@ def Toplevel.checkAndSimplify (toplevel : Toplevel) : Except CheckError TypedDec let decls ← toplevel.mkDecls wellFormedDecls decls -- The first check happens on the original terms (but with expanded types). - decls.foldlM (init := ()) fun _ (_, decl) => match decl with - | .function f => do let _ ← (checkFunction f) (getFunctionContext f decls); pure () - | _ => pure () + decls.forM fun (_, decl) => do + if let .function f := decl then + let _ ← (checkFunction f) (getFunctionContext f decls) let decls := decls.map fun decl => match decl with | .function f => .function { f with body := simplifyTerm decls f.body } | _ => decl diff --git a/Ix/IndexMap.lean b/Ix/IndexMap.lean index 27ecc8d1..b6177613 100644 --- a/Ix/IndexMap.lean +++ b/Ix/IndexMap.lean @@ -69,6 +69,9 @@ def map (f : β → β) : IndexMap α β := by @[inline] def foldrM [Monad μ] (f : α × β → γ → μ γ) (init : γ) : μ γ := m.pairs.foldrM f init +@[inline] def forM [Monad μ] (f : α × β → μ PUnit) : μ PUnit := + m.pairs.forM f + end IndexMap end