From 831a2b73293979715ee7eb125cd13d32eb2ecb6e Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Wed, 6 May 2026 14:40:10 +0300 Subject: [PATCH] feat(constant): Added support for const calculation for inner variants and lets. --- .../src/expr/test_data/constant | 129 ++++++++++++++++- .../cairo-lang-semantic/src/items/constant.rs | 135 +++++++++--------- tests/bug_samples/issue9887.cairo | 39 +++++ tests/bug_samples/lib.cairo | 1 + 4 files changed, 228 insertions(+), 76 deletions(-) create mode 100644 tests/bug_samples/issue9887.cairo diff --git a/crates/cairo-lang-semantic/src/expr/test_data/constant b/crates/cairo-lang-semantic/src/expr/test_data/constant index 8e8e2a747d4..79756c5b0c8 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/constant +++ b/crates/cairo-lang-semantic/src/expr/test_data/constant @@ -17,6 +17,10 @@ const MY_CONST: felt252 = 0x1234; const FELT_TUPLE: (felt252, felt252) = (1, 2); +mod my_module { + const CONST_IN_MODULE: felt252 = 1; +} + const FELT_FIXED_SIZE_ARRAY: [felt252; 3] = [1, 2, 3]; const OTHER_CONST_REF: felt252 = my_module::CONST_IN_MODULE; @@ -24,9 +28,106 @@ const OTHER_CONST_REF: felt252 = my_module::CONST_IN_MODULE; const FIVE: u32 = 5; const NON_ZERO_FIVE: NonZero = FIVE.try_into().unwrap(); -mod my_module { - const CONST_IN_MODULE: felt252 = 1; -} +const MATCH_WILDCARD: () = assert(match 3_u8 { + 2 => 0, + _ => 5, +} == 5, 'match_wildcard'); + +const MATCH_LITERAL_PATTERN: () = assert( + match 3_u8 { + 1 => 0, + 3 => 5, + _ => 0, + } == 5, 'match_literal_pattern', +); + +const MATCH_TUPLE_PATTERN: () = assert( + match (1_u8, 2_u8) { + (_, y) => y, + } == 2, 'match_tuple_pattern', +); + +const MATCH_ARRAY_PATTERN: () = assert( + match [1_u8, 2_u8] { + [x, _] => x, + } == 1, 'match_array_pattern', +); + +const LET_TUPLE_BIND: () = assert({ + let (x, y) = (1_u8, 2_u8); + x + y +} == 3, 'let_tuple_bind'); + +const LET_ELSE_MATCH: () = assert( + { + let Some(x) = Option::::Some(3) else { + core::panic_with_felt252('impossible') + }; + x + } == 3, + 'let_else_match', +); + +const IF_LET_MATCH: () = assert( + if let Some(x) = Option::::Some(4) { + x + } else { + 0 + } == 4, 'if_let_match', +); + +const IF_LET_CHAIN_0: () = assert( + if let Some(x) = Option::::Some(3) && let Some(y) = Option::::Some(2) { + x + y + } else { + 0 + } == 5, + 'if_let_chain', +); + +const IF_LET_CHAIN_1: () = assert( + if let Some(x) = Option::::Some(3) && let Some(y) = Option::::None { + x + y + } else { + 0 + } == 0, + 'if_let_chain', +); + +const IF_LET_CHAIN_2: () = assert( + if let Some(x) = Option::::None && let Some(y) = Option::::Some(2) { + x + y + } else { + 0 + } == 0, + 'if_let_chain', +); + +const IF_LET_CHAIN_3: () = assert( + if let Some(x) = Option::::Some(3) && x == 3 { + 1 + } else { + 0 + } == 1, 'if_let_chain', +); + +const IF_LET_CHAIN_4: () = assert( + if let Some(x) = Option::::Some(3) && x == 2 { + 1 + } else { + 0 + } == 0, 'if_let_chain', +); + +const PARTIAL_STRUCT1: () = assert({ + let u256 { low, .. } = 1337; + low +} == 1337, 'partial_struct'); + +const PARTIAL_STRUCT2: () = assert({ + let u256 { high, .. } = 1337; + high +} == 0, 'partial_struct'); //! > expected_diagnostics @@ -81,6 +182,10 @@ const VALID_LE: () = assert(1_usize <= 1); const VALID_GT: () = assert(2_usize > 1); const VALID_GE: () = assert(1_usize >= 1); const VALID_DIVREM: () = assert(DivRem::div_rem(5_u8, 2) == (2, 1)); + +const MATCH_VALUE: () = assert(match 1_u8 { + _ => 1, +} == 1); use core::num::traits::Pow; const VALID_POW_MIN2_0: () = assert((-2_i8).pow(0) == 1); @@ -119,6 +224,13 @@ const fn my_unwrap(v: Option) -> T { } } +const LET_ELSE_FAIL: u8 = { + let None = Option::::Some(5) else { + core::panic_with_felt252('should fail') + }; + 0 +}; + //! > expected_diagnostics error[E0006]: Type not found. --> lib.cairo:1:17 @@ -160,19 +272,24 @@ note: In `test::assert`: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error[E2129]: Constant calculation depth exceeded. - --> lib.cairo:59:43 + --> lib.cairo:63:43 const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself(); ^^^^^^^^^^^^^ error[E2130]: Failed to calculate constant. - --> lib.cairo:66:37 + --> lib.cairo:70:37 const NON_ZERO_ZERO: NonZero = my_unwrap(ZERO.try_into()); ^^^^^^^^^^^^^^^^^^^^^^^^^^ note: In `test::my_unwrap::>`: - --> lib.cairo:72:17 + --> lib.cairo:76:17 None => core::panic_with_felt252('bad value'), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +error[E2128]: Failed to calculate constant. + --> lib.cairo:82:9 + core::panic_with_felt252('should fail') + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + //! > ========================================================================== //! > Const of wrong type. diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index 83ed333fba7..fe77361d679 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -386,7 +386,6 @@ pub fn constant_semantic_data_helper<'db>( } let mut ctx = ComputationContext::new_global(db, &mut diagnostics, &mut resolver); - let value = compute_expr_semantic(&mut ctx, &constant_ast.value(db)); let const_value = resolve_const_expr_and_evaluate( db, @@ -526,6 +525,9 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> { match &self.arenas.statements[*statement_id] { Statement::Let(statement) => { self.validate(statement.expr); + if let Some(else_clause) = &statement.else_clause { + self.validate(*else_clause); + } } Statement::Expr(expr) => { self.validate(expr.expr); @@ -684,7 +686,18 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> { match &self.arenas.statements[*statement_id] { Statement::Let(statement) => { let value = self.evaluate(statement.expr); - self.destructure_pattern(statement.pattern, value); + if self.destructure_pattern(statement.pattern, value).is_none() { + if let Some(else_clause) = &statement.else_clause { + // Only exiting the block is possible - so this is a return of a + // panic. + return self.evaluate(*else_clause); + } else { + // Either the pattern is refutable and we are missing an else + // clause, or the pattern is irrefutable and the pattern have + // failed for some reason. Both should already cause diagstics. + return to_missing(skip_diagnostic()); + } + } } Statement::Expr(expr) => { self.evaluate(expr.expr); @@ -777,25 +790,11 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> { } Expr::Match(expr) => { let value = self.evaluate(expr.matched_expr); - let ConstValue::Enum(variant, value) = value.long(db) else { - return to_missing(skip_diagnostic()); - }; for arm in &expr.arms { for pattern_id in &arm.patterns { - let pattern = &self.arenas.patterns[*pattern_id]; - if matches!(pattern, Pattern::Otherwise(_)) { + if self.destructure_pattern(*pattern_id, value).is_some() { return self.evaluate(arm.expression); } - let Pattern::EnumVariant(pattern) = pattern else { - continue; - }; - if pattern.variant.idx != variant.idx { - continue; - } - if let Some(inner_pattern) = pattern.inner_pattern { - self.destructure_pattern(inner_pattern, *value); - } - return self.evaluate(arm.expression); } } to_missing( @@ -811,35 +810,23 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> { match condition { crate::Condition::BoolExpr(id) => { let condition = self.evaluate(*id); - let ConstValue::Enum(variant, _) = condition.long(db) else { - return to_missing(skip_diagnostic()); - }; - if *variant != true_variant(self.db) { + if condition == self.true_const { + continue; + } else if condition == self.false_const { if_condition = false; break; + } else { + return to_missing(skip_diagnostic()); } } crate::Condition::Let(id, patterns) => { let value = self.evaluate(*id); - let ConstValue::Enum(variant, value) = value.long(db) else { - return to_missing(skip_diagnostic()); - }; let mut found_pattern = false; for pattern_id in patterns { - let Pattern::EnumVariant(pattern) = - &self.arenas.patterns[*pattern_id] - else { - continue; - }; - if pattern.variant != *variant { - // Continue to the next option in the `|` list. - continue; + if self.destructure_pattern(*pattern_id, value).is_some() { + found_pattern = true; + break; } - if let Some(inner_pattern) = pattern.inner_pattern { - self.destructure_pattern(inner_pattern, *value); - } - found_pattern = true; - break; } if !found_pattern { if_condition = false; @@ -1165,56 +1152,64 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> { Ok(values[member_idx]) } - /// Destructures the pattern into the const value of the variables in scope. - fn destructure_pattern(&mut self, pattern_id: PatternId, value: ConstValueId<'a>) { - let pattern = &self.arenas.patterns[pattern_id]; + /// Destructures the pattern, binding variables to their values; returns `None` if the pattern + /// didn't match. + fn destructure_pattern( + &mut self, + pattern_id: PatternId, + value: ConstValueId<'a>, + ) -> Option<()> { let db = self.db; + let pattern = &self.arenas.patterns[pattern_id]; match pattern { - Pattern::Literal(_) - | Pattern::StringLiteral(_) - | Pattern::Otherwise(_) - | Pattern::Missing(_) => {} + Pattern::Missing(_) | Pattern::StringLiteral(_) => None, + Pattern::Otherwise(_) => Some(()), + Pattern::Literal(v) => require(numeric_arg_value(db, value)? == v.literal.value), Pattern::Variable(pattern) => { self.vars.insert(VarId::Local(pattern.var.id), value); + Some(()) } Pattern::Struct(pattern) => { - if let ConstValue::Struct(inner_values, _) = value.long(db) { - let member_order = match db.concrete_struct_members(pattern.concrete_struct_id) + let ConstValue::Struct(inner_values, _) = value.long(db) else { + return None; + }; + let member_order = db.concrete_struct_members(pattern.concrete_struct_id).ok()?; + for (member, inner_value) in zip(member_order.values(), inner_values) { + if let Some((inner_pattern, _)) = + pattern.field_patterns.iter().find(|(_, field)| member.id == field.id) { - Ok(member_order) => member_order, - Err(_) => return, - }; - for (member, inner_value) in zip(member_order.values(), inner_values) { - if let Some((inner_pattern, _)) = - pattern.field_patterns.iter().find(|(_, field)| member.id == field.id) - { - self.destructure_pattern(*inner_pattern, *inner_value); - } + self.destructure_pattern(*inner_pattern, *inner_value)?; } } + Some(()) } Pattern::Tuple(pattern) => { - if let ConstValue::Struct(inner_values, _) = value.long(db) { - for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) { - self.destructure_pattern(*inner_pattern, *inner_value); - } + let ConstValue::Struct(inner_values, _) = value.long(db) else { + return None; + }; + for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) { + self.destructure_pattern(*inner_pattern, *inner_value)?; } + Some(()) } Pattern::FixedSizeArray(pattern) => { - if let ConstValue::Struct(inner_values, _) = value.long(db) { - for (inner_pattern, inner_value) in - zip(&pattern.elements_patterns, inner_values) - { - self.destructure_pattern(*inner_pattern, *inner_value); - } + let ConstValue::Struct(inner_values, _) = value.long(db) else { + return None; + }; + for (inner_pattern, inner_value) in zip(&pattern.elements_patterns, inner_values) { + self.destructure_pattern(*inner_pattern, *inner_value)?; } + Some(()) } Pattern::EnumVariant(pattern) => { - if let ConstValue::Enum(variant, inner_value) = value.long(db) - && pattern.variant == *variant - && let Some(inner_pattern) = pattern.inner_pattern - { - self.destructure_pattern(inner_pattern, *inner_value); + let ConstValue::Enum(variant, inner_value) = value.long(db) else { + return None; + }; + require(pattern.variant.id == variant.id)?; + if let Some(inner_pattern) = pattern.inner_pattern { + self.destructure_pattern(inner_pattern, *inner_value) + } else { + Some(()) } } } diff --git a/tests/bug_samples/issue9887.cairo b/tests/bug_samples/issue9887.cairo new file mode 100644 index 00000000000..fdcad41d278 --- /dev/null +++ b/tests/bug_samples/issue9887.cairo @@ -0,0 +1,39 @@ +struct S { + v: felt252, +} + +// Wildcard pattern on non-enum type. +const CASE1: bool = match 1_u8 { + _ => true, +}; +// Literal pattern: non-matching arm skipped, wildcard fallthrough. +const CASE2: bool = match 2_u8 { + 1 => false, + _ => true, +}; +// Wildcard on felt252. +const CASE3: bool = match 1_felt252 { + _ => true, +}; +// Tuple destructuring in match. +const CASE4: bool = match (1_u8, 2_u8) { + (_, _) => true, +}; +// Struct destructuring in match. +const CASE5: bool = match (S { v: 1 }) { + S { v: _ } => true, +}; +// Unit pattern in match. +const CASE6: bool = match () { + () => true, +}; + +#[test] +fn all_cases() { + assert!(CASE1); + assert!(CASE2); + assert!(CASE3); + assert!(CASE4); + assert!(CASE5); + assert!(CASE6); +} diff --git a/tests/bug_samples/lib.cairo b/tests/bug_samples/lib.cairo index b5039a9f2a9..a7d22f516c8 100644 --- a/tests/bug_samples/lib.cairo +++ b/tests/bug_samples/lib.cairo @@ -67,6 +67,7 @@ mod issue7544; mod issue7640; mod issue7961; mod issue8105; +mod issue9887; mod loop_break_in_match; mod loop_only_change; mod merge_const_member;