diff --git a/examples/abs.zy b/examples/abs.zy index 34aabd9..2441887 100644 --- a/examples/abs.zy +++ b/examples/abs.zy @@ -4,6 +4,4 @@ fn abs(x: i32): i32 { return if x < 0 then -x else x } -fn main(): void { - std.debug.print(abs(-42)) -} +std.debug.print(abs(-42)) diff --git a/examples/explicit-main.zy b/examples/explicit-main.zy new file mode 100644 index 0000000..e29c1d7 --- /dev/null +++ b/examples/explicit-main.zy @@ -0,0 +1,7 @@ +const std = import("std") + +fn main(): void { + std.debug.print("hello from main") +} + +main() diff --git a/examples/get-readme.zy b/examples/get-readme.zy index d020bc8..597f4ac 100644 --- a/examples/get-readme.zy +++ b/examples/get-readme.zy @@ -6,10 +6,8 @@ fn getReadme(filename: string): !string { return data } -fn main(): void { - const data = getReadme("README.md") catch err { - std.debug.print(err) - return - } - std.debug.print(data) +const data = getReadme("README.md") catch err { + std.debug.print(err) + return } +std.debug.print(data) diff --git a/examples/switch.zy b/examples/switch.zy index f7f6442..6365736 100644 --- a/examples/switch.zy +++ b/examples/switch.zy @@ -9,9 +9,7 @@ fn label(n: i32): string { } } -fn main(): void { - std.debug.print(label(1)) - std.debug.print(label(2)) - std.debug.print(label(3)) - std.debug.print(label(99)) -} +std.debug.print(label(1)) +std.debug.print(label(2)) +std.debug.print(label(3)) +std.debug.print(label(99)) diff --git a/src/codegen/zig/mod.rs b/src/codegen/zig/mod.rs index 53cca07..2b3eab7 100644 --- a/src/codegen/zig/mod.rs +++ b/src/codegen/zig/mod.rs @@ -2,6 +2,8 @@ use crate::codegen::Backend; use crate::parser::*; use std::collections::{HashMap, HashSet}; +const MAIN_MANGLED: &str = "__zyre_fn_main"; + mod stdlib; mod tracker; @@ -165,15 +167,13 @@ impl ZigBackend { // Hoisting: collect export const and their dependencies let hoisted = Self::collect_hoisted(program); - let has_explicit_main = program - .iter() - .any(|item| matches!(item, TopLevel::FnDecl(f) if f.name == "main")); - // Pass 2: collect body_stmts first (needed for header generation) let mut body_stmts: Vec = Vec::new(); for item in program { match item { - TopLevel::ConstDecl { name, value, .. } => { + TopLevel::ConstDecl { + name, ty, value, .. + } => { let is_import = matches!(&value.kind, ExprKind::Import(_)); let is_module_alias = if let ExprKind::MemberAccess { obj, .. } = &value.kind { matches!(&obj.kind, ExprKind::Var(m) if import_names.contains(m) || self.aliases.contains_key(m)) @@ -189,7 +189,7 @@ impl ZigBackend { body_stmts.push(Stmt { kind: StmtKind::ConstDecl { name: name.clone(), - ty: None, + ty: ty.clone(), value: value.clone(), }, span: (0, 0), @@ -201,9 +201,8 @@ impl ZigBackend { } } - // main / implicit main uses std.heap, so std is required - let needs_std = - self.program_uses_std(program) || has_explicit_main || !body_stmts.is_empty(); + // implicit main uses std.heap, so std is required + let needs_std = self.program_uses_std(program) || !body_stmts.is_empty(); let mut out = String::new(); if needs_std { @@ -289,7 +288,7 @@ impl ZigBackend { } } - if !body_stmts.is_empty() && !has_explicit_main { + if !body_stmts.is_empty() { let needs_alloc = self.uses_allocator(&body_stmts); out.push_str("pub fn main() !void {\n"); out.push_str(" try __zyre_runtime.Output.init();\n"); @@ -323,6 +322,14 @@ impl ZigBackend { } } + fn is_allocating_call(&self, callee: &Expr) -> bool { + if let ExprKind::Var(name) = &callee.kind { + self.allocating_fns.contains(name) + } else { + false + } + } + fn gen_fn(&mut self, f: &FnDecl) -> String { let params: Vec = f .params @@ -330,19 +337,12 @@ impl ZigBackend { .map(|(name, ty)| format!("{}: {}", name, self.gen_type(ty))) .collect(); - let (ret, pub_prefix) = if f.name == "main" { - ("!void".to_string(), "pub ") - } else if f.exported { - (self.gen_type(&f.ret), "pub ") - } else { - (self.gen_type(&f.ret), "") - }; + let pub_prefix = if f.exported { "pub " } else { "" }; + let ret = self.gen_type(&f.ret); let needs_alloc = self.allocating_fns.contains(&f.name); - let params_str = if f.name == "main" { - "".to_string() - } else if needs_alloc { + let params_str = if needs_alloc { let mut all = vec!["__zyre_allocator: std.mem.Allocator".to_string()]; all.extend(params); all.join(", ") @@ -350,13 +350,13 @@ impl ZigBackend { params.join(", ") }; - let mut out = format!("{}fn {}({}) {} {{\n", pub_prefix, f.name, params_str, ret); - - if f.name == "main" { - out.push_str(" try __zyre_runtime.Output.init();\n"); - out.push_str(" defer __zyre_runtime.Output.restore();\n"); - out.push_str(&Self::gen_arena_setup(needs_alloc)); - } + // "main" is reserved for the implicit entry point in Zig; mangle user-defined fn main + let zig_name = if f.name == "main" { + MAIN_MANGLED + } else { + &f.name + }; + let mut out = format!("{}fn {}({}) {} {{\n", pub_prefix, zig_name, params_str, ret); for stmt in &f.body { out.push_str(&self.gen_stmt(stmt, 1)); @@ -508,13 +508,16 @@ impl ZigBackend { } } } - let callee_s = self.gen_expr(callee); + // Mangle calls to user-defined "main" to match the renamed fn + let callee_s = if matches!(&callee.kind, ExprKind::Var(n) if n == "main") { + MAIN_MANGLED.to_string() + } else { + self.gen_expr(callee) + }; let mut args_str = self.gen_args(args); // Insert __zyre_allocator as the first argument when calling an allocating fn - if let ExprKind::Var(name) = &callee.kind { - if self.allocating_fns.contains(name) { - args_str.insert(0, "__zyre_allocator".to_string()); - } + if self.is_allocating_call(callee) { + args_str.insert(0, "__zyre_allocator".to_string()); } format!("{}({})", callee_s, args_str.join(", ")) } diff --git a/src/codegen/zig/stdlib/mod.rs b/src/codegen/zig/stdlib/mod.rs index 760682c..c277e52 100644 --- a/src/codegen/zig/stdlib/mod.rs +++ b/src/codegen/zig/stdlib/mod.rs @@ -80,6 +80,10 @@ impl ZigBackend { // Check recursively match &expr.kind { ExprKind::Call { callee, args } => { + // Calling a user function that transitively needs an allocator + if self.is_allocating_call(callee) { + return true; + } self.expr_uses_allocator(callee) || args.iter().any(|a| self.expr_uses_allocator(a)) } ExprKind::MemberAccess { obj, .. } => self.expr_uses_allocator(obj), diff --git a/src/fmt.rs b/src/fmt.rs index 7bdb398..4a35b66 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -61,13 +61,18 @@ impl Formatter { match item { TopLevel::ConstDecl { name, + ty, value, exported, .. } => { let prefix = if *exported { "export " } else { "" }; + let type_ann = ty + .as_ref() + .map(|t| format!(": {}", fmt_type(t))) + .unwrap_or_default(); let val = self.fmt_expr(value); - self.line(&format!("{}const {} = {}", prefix, name, val)); + self.line(&format!("{}const {}{} = {}", prefix, name, type_ann, val)); } TopLevel::FnDecl(f) => self.fmt_fn(f), TopLevel::StructDecl { diff --git a/src/parser.rs b/src/parser.rs index 90c9e81..86860d8 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -190,6 +190,7 @@ pub enum TopLevel { ConstDecl { name: String, name_span: Span, + ty: Option, value: Expr, exported: bool, }, @@ -394,6 +395,12 @@ impl Parser { Token::Const => { self.advance(); let (name, name_span) = self.expect_ident_spanned(); + let ty = if self.peek() == &Token::Colon { + self.advance(); + Some(self.parse_type()) + } else { + None + }; self.expect(Token::Eq); match self.peek().clone() { Token::Import => { @@ -402,6 +409,7 @@ impl Parser { TopLevel::ConstDecl { name, name_span, + ty, value, exported, } @@ -454,6 +462,7 @@ impl Parser { TopLevel::ConstDecl { name, name_span, + ty, value, exported, } diff --git a/src/tests/fixtures/alloc_propagate.zy b/src/tests/fixtures/alloc_propagate.zy index bb10ac9..3785450 100644 --- a/src/tests/fixtures/alloc_propagate.zy +++ b/src/tests/fixtures/alloc_propagate.zy @@ -4,9 +4,7 @@ fn readIt(): !string { return std.fs.readTextFile("x.txt")? } -fn main(): void { - const data = readIt() catch _err { - return - } - std.debug.print(data) +const data = readIt() catch _err { + return } +std.debug.print(data) diff --git a/src/tests/fixtures/array.zy b/src/tests/fixtures/array.zy index 6dd6cb1..6e74540 100644 --- a/src/tests/fixtures/array.zy +++ b/src/tests/fixtures/array.zy @@ -1,6 +1,4 @@ const std = import("std") -fn main(): void { - const arr: i32[3] = [1, 2, 3] - std.debug.print(arr[0]) -} +const arr: i32[3] = [1, 2, 3] +std.debug.print(arr[0]) diff --git a/src/tests/fixtures/enum_decl.zy b/src/tests/fixtures/enum_decl.zy index c191878..70acdb1 100644 --- a/src/tests/fixtures/enum_decl.zy +++ b/src/tests/fixtures/enum_decl.zy @@ -16,6 +16,4 @@ fn describe(d: Direction): void { } } -fn main(): void { - std.debug.print("ok") -} +std.debug.print("ok") diff --git a/src/tests/fixtures/explicit_main.zy b/src/tests/fixtures/explicit_main.zy new file mode 100644 index 0000000..e29c1d7 --- /dev/null +++ b/src/tests/fixtures/explicit_main.zy @@ -0,0 +1,7 @@ +const std = import("std") + +fn main(): void { + std.debug.print("hello from main") +} + +main() diff --git a/src/tests/fixtures/export_const.zy b/src/tests/fixtures/export_const.zy index 84235cb..f213f0f 100644 --- a/src/tests/fixtures/export_const.zy +++ b/src/tests/fixtures/export_const.zy @@ -4,6 +4,4 @@ const a = 2 const b = 3 export const c = a * b -fn main(): void { - std.debug.print(c) -} +std.debug.print(c) diff --git a/src/tests/fixtures/fn_call.zy b/src/tests/fixtures/fn_call.zy index eb42962..f75b6b0 100644 --- a/src/tests/fixtures/fn_call.zy +++ b/src/tests/fixtures/fn_call.zy @@ -4,6 +4,4 @@ fn add(a: i32, b: i32): i32 { return a + b } -fn main(): void { - std.debug.print(add(1, 2)) -} +std.debug.print(add(1, 2)) diff --git a/src/tests/fixtures/if_stmt.zy b/src/tests/fixtures/if_stmt.zy index 439e1fc..4c1b2ce 100644 --- a/src/tests/fixtures/if_stmt.zy +++ b/src/tests/fixtures/if_stmt.zy @@ -1,8 +1,6 @@ const std = import("std") -fn main(): void { - const x = true - if x { - std.debug.print("yes") - } +const x = true +if x { + std.debug.print("yes") } diff --git a/src/tests/fixtures/print_int.zy b/src/tests/fixtures/print_int.zy index 9083c9a..380e870 100644 --- a/src/tests/fixtures/print_int.zy +++ b/src/tests/fixtures/print_int.zy @@ -1,6 +1,4 @@ const std = import("std") -fn main(): void { - const x: i32 = 42 - std.debug.print(x) -} +const x: i32 = 42 +std.debug.print(x) diff --git a/src/tests/fixtures/print_str.zy b/src/tests/fixtures/print_str.zy index bc711d5..be4adb0 100644 --- a/src/tests/fixtures/print_str.zy +++ b/src/tests/fixtures/print_str.zy @@ -1,5 +1,3 @@ const std = import("std") -fn main(): void { - std.debug.print("hello") -} +std.debug.print("hello") diff --git a/src/tests/fixtures/read_text_file.zy b/src/tests/fixtures/read_text_file.zy index 1302264..f0fdb4b 100644 --- a/src/tests/fixtures/read_text_file.zy +++ b/src/tests/fixtures/read_text_file.zy @@ -1,8 +1,6 @@ const std = import("std") -fn main(): void { - const data = std.fs.readTextFile("file.txt") catch _err { - return - } - std.debug.print(data) +const data = std.fs.readTextFile("file.txt") catch _err { + return } +std.debug.print(data) diff --git a/src/tests/fixtures/struct_decl.zy b/src/tests/fixtures/struct_decl.zy index 77273e3..18fe344 100644 --- a/src/tests/fixtures/struct_decl.zy +++ b/src/tests/fixtures/struct_decl.zy @@ -9,6 +9,4 @@ fn getX(p: Point): i32 { return p.x } -fn main(): void { - std.debug.print("ok") -} +std.debug.print("ok") diff --git a/src/tests/fixtures/switch_stmt.zy b/src/tests/fixtures/switch_stmt.zy index 094b71e..79eae18 100644 --- a/src/tests/fixtures/switch_stmt.zy +++ b/src/tests/fixtures/switch_stmt.zy @@ -8,6 +8,4 @@ fn check(n: i32): void { } } -fn main(): void { - check(1) -} +check(1) diff --git a/src/tests/fixtures/while_loop.zy b/src/tests/fixtures/while_loop.zy index 207380a..3439c76 100644 --- a/src/tests/fixtures/while_loop.zy +++ b/src/tests/fixtures/while_loop.zy @@ -1,9 +1,7 @@ const std = import("std") -fn main(): void { - const x = true - while x { - std.debug.print("loop") - break - } +const x = true +while x { + std.debug.print("loop") + break } diff --git a/src/tests/snapshots/zyre__tests__snapshots__codegen_snapshots@explicit_main.zy.snap b/src/tests/snapshots/zyre__tests__snapshots__codegen_snapshots@explicit_main.zy.snap new file mode 100644 index 0000000..d1e1fd2 --- /dev/null +++ b/src/tests/snapshots/zyre__tests__snapshots__codegen_snapshots@explicit_main.zy.snap @@ -0,0 +1,31 @@ +--- +source: src/tests/snapshots.rs +assertion_line: 10 +expression: output +input_file: src/tests/fixtures/explicit_main.zy +--- +const std = @import("std"); +const __zyre_runtime = @import("zyre_runtime.zig"); + +fn __zyre_print(val: anytype) void { + if (comptime @typeInfo(@TypeOf(val)) == .pointer) { + std.debug.print("{s}\n", .{val}); + } else { + std.debug.print("{}\n", .{val}); + } +} + +fn __zyre_fn_main() void { + __zyre_print("hello from main"); +} + +pub fn main() !void { + try __zyre_runtime.Output.init(); + defer __zyre_runtime.Output.restore(); + var __zyre_arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer __zyre_arena.deinit(); + const __zyre_allocator = __zyre_arena.allocator(); + _ = __zyre_allocator; + + __zyre_fn_main(); +} diff --git a/src/typechecker.rs b/src/typechecker.rs index 05fc427..f9685b6 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -157,11 +157,14 @@ impl TypeChecker { let mut import_names: std::collections::HashSet = std::collections::HashSet::new(); for item in program { match item { - TopLevel::ConstDecl { name, value, .. } => { + TopLevel::ConstDecl { + name, ty, value, .. + } => { let is_import = matches!(&value.kind, ExprKind::Import(_)); // Register all top-level ConstDecls into globals in order - let ty = self.check_expr(value, &self.globals.clone()); - self.globals.insert(name.clone(), ty); + let val_ty = self.check_expr(value, &self.globals.clone()); + let inferred_ty = self.resolve_decl_type(name, ty, val_ty); + self.globals.insert(name.clone(), inferred_ty); if is_import { import_names.insert(name.clone()); // Pre-collect exports of local .zy modules @@ -391,24 +394,31 @@ impl TypeChecker { } } + /// Resolve the declared type of a variable, checking for mismatches. + /// Returns the declared type if annotated, otherwise the inferred type. + fn resolve_decl_type(&mut self, name: &str, ty: &Option, val_ty: Ty) -> Ty { + if let Some(t) = ty { + let decl_ty = Ty::from_type_expr(t, &self.types); + if !decl_ty.is_assignable_from(&val_ty) { + self.error(format!( + "Type mismatch: '{}' is declared as '{}' but got '{}'", + name, + decl_ty.display(), + val_ty.display() + )); + } + decl_ty + } else { + val_ty + } + } + fn check_stmt(&mut self, stmt: &Stmt, scope: &mut HashMap, ret_ty: &Ty) { self.current_span = stmt.span; match &stmt.kind { StmtKind::ConstDecl { name, ty, value } | StmtKind::LetDecl { name, ty, value } => { let val_ty = self.check_expr(value, scope); - let decl_ty = if let Some(t) = ty { - Ty::from_type_expr(t, &self.types) - } else { - val_ty.clone() - }; - if !decl_ty.is_assignable_from(&val_ty) { - self.error(format!( - "Type mismatch: '{}' is declared as '{}' but got '{}'", - name, - decl_ty.display(), - val_ty.display() - )); - } + let decl_ty = self.resolve_decl_type(name, ty, val_ty); scope.insert(name.clone(), decl_ty); } StmtKind::Return(None) => {