diff --git a/examples/abs.zy b/examples/abs.zy new file mode 100644 index 0000000..ef4378e --- /dev/null +++ b/examples/abs.zy @@ -0,0 +1,9 @@ +const std = import("std"); + +fn abs(x: i32): i32 { + return if x < 0 then -x else x +} + +fn main(): void { + std.debug.print(abs(-42)); +} diff --git a/src/codegen/zig/mod.rs b/src/codegen/zig/mod.rs index 346ea51..e2ab57d 100644 --- a/src/codegen/zig/mod.rs +++ b/src/codegen/zig/mod.rs @@ -432,7 +432,23 @@ impl ZigBackend { StmtKind::Return(Some(e)) => format!("{}return {};\n", ind, self.gen_expr(e)), StmtKind::Break => format!("{}break;\n", ind), StmtKind::Continue => format!("{}continue;\n", ind), - StmtKind::If { cond, body } => self.gen_cond_block("if", cond, body, level), + StmtKind::If { + cond, + body, + else_body, + } => { + let mut s = self.gen_cond_block("if", cond, body, level); + if let Some(else_stmts) = else_body { + let ind = Self::indent(level); + s.pop(); // remove trailing '\n' + s.push_str(" else {\n"); + for stmt in else_stmts { + s.push_str(&self.gen_stmt(stmt, level + 1)); + } + s.push_str(&format!("{}}}\n", ind)); + } + s + } StmtKind::While { cond, body } => self.gen_cond_block("while", cond, body, level), StmtKind::ExprStmt(e) => { let s = self.gen_expr(e); @@ -594,6 +610,15 @@ impl ZigBackend { out.push('}'); out } + + ExprKind::If { cond, then, else_ } => { + format!( + "if ({}) {} else {}", + self.gen_expr(cond), + self.gen_expr(then), + self.gen_expr(else_) + ) + } } } @@ -607,8 +632,16 @@ impl ZigBackend { Self::expr_references_var(value, var) } StmtKind::Return(Some(e)) => Self::expr_references_var(e, var), - StmtKind::If { cond, body } => { - Self::expr_references_var(cond, var) || Self::stmts_reference_var(body, var) + StmtKind::If { + cond, + body, + else_body, + } => { + Self::expr_references_var(cond, var) + || Self::stmts_reference_var(body, var) + || else_body + .as_ref() + .is_some_and(|s| Self::stmts_reference_var(s, var)) } StmtKind::While { cond, body } => { Self::expr_references_var(cond, var) || Self::stmts_reference_var(body, var) @@ -647,6 +680,11 @@ impl ZigBackend { ExprKind::Index { obj, idx } => { Self::expr_references_var(obj, var) || Self::expr_references_var(idx, var) } + ExprKind::If { cond, then, else_ } => { + Self::expr_references_var(cond, var) + || Self::expr_references_var(then, var) + || Self::expr_references_var(else_, var) + } _ => false, } } diff --git a/src/codegen/zig/stdlib/mod.rs b/src/codegen/zig/stdlib/mod.rs index 4fcc7fd..760682c 100644 --- a/src/codegen/zig/stdlib/mod.rs +++ b/src/codegen/zig/stdlib/mod.rs @@ -55,8 +55,14 @@ impl ZigBackend { self.expr_uses_allocator(value) } StmtKind::Return(Some(e)) => self.expr_uses_allocator(e), - StmtKind::If { cond, body } => { - self.expr_uses_allocator(cond) || self.uses_allocator(body) + StmtKind::If { + cond, + body, + else_body, + } => { + self.expr_uses_allocator(cond) + || self.uses_allocator(body) + || else_body.as_ref().is_some_and(|s| self.uses_allocator(s)) } StmtKind::While { cond, body } => { self.expr_uses_allocator(cond) || self.uses_allocator(body) @@ -96,6 +102,11 @@ impl ZigBackend { ExprKind::Index { obj, idx } => { self.expr_uses_allocator(obj) || self.expr_uses_allocator(idx) } + ExprKind::If { cond, then, else_ } => { + self.expr_uses_allocator(cond) + || self.expr_uses_allocator(then) + || self.expr_uses_allocator(else_) + } _ => false, } } diff --git a/src/codegen/zig/tracker.rs b/src/codegen/zig/tracker.rs index 4b10f62..d5cbfcc 100644 --- a/src/codegen/zig/tracker.rs +++ b/src/codegen/zig/tracker.rs @@ -54,7 +54,18 @@ impl ZigBackend { } StmtKind::Return(Some(e)) => self.expr_calls_any(e, targets), StmtKind::ExprStmt(e) => self.expr_calls_any(e, targets), - StmtKind::If { cond, body } | StmtKind::While { cond, body } => { + StmtKind::If { + cond, + body, + else_body, + } => { + self.expr_calls_any(cond, targets) + || self.stmts_call_any(body, targets) + || else_body + .as_ref() + .is_some_and(|s| self.stmts_call_any(s, targets)) + } + StmtKind::While { cond, body } => { self.expr_calls_any(cond, targets) || self.stmts_call_any(body, targets) } _ => false, @@ -89,6 +100,11 @@ impl ZigBackend { SwitchBody::Block(stmts) => self.stmts_call_any(stmts, targets), }) } + ExprKind::If { cond, then, else_ } => { + self.expr_calls_any(cond, targets) + || self.expr_calls_any(then, targets) + || self.expr_calls_any(else_, targets) + } _ => false, } } @@ -113,7 +129,15 @@ impl ZigBackend { self.expr_uses_std(value) } StmtKind::Return(Some(e)) => self.expr_uses_std(e), - StmtKind::If { cond, body } => self.expr_uses_std(cond) || self.stmts_use_std(body), + StmtKind::If { + cond, + body, + else_body, + } => { + self.expr_uses_std(cond) + || self.stmts_use_std(body) + || else_body.as_ref().is_some_and(|s| self.stmts_use_std(s)) + } StmtKind::While { cond, body } => self.expr_uses_std(cond) || self.stmts_use_std(body), StmtKind::ExprStmt(e) => self.expr_uses_std(e), _ => false, @@ -163,6 +187,9 @@ impl ZigBackend { } ExprKind::ArrayLiteral(elems) => elems.iter().any(|e| self.expr_uses_std(e)), ExprKind::Index { obj, idx } => self.expr_uses_std(obj) || self.expr_uses_std(idx), + ExprKind::If { cond, then, else_ } => { + self.expr_uses_std(cond) || self.expr_uses_std(then) || self.expr_uses_std(else_) + } _ => false, } } diff --git a/src/lexer.rs b/src/lexer.rs index 1a8d8ff..2f58850 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -14,6 +14,7 @@ pub enum Token { Switch, Catch, Else, + Then, // Type keywords Void, // Literals @@ -132,7 +133,7 @@ fn tokenize_raw(source: &str) -> Vec<(Token, Span)> { '\n' => { chars.next(); // Insert AutoSemi if the previous token can end a statement - if tokens.last().map(|(t, _)| can_end_stmt(t)).unwrap_or(false) { + if tokens.last().is_some_and(|(t, _)| can_end_stmt(t)) { tokens.push((Token::AutoSemi, (pos, pos + 1))); } } @@ -307,6 +308,7 @@ fn tokenize_raw(source: &str) -> Vec<(Token, Span)> { "switch" => Token::Switch, "catch" => Token::Catch, "else" => Token::Else, + "then" => Token::Then, "and" => Token::And, "or" => Token::Or, "void" => Token::Void, diff --git a/src/parser.rs b/src/parser.rs index 6c4f6bc..b849d7e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -77,6 +77,11 @@ pub enum ExprKind { obj: Box, idx: Box, }, // arr[i] index expression + If { + cond: Box, + then: Box, + else_: Box, + }, } #[derive(Debug, Clone)] @@ -145,6 +150,7 @@ pub enum StmtKind { If { cond: Expr, body: Vec, + else_body: Option>, }, While { cond: Expr, @@ -549,8 +555,21 @@ impl Parser { self.expect(Token::LBrace); let body = self.parse_block(); self.expect(Token::RBrace); + let else_body = if self.peek() == &Token::Else { + self.advance(); + self.expect(Token::LBrace); + let b = self.parse_block(); + self.expect(Token::RBrace); + Some(b) + } else { + None + }; Stmt { - kind: StmtKind::If { cond, body }, + kind: StmtKind::If { + cond, + body, + else_body, + }, span, } } @@ -887,6 +906,24 @@ impl Parser { span, } } + Token::If => { + let span_start = self.peek_span().0; + self.advance(); + let cond = self.parse_expr(); + self.expect(Token::Then); + let then = self.parse_expr(); + self.expect(Token::Else); + let else_ = self.parse_expr(); + let span = (span_start, self.prev_end()); + Expr { + kind: ExprKind::If { + cond: Box::new(cond), + then: Box::new(then), + else_: Box::new(else_), + }, + span, + } + } Token::Import => { let span_start = self.peek_span().0; self.advance(); diff --git a/src/tests/codegen.rs b/src/tests/codegen.rs index 84b7766..8b021df 100644 --- a/src/tests/codegen.rs +++ b/src/tests/codegen.rs @@ -128,6 +128,144 @@ fn test_codegen_empty_return() { assert!(out.contains("return;"), "got:\n{}", out); } +#[test] +fn test_codegen_if_expr() { + let out = compile( + r#" + fn choose(x: i32): string { + return if x > 0 then "positive" else "non-positive" + } + "#, + ); + assert!( + out.contains("if ((x > 0)) \"positive\" else \"non-positive\""), + "got:\n{}", + out + ); +} + +#[test] +fn test_codegen_if_expr_uses_std() { + // std used inside if branches must trigger std import + let out = compile( + r#" + const std = import("std"); + fn greet(flag: bool): void { + std.debug.print(if flag then "yes" else "no") + } + "#, + ); + assert!( + out.contains("const std = @import(\"std\")"), + "expected std import, got:\n{}", + out + ); +} + +#[test] +fn test_codegen_if_expr_alloc_propagation() { + // allocator-requiring call inside if branch must propagate allocator to caller + let out = compile( + r#" + const std = import("std"); + fn load(flag: bool): !string { + return if flag then std.fs.readTextFile("a.txt") else std.fs.readTextFile("b.txt") + } + "#, + ); + assert!( + out.contains("__zyre_allocator"), + "expected allocator param, got:\n{}", + out + ); +} + +#[test] +fn test_codegen_if_stmt_else() { + let out = compile( + r#" + fn main(): void { + if true { + return + } else { + return + } + } + "#, + ); + assert!(out.contains("} else {"), "got:\n{}", out); +} + +#[test] +fn test_codegen_if_stmt_else_std() { + // std used only in else branch must still trigger std import + let out = compile( + r#" + const std = import("std"); + fn main(): void { + if false { + return + } else { + std.debug.print("else") + } + } + "#, + ); + assert!( + out.contains("const std = @import(\"std\")"), + "got:\n{}", + out + ); + assert!(out.contains("} else {"), "got:\n{}", out); +} + +#[test] +fn test_codegen_if_expr_nested() { + let out = compile( + r#" + fn classify(x: i32): string { + return if x > 0 then "pos" else if x == 0 then "zero" else "neg" + } + "#, + ); + assert!(out.contains("if ((x > 0))"), "got:\n{}", out); + assert!(out.contains("if ((x == 0))"), "got:\n{}", out); +} + +#[test] +fn test_codegen_if_expr_as_arg() { + // if expression passed directly as function argument + let out = compile( + r#" + const std = import("std"); + fn greet(flag: bool): void { + std.debug.print(if flag then "yes" else "no") + } + "#, + ); + assert!(out.contains("__zyre_print(if (flag)"), "got:\n{}", out); +} + +#[test] +fn test_codegen_if_stmt_no_else() { + // if without else must not emit "} else {" + let out = compile( + r#" + export fn check(x: i32): void { + if x > 0 { + return + } + } + "#, + ); + assert!(out.contains("if ((x > 0))"), "got:\n{}", out); + assert!( + !out.contains("} else {"), + "unexpected else branch, got:\n{}", + out + ); +} + #[test] fn test_codegen_export_const_hoisted() { let out = compile( diff --git a/src/tests/typechecker.rs b/src/tests/typechecker.rs index 1c7f784..476124a 100644 --- a/src/tests/typechecker.rs +++ b/src/tests/typechecker.rs @@ -201,3 +201,59 @@ fn test_empty_return_ok() { } "#); } + +// --- If statement with else --- + +#[test] +fn test_if_stmt_else_ok() { + ok(r#" + fn main(): void { + if true { + return + } else { + return + } + } + "#); +} + +#[test] +fn test_if_stmt_else_condition_not_bool() { + err( + r#" + fn main(): void { + if 42 { + return + } else { + return + } + } + "#, + "if condition must be bool", + ); +} + +// --- If expression --- + +#[test] +fn test_if_expr_ok() { + ok(r#" + fn choose(x: i32): string { + return if x > 0 then "positive" else "non-positive" + } + fn main(): void {} + "#); +} + +#[test] +fn test_if_expr_branch_type_mismatch() { + err( + r#" + fn choose(x: i32): string { + return if x > 0 then "positive" else 42 + } + fn main(): void {} + "#, + "if branches have different types", + ); +} diff --git a/src/typechecker.rs b/src/typechecker.rs index 0c4e3f4..6b9dff9 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -307,10 +307,14 @@ impl TypeChecker { fn collect_refs(stmts: &[Stmt]) -> std::collections::HashSet { let mut refs = std::collections::HashSet::new(); + Self::refs_stmts(stmts, &mut refs); + refs + } + + fn refs_stmts(stmts: &[Stmt], refs: &mut std::collections::HashSet) { for stmt in stmts { - Self::refs_stmt(stmt, &mut refs); + Self::refs_stmt(stmt, refs); } - refs } fn refs_stmt(stmt: &Stmt, refs: &mut std::collections::HashSet) { @@ -320,12 +324,21 @@ impl TypeChecker { } StmtKind::Return(Some(e)) => Self::refs_expr(e, refs), StmtKind::ExprStmt(e) => Self::refs_expr(e, refs), - StmtKind::If { cond, body } | StmtKind::While { cond, body } => { + StmtKind::If { + cond, + body, + else_body, + } => { Self::refs_expr(cond, refs); - for s in body { - Self::refs_stmt(s, refs); + Self::refs_stmts(body, refs); + if let Some(else_stmts) = else_body { + Self::refs_stmts(else_stmts, refs); } } + StmtKind::While { cond, body } => { + Self::refs_expr(cond, refs); + Self::refs_stmts(body, refs); + } _ => {} } } @@ -349,20 +362,14 @@ impl TypeChecker { ExprKind::UnOp { expr, .. } | ExprKind::Propagate(expr) => Self::refs_expr(expr, refs), ExprKind::Catch { expr, body, .. } => { Self::refs_expr(expr, refs); - for s in body { - Self::refs_stmt(s, refs); - } + Self::refs_stmts(body, refs); } ExprKind::Switch { expr, arms } => { Self::refs_expr(expr, refs); for arm in arms { match &arm.body { SwitchBody::Expr(e) => Self::refs_expr(e, refs), - SwitchBody::Block(stmts) => { - for s in stmts { - Self::refs_stmt(s, refs); - } - } + SwitchBody::Block(stmts) => Self::refs_stmts(stmts, refs), } } } @@ -375,6 +382,11 @@ impl TypeChecker { Self::refs_expr(obj, refs); Self::refs_expr(idx, refs); } + ExprKind::If { cond, then, else_ } => { + Self::refs_expr(cond, refs); + Self::refs_expr(then, refs); + Self::refs_expr(else_, refs); + } _ => {} } } @@ -417,7 +429,11 @@ impl TypeChecker { )); } } - StmtKind::If { cond, body } => { + StmtKind::If { + cond, + body, + else_body, + } => { let cond_ty = self.check_expr(cond, scope); if cond_ty != Ty::Bool { self.error(format!( @@ -427,6 +443,10 @@ impl TypeChecker { } let mut inner = scope.clone(); self.check_block(body, &mut inner, ret_ty); + if let Some(else_stmts) = else_body { + let mut inner = scope.clone(); + self.check_block(else_stmts, &mut inner, ret_ty); + } } StmtKind::While { cond, body } => { let cond_ty = self.check_expr(cond, scope); @@ -704,6 +724,23 @@ impl TypeChecker { } result_ty } + + ExprKind::If { cond, then, else_ } => { + let cond_ty = self.check_expr(cond, scope); + if cond_ty != Ty::Bool && cond_ty != Ty::Unknown { + self.error("if condition must be bool".to_string()); + } + let then_ty = self.check_expr(then, scope); + let else_ty = self.check_expr(else_, scope); + if then_ty != else_ty && then_ty != Ty::Unknown && else_ty != Ty::Unknown { + self.error(format!( + "if branches have different types: '{}' vs '{}'", + then_ty.display(), + else_ty.display() + )); + } + then_ty + } } }