Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 123 additions & 6 deletions crates/cairo-lang-semantic/src/expr/test_data/constant
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,117 @@ const MY_CONST: felt252 = 0x1234;

const FELT_TUPLE: (felt252, felt252) = (1, 2);

mod my_module {
const CONST_IN_MODULE: felt252 = 1;
}

const FELT_FIXED_SIZE_ARRAY: [felt252; 3] = [1, 2, 3];

const OTHER_CONST_REF: felt252 = my_module::CONST_IN_MODULE;

const FIVE: u32 = 5;
const NON_ZERO_FIVE: NonZero<u32> = FIVE.try_into().unwrap();

mod my_module {
const CONST_IN_MODULE: felt252 = 1;
}
const MATCH_WILDCARD: () = assert(match 3_u8 {
2 => 0,
_ => 5,
} == 5, 'match_wildcard');

const MATCH_LITERAL_PATTERN: () = assert(
match 3_u8 {
1 => 0,
3 => 5,
_ => 0,
} == 5, 'match_literal_pattern',
);

const MATCH_TUPLE_PATTERN: () = assert(
match (1_u8, 2_u8) {
(_, y) => y,
} == 2, 'match_tuple_pattern',
);

const MATCH_ARRAY_PATTERN: () = assert(
match [1_u8, 2_u8] {
[x, _] => x,
} == 1, 'match_array_pattern',
);

const LET_TUPLE_BIND: () = assert({
let (x, y) = (1_u8, 2_u8);
x + y
} == 3, 'let_tuple_bind');

const LET_ELSE_MATCH: () = assert(
{
let Some(x) = Option::<u8>::Some(3) else {
core::panic_with_felt252('impossible')
};
x
} == 3,
'let_else_match',
);

const IF_LET_MATCH: () = assert(
if let Some(x) = Option::<u8>::Some(4) {
x
} else {
0
} == 4, 'if_let_match',
);

const IF_LET_CHAIN_0: () = assert(
if let Some(x) = Option::<u8>::Some(3) && let Some(y) = Option::<u8>::Some(2) {
x + y
} else {
0
} == 5,
'if_let_chain',
);

const IF_LET_CHAIN_1: () = assert(
if let Some(x) = Option::<u8>::Some(3) && let Some(y) = Option::<u8>::None {
x + y
} else {
0
} == 0,
'if_let_chain',
);

const IF_LET_CHAIN_2: () = assert(
if let Some(x) = Option::<u8>::None && let Some(y) = Option::<u8>::Some(2) {
x + y
} else {
0
} == 0,
'if_let_chain',
);

const IF_LET_CHAIN_3: () = assert(
if let Some(x) = Option::<u8>::Some(3) && x == 3 {
1
} else {
0
} == 1, 'if_let_chain',
);

const IF_LET_CHAIN_4: () = assert(
if let Some(x) = Option::<u8>::Some(3) && x == 2 {
1
} else {
0
} == 0, 'if_let_chain',
);

const PARTIAL_STRUCT1: () = assert({
let u256 { low, .. } = 1337;
low
} == 1337, 'partial_struct');

const PARTIAL_STRUCT2: () = assert({
let u256 { high, .. } = 1337;
high
} == 0, 'partial_struct');

//! > expected_diagnostics

Expand Down Expand Up @@ -81,6 +182,10 @@ const VALID_LE: () = assert(1_usize <= 1);
const VALID_GT: () = assert(2_usize > 1);
const VALID_GE: () = assert(1_usize >= 1);
const VALID_DIVREM: () = assert(DivRem::div_rem(5_u8, 2) == (2, 1));

const MATCH_VALUE: () = assert(match 1_u8 {
_ => 1,
} == 1);
use core::num::traits::Pow;

const VALID_POW_MIN2_0: () = assert((-2_i8).pow(0) == 1);
Expand Down Expand Up @@ -119,6 +224,13 @@ const fn my_unwrap<T>(v: Option<T>) -> T {
}
}

const LET_ELSE_FAIL: u8 = {
let None = Option::<u8>::Some(5) else {
core::panic_with_felt252('should fail')
};
0
};

//! > expected_diagnostics
error[E0006]: Type not found.
--> lib.cairo:1:17
Expand Down Expand Up @@ -160,19 +272,24 @@ note: In `test::assert`:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error[E2129]: Constant calculation depth exceeded.
--> lib.cairo:59:43
--> lib.cairo:63:43
const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself();
^^^^^^^^^^^^^

error[E2130]: Failed to calculate constant.
--> lib.cairo:66:37
--> lib.cairo:70:37
const NON_ZERO_ZERO: NonZero<u64> = my_unwrap(ZERO.try_into());
^^^^^^^^^^^^^^^^^^^^^^^^^^
note: In `test::my_unwrap::<core::zeroable::NonZero::<core::integer::u64>>`:
--> lib.cairo:72:17
--> lib.cairo:76:17
None => core::panic_with_felt252('bad value'),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error[E2128]: Failed to calculate constant.
--> lib.cairo:82:9
core::panic_with_felt252('should fail')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

//! > ==========================================================================

//! > Const of wrong type.
Expand Down
135 changes: 65 additions & 70 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ pub fn constant_semantic_data_helper<'db>(
}

let mut ctx = ComputationContext::new_global(db, &mut diagnostics, &mut resolver);

let value = compute_expr_semantic(&mut ctx, &constant_ast.value(db));
let const_value = resolve_const_expr_and_evaluate(
db,
Expand Down Expand Up @@ -526,6 +525,9 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> {
match &self.arenas.statements[*statement_id] {
Statement::Let(statement) => {
self.validate(statement.expr);
if let Some(else_clause) = &statement.else_clause {
self.validate(*else_clause);
}
}
Statement::Expr(expr) => {
self.validate(expr.expr);
Expand Down Expand Up @@ -684,7 +686,18 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> {
match &self.arenas.statements[*statement_id] {
Statement::Let(statement) => {
let value = self.evaluate(statement.expr);
self.destructure_pattern(statement.pattern, value);
if self.destructure_pattern(statement.pattern, value).is_none() {
if let Some(else_clause) = &statement.else_clause {
// Only exiting the block is possible - so this is a return of a
// panic.
return self.evaluate(*else_clause);
} else {
// Either the pattern is refutable and we are missing an else
// clause, or the pattern is irrefutable and the pattern have
// failed for some reason. Both should already cause diagstics.
return to_missing(skip_diagnostic());
}
}
}
Statement::Expr(expr) => {
self.evaluate(expr.expr);
Expand Down Expand Up @@ -777,25 +790,11 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> {
}
Expr::Match(expr) => {
let value = self.evaluate(expr.matched_expr);
let ConstValue::Enum(variant, value) = value.long(db) else {
return to_missing(skip_diagnostic());
};
for arm in &expr.arms {
for pattern_id in &arm.patterns {
let pattern = &self.arenas.patterns[*pattern_id];
if matches!(pattern, Pattern::Otherwise(_)) {
if self.destructure_pattern(*pattern_id, value).is_some() {
return self.evaluate(arm.expression);
}
let Pattern::EnumVariant(pattern) = pattern else {
continue;
};
if pattern.variant.idx != variant.idx {
continue;
}
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *value);
}
return self.evaluate(arm.expression);
}
}
to_missing(
Expand All @@ -811,35 +810,23 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> {
match condition {
crate::Condition::BoolExpr(id) => {
let condition = self.evaluate(*id);
let ConstValue::Enum(variant, _) = condition.long(db) else {
return to_missing(skip_diagnostic());
};
if *variant != true_variant(self.db) {
if condition == self.true_const {
continue;
} else if condition == self.false_const {
if_condition = false;
break;
} else {
return to_missing(skip_diagnostic());
}
}
crate::Condition::Let(id, patterns) => {
let value = self.evaluate(*id);
let ConstValue::Enum(variant, value) = value.long(db) else {
return to_missing(skip_diagnostic());
};
let mut found_pattern = false;
for pattern_id in patterns {
let Pattern::EnumVariant(pattern) =
&self.arenas.patterns[*pattern_id]
else {
continue;
};
if pattern.variant != *variant {
// Continue to the next option in the `|` list.
continue;
if self.destructure_pattern(*pattern_id, value).is_some() {
found_pattern = true;
break;
}
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *value);
}
found_pattern = true;
break;
}
if !found_pattern {
if_condition = false;
Expand Down Expand Up @@ -1165,56 +1152,64 @@ impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> {
Ok(values[member_idx])
}

/// Destructures the pattern into the const value of the variables in scope.
fn destructure_pattern(&mut self, pattern_id: PatternId, value: ConstValueId<'a>) {
let pattern = &self.arenas.patterns[pattern_id];
/// Destructures the pattern, binding variables to their values; returns `None` if the pattern
/// didn't match.
fn destructure_pattern(
&mut self,
pattern_id: PatternId,
value: ConstValueId<'a>,
) -> Option<()> {
let db = self.db;
let pattern = &self.arenas.patterns[pattern_id];
match pattern {
Pattern::Literal(_)
| Pattern::StringLiteral(_)
| Pattern::Otherwise(_)
| Pattern::Missing(_) => {}
Pattern::Missing(_) | Pattern::StringLiteral(_) => None,
Pattern::Otherwise(_) => Some(()),
Pattern::Literal(v) => require(numeric_arg_value(db, value)? == v.literal.value),
Pattern::Variable(pattern) => {
self.vars.insert(VarId::Local(pattern.var.id), value);
Some(())
}
Pattern::Struct(pattern) => {
if let ConstValue::Struct(inner_values, _) = value.long(db) {
let member_order = match db.concrete_struct_members(pattern.concrete_struct_id)
let ConstValue::Struct(inner_values, _) = value.long(db) else {
return None;
};
let member_order = db.concrete_struct_members(pattern.concrete_struct_id).ok()?;
for (member, inner_value) in zip(member_order.values(), inner_values) {
if let Some((inner_pattern, _)) =
pattern.field_patterns.iter().find(|(_, field)| member.id == field.id)
{
Ok(member_order) => member_order,
Err(_) => return,
};
for (member, inner_value) in zip(member_order.values(), inner_values) {
if let Some((inner_pattern, _)) =
pattern.field_patterns.iter().find(|(_, field)| member.id == field.id)
{
self.destructure_pattern(*inner_pattern, *inner_value);
}
self.destructure_pattern(*inner_pattern, *inner_value)?;
}
}
Some(())
}
Pattern::Tuple(pattern) => {
if let ConstValue::Struct(inner_values, _) = value.long(db) {
for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) {
self.destructure_pattern(*inner_pattern, *inner_value);
}
let ConstValue::Struct(inner_values, _) = value.long(db) else {
return None;
};
for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) {
self.destructure_pattern(*inner_pattern, *inner_value)?;
}
Some(())
}
Pattern::FixedSizeArray(pattern) => {
if let ConstValue::Struct(inner_values, _) = value.long(db) {
for (inner_pattern, inner_value) in
zip(&pattern.elements_patterns, inner_values)
{
self.destructure_pattern(*inner_pattern, *inner_value);
}
let ConstValue::Struct(inner_values, _) = value.long(db) else {
return None;
};
for (inner_pattern, inner_value) in zip(&pattern.elements_patterns, inner_values) {
self.destructure_pattern(*inner_pattern, *inner_value)?;
}
Some(())
}
Pattern::EnumVariant(pattern) => {
if let ConstValue::Enum(variant, inner_value) = value.long(db)
&& pattern.variant == *variant
&& let Some(inner_pattern) = pattern.inner_pattern
{
self.destructure_pattern(inner_pattern, *inner_value);
let ConstValue::Enum(variant, inner_value) = value.long(db) else {
return None;
};
require(pattern.variant.id == variant.id)?;
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *inner_value)
} else {
Some(())
}
}
}
Expand Down
Loading
Loading