From dc5e6695b2a108adddd571858cf9beff976aca8f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Sun, 15 Mar 2026 14:50:34 +0100 Subject: [PATCH 01/31] refactor: replace `enumerable` implementation with `enum` type, adjusted parsing logic --- .../compiler/include/ast/nodes/enumerables.h | 103 ---------- .../include/ast/nodes/type_definition.h | 15 +- packages/compiler/include/ast/nodes/types.h | 70 ++++++- .../compiler/include/ast/parsing_context.h | 2 - packages/compiler/src/ast/ast.cpp | 3 +- .../src/ast/context/parsing_context.cpp | 7 - .../compiler/src/ast/nodes/enumerables.cpp | 162 --------------- .../ast/nodes/expressions/member_accessor.cpp | 3 +- .../src/ast/nodes/types/enum_type.cpp | 186 ++++++++++++++++++ 9 files changed, 261 insertions(+), 290 deletions(-) delete mode 100644 packages/compiler/include/ast/nodes/enumerables.h delete mode 100644 packages/compiler/src/ast/nodes/enumerables.cpp create mode 100644 packages/compiler/src/ast/nodes/types/enum_type.cpp diff --git a/packages/compiler/include/ast/nodes/enumerables.h b/packages/compiler/include/ast/nodes/enumerables.h deleted file mode 100644 index db2cafa9..00000000 --- a/packages/compiler/include/ast/nodes/enumerables.h +++ /dev/null @@ -1,103 +0,0 @@ -#pragma once - -#include "ast_node.h" -#include "ast/nodes/literal_values.h" - -#include - -namespace stride::ast -{ - class AstEnumerableMember - : public IAstNode - { - friend class AstEnumerable; - - std::string _name; - std::unique_ptr _value; - - public: - explicit AstEnumerableMember( - const SourceFragment& source, - const std::shared_ptr& context, - std::string name, - std::unique_ptr value - ) : - IAstNode(source, context), - _name(std::move(name)), - _value(std::move(value)) {} - - [[nodiscard]] - const std::string& get_name() const - { - return this->_name; - } - - [[nodiscard]] - AstLiteral& value() const - { - return *this->_value; - } - - std::string to_string() override; - - llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override - { - return nullptr; - } - - std::unique_ptr clone() override; - }; - - class AstEnumerable - : public IAstNode - { - const std::vector> _members; - std::string _name; - - public: - explicit AstEnumerable( - const SourceFragment& source, - const std::shared_ptr& context, - std::vector> members, - std::string name - ) : - IAstNode(source, context), - _members(std::move(members)), - _name(std::move(name)) {} - - [[nodiscard]] - const std::vector>& - get_members() const - { - return this->_members; - } - - [[nodiscard]] - const std::string& get_name() const - { - return this->_name; - } - - std::string to_string() override; - - llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override - { - return nullptr; - } - - std::unique_ptr clone() override; - - void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override; - }; - - std::unique_ptr parse_enumerable_member( - const std::shared_ptr& context, - TokenSet& set, - size_t element_index - ); - - std::unique_ptr parse_enumerable_declaration( - const std::shared_ptr& context, - TokenSet& set, - VisibilityModifier modifier); -} // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/type_definition.h b/packages/compiler/include/ast/nodes/type_definition.h index 311ed751..04a0c5c3 100644 --- a/packages/compiler/include/ast/nodes/type_definition.h +++ b/packages/compiler/include/ast/nodes/type_definition.h @@ -6,7 +6,7 @@ #include #include -namespace stride:: ast +namespace stride::ast { enum class VisibilityModifier; class TokenSet; @@ -26,7 +26,7 @@ namespace stride:: ast std::string name, std::unique_ptr type, const VisibilityModifier visibility, - GenericParameterList generic_parameters + GenericParameterList generic_parameters = {} ) : IAstNode(source, context), _name(std::move(name)), @@ -83,4 +83,15 @@ namespace stride:: ast TokenSet& set, VisibilityModifier modifier ); + + EnumMemberPair parse_enumerable_member( + const std::shared_ptr& context, + TokenSet& set, + size_t element_index + ); + + std::unique_ptr parse_enum_type_definition( + const std::shared_ptr& context, + TokenSet& set, + VisibilityModifier modifier); } diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index 2f85817b..7e8753bd 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -4,11 +4,13 @@ #include "formatting.h" #include "ast/flags.h" #include "ast/generics.h" +#include "ast/modifiers.h" #include #include #include #include +#include namespace llvm { @@ -18,6 +20,9 @@ namespace llvm namespace stride::ast { + enum class VisibilityModifier; + class AstLiteral; + namespace definition { class TypeDefinition; @@ -28,6 +33,9 @@ namespace stride::ast using ObjectTypeMemberPair = std::pair>; using ObjectTypeMemberList = std::vector; + using EnumMemberValueTy = std::unique_ptr; + using EnumMemberPair = std::pair; + enum class PrimitiveType { INT8, @@ -440,11 +448,6 @@ namespace stride::ast [[nodiscard]] bool equals(IAstType* other) override; - bool is_castable_to(IAstType* other) override - { - return IAstType::is_castable_to(other); - } - private: bool is_assignable_to_impl(IAstType* other) override; @@ -506,6 +509,58 @@ namespace stride::ast llvm::Type* get_llvm_type_impl(llvm::Module* module) override; }; + class AstEnumType + : public IAstType + { + + std::vector _members; + std::string _name; + + public: + explicit AstEnumType( + const SourceFragment& source, + const std::shared_ptr& context, + std::string enum_name, + std::vector members, + int flags = SRFLAG_NONE + ); + + [[nodiscard]] + std::unique_ptr clone() override; + + [[nodiscard]] + const std::vector& get_members() const + { + return _members; + } + + [[nodiscard]] + const std::string& get_name() const + { + return _name; + } + + [[nodiscard]] + std::string get_type_name() override + { + return _name; + } + + [[nodiscard]] + bool equals(IAstType* other) override; + + std::string to_string() override; + + llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override; + + void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override; + + private: + bool is_assignable_to_impl(IAstType* other) override; + + llvm::Type* get_llvm_type_impl(llvm::Module* module) override; + }; + class AstTupleType : public IAstType { @@ -540,11 +595,6 @@ namespace stride::ast [[nodiscard]] bool equals(IAstType* other) override; - bool is_castable_to(IAstType* other) override - { - return IAstType::is_castable_to(other); - } - llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override; private: diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 95dc0d3c..94facadb 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -418,8 +418,6 @@ namespace stride::ast [[nodiscard]] bool is_type_defined(const std::string& type_name) const; - void define_symbol(const Symbol& symbol_name, definition::SymbolType type); - void define(std::unique_ptr definition); /// Checks whether the provided variable name is defined in the current context. diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp index 87e5b3df..7d49dda6 100644 --- a/packages/compiler/src/ast/ast.cpp +++ b/packages/compiler/src/ast/ast.cpp @@ -6,7 +6,6 @@ #include "ast/nodes/blocks.h" #include "ast/nodes/conditional_statement.h" #include "ast/nodes/control_flow_statements.h" -#include "ast/nodes/enumerables.h" #include "ast/nodes/for_loop.h" #include "ast/nodes/function_declaration.h" #include "ast/nodes/import.h" @@ -156,7 +155,7 @@ std::unique_ptr stride::ast::parse_next_statement( case TokenType::KEYWORD_TYPE: return parse_type_definition(context, set, visibility_modifier); case TokenType::KEYWORD_ENUM: - return parse_enumerable_declaration(context, set, visibility_modifier); + return parse_enum_type_definition(context, set, visibility_modifier); case TokenType::KEYWORD_FOR: return parse_for_loop_statement(context, set, visibility_modifier); case TokenType::KEYWORD_WHILE: diff --git a/packages/compiler/src/ast/context/parsing_context.cpp b/packages/compiler/src/ast/context/parsing_context.cpp index a444141f..b6a1cbc0 100644 --- a/packages/compiler/src/ast/context/parsing_context.cpp +++ b/packages/compiler/src/ast/context/parsing_context.cpp @@ -37,13 +37,6 @@ const ParsingContext& ParsingContext::traverse_to_root() const return *current; } -void ParsingContext::define_symbol(const Symbol& symbol_name, const SymbolType type) -{ - this->_symbols.push_back( - std::make_unique(type, symbol_name) - ); -} - void ParsingContext::define(std::unique_ptr definition) { this->_symbols.push_back(std::move(definition)); diff --git a/packages/compiler/src/ast/nodes/enumerables.cpp b/packages/compiler/src/ast/nodes/enumerables.cpp deleted file mode 100644 index 6c683eb5..00000000 --- a/packages/compiler/src/ast/nodes/enumerables.cpp +++ /dev/null @@ -1,162 +0,0 @@ -#include "ast/nodes/enumerables.h" - -#include "ast/parsing_context.h" -#include "ast/nodes/blocks.h" -#include "ast/tokens/token_set.h" - -#include - -using namespace stride::ast; -using namespace stride::ast::definition; - -/** - * Parses member entries with the following sequence: - * - * IDENTIFIER: ; - * - */ -std::unique_ptr stride::ast::parse_enumerable_member( - const std::shared_ptr& context, - TokenSet& set, - size_t element_index -) -{ - const auto member_name_tok = set.expect(TokenType::IDENTIFIER); - auto member_sym = member_name_tok.get_lexeme(); - - context->define_symbol( - Symbol( - member_name_tok.get_source_fragment(), - context->get_name(), - member_sym - ), - SymbolType::ENUM_MEMBER - ); - - // Using index as element value if no explicit value is provided, and allowing optional trailing comma - if (!set.has_next() || !set.peek_next_eq(TokenType::COLON)) - { - return std::make_unique( - member_name_tok.get_source_fragment(), - context, - std::move(member_sym), - std::make_unique( - member_name_tok.get_source_fragment(), - context, - PrimitiveType::INT32, - element_index - ) - ); - } - - set.expect(TokenType::COLON, "Expected a colon after enum member name"); - - auto value = parse_literal_optional(context, set); - - if (!value.has_value()) - set.throw_error("Expected a literal value for enum member"); - - return std::make_unique( - member_name_tok.get_source_fragment(), - context, - std::move(member_sym), - std::move(value.value()) - ); -} - -std::unique_ptr stride::ast::parse_enumerable_declaration( - const std::shared_ptr& context, - TokenSet& set, - [[maybe_unused]] VisibilityModifier modifier -) -{ - const auto reference_token = set.expect(TokenType::KEYWORD_ENUM); - const auto enumerable_name = set.expect(TokenType::IDENTIFIER).get_lexeme(); - - context->define_symbol( - Symbol(reference_token.get_source_fragment(), - context->get_name(), - enumerable_name), - SymbolType::ENUM - ); - - auto enum_body_subset = collect_block_required(set, "Expected a block in enum declaration"); - - std::vector> members; - - auto enum_definition_context = std::make_shared( - context, - context->get_context_type()); - - members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset, 0)); - - for (size_t i = 1; enum_body_subset.has_next(); ++i) - { - enum_body_subset.expect(TokenType::COMMA, "Expected a comma between enum members"); - members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset, i)); - } - - return std::make_unique( - reference_token.get_source_fragment(), - enum_definition_context, - std::move(members), - enumerable_name - ); -} - -void AstEnumerable::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) -{ - for (const auto& member : this->_members) - { - member->resolve_forward_references(module, builder); - } -} - -std::unique_ptr AstEnumerable::clone() -{ - std::vector> cloned_members; - cloned_members.reserve(this->_members.size()); - - for (const auto& member : this->_members) - { - cloned_members.push_back(member->clone_as()); - } - - return std::make_unique( - this->get_source_fragment(), - this->get_context(), - std::move(cloned_members), - this->get_name() - ); -} - -std::unique_ptr AstEnumerableMember::clone() -{ - return std::make_unique( - this->get_source_fragment(), - this->get_context(), - this->get_name(), - this->value().clone_as() - ); -} - -std::string AstEnumerableMember::to_string() -{ - auto member_value_str = this->value().to_string(); - return std::format("{}: {}", this->get_name(), member_value_str); -} - -std::string AstEnumerable::to_string() -{ - std::vector members; - - for (const auto& member : this->get_members()) - { - members.push_back(member->to_string()); - } - - return std::format( - "Enumerable {} (\n {}\n)", - this->get_name(), - join(members, ",\n ")); -} diff --git a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp index 32fb9514..0759dc65 100644 --- a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp @@ -3,9 +3,8 @@ #include "ast/casting.h" #include "ast/closures.h" #include "ast/parsing_context.h" -#include "ast/nodes/enumerables.h" -#include "ast/nodes/expression.h" #include "ast/nodes/blocks.h" +#include "ast/nodes/expression.h" #include "ast/tokens/token_set.h" #include diff --git a/packages/compiler/src/ast/nodes/types/enum_type.cpp b/packages/compiler/src/ast/nodes/types/enum_type.cpp new file mode 100644 index 00000000..cef50449 --- /dev/null +++ b/packages/compiler/src/ast/nodes/types/enum_type.cpp @@ -0,0 +1,186 @@ +#include "ast/casting.h" +#include "ast/parsing_context.h" +#include "ast/nodes/blocks.h" +#include "ast/nodes/literal_values.h" +#include "ast/nodes/types.h" +#include "ast/nodes/type_definition.h" +#include "ast/tokens/token.h" +#include "ast/tokens/token_set.h" + +#include +#include +#include + +using namespace stride::ast; + +AstEnumType::AstEnumType( + const SourceFragment& source, + const std::shared_ptr& context, + std::string enum_name, + std::vector members, + const int flags +) : + IAstType(source, context, flags), + _members(std::move(members)), + _name(std::move(enum_name)) {} + +/** + * Parses member entries with the following sequence: + * + * IDENTIFIER: ; + * + */ +EnumMemberPair stride::ast::parse_enumerable_member( + const std::shared_ptr& context, + TokenSet& set, + size_t element_index +) +{ + const auto member_name_tok = set.expect(TokenType::IDENTIFIER); + auto member_name = member_name_tok.get_lexeme(); + + // Using index as element value if no explicit value is provided, and allowing optional trailing comma + if (!set.has_next() || !set.peek_next_eq(TokenType::COLON)) + { + auto value = std::make_unique( + member_name_tok.get_source_fragment(), + context, + PrimitiveType::INT32, + element_index + ); + + return { + member_name, + std::move(value) + }; + } + + set.expect(TokenType::COLON, "Expected a colon after enum member name"); + + auto value = parse_literal_optional(context, set); + + if (!value.has_value()) + set.throw_error("Expected a literal value for enum member"); + + return { + member_name, + std::move(value.value()) + }; +} + +std::unique_ptr stride::ast::parse_enum_type_definition( + const std::shared_ptr& context, + TokenSet& set, + [[maybe_unused]] VisibilityModifier modifier +) +{ + const auto reference_token = set.expect(TokenType::KEYWORD_ENUM); + const auto enumerable_name = set.expect(TokenType::IDENTIFIER).get_lexeme(); + + auto enum_body_subset = collect_block_required(set, "Expected a block in enum declaration"); + + std::vector members; + + members.push_back(parse_enumerable_member(context, enum_body_subset, 0)); + + for (size_t i = 1; enum_body_subset.has_next(); ++i) + { + enum_body_subset.expect(TokenType::COMMA, "Expected a comma between enum members"); + members.push_back(parse_enumerable_member(context, enum_body_subset, i)); + } + + auto type = std::make_unique( + reference_token.get_source_fragment(), + context, + enumerable_name, + std::move(members) + ); + + context->define_type( + Symbol(reference_token.get_source_fragment(), + context->get_name(), + enumerable_name), + type->clone_ty(), + {}, + modifier + ); + + return std::make_unique( + reference_token.get_source_fragment(), + context, + enumerable_name, + std::move(type), + modifier + ); +} + + +std::unique_ptr AstEnumType::clone() +{ + std::vector cloned_members; + cloned_members.reserve(this->_members.size()); + + for (const auto& [field_name, field_ty] : this->_members) + { + cloned_members.emplace_back(field_name, field_ty->clone()); + } + + return std::make_unique( + this->get_source_fragment(), + this->get_context(), + this->get_name(), + std::move(cloned_members), + this->get_flags() + ); +} + +llvm::Value* AstEnumType::codegen(llvm::Module* module, llvm::IRBuilderBase* builder) +{ + return nullptr; +} + +bool AstEnumType::equals(IAstType* other) +{ + if (const auto other_enum = cast_type(other)) + { + return this->get_name() == other_enum->get_name(); + } + return false; +} + +llvm::Type* AstEnumType::get_llvm_type_impl(llvm::Module* module) +{ + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format("Cannot get LLVM type for enum type '{}'", this->get_name()), + this->get_source_fragment() + ); +} + +bool AstEnumType::is_assignable_to_impl(IAstType* other) +{ + return false; // Already managed by `AstType` (equality check) +} + +void AstEnumType::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) +{ + for (const auto& val : this->_members | std::views::values) + { + val->resolve_forward_references(module, builder); + } +} + +std::string AstEnumType::to_string() +{ + std::vector members; + + for (const auto& [field_name, field_val] : this->get_members()) + { + members.push_back(std::format("{} = {}", field_name, field_val->to_string())); + } + + return std::format( + "Enumerable {} (\n {}\n)", + this->get_name(), + join(members, ",\n ")); +} From 0163eb077655a887c02210c399c0148f9d6d640b Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 09:09:50 +0100 Subject: [PATCH 02/31] Partially added support for generic overload function lookup and codegen --- packages/compiler/include/ast/generics.h | 6 +- .../compiler/include/ast/nodes/expression.h | 12 +- .../include/ast/nodes/function_declaration.h | 40 ++-- packages/compiler/include/ast/nodes/types.h | 19 +- .../compiler/include/ast/parsing_context.h | 52 ++--- .../src/ast/context/function_registry.cpp | 55 ++++- .../src/ast/context/parsing_context.cpp | 17 -- packages/compiler/src/ast/generics.cpp | 11 + .../src/ast/nodes/functions/function_call.cpp | 124 ++++++---- .../nodes/functions/function_declaration.cpp | 216 ++++++++++++++---- packages/compiler/src/ast/type_inference.cpp | 10 +- 11 files changed, 390 insertions(+), 172 deletions(-) diff --git a/packages/compiler/include/ast/generics.h b/packages/compiler/include/ast/generics.h index 993dc718..440a3bcc 100644 --- a/packages/compiler/include/ast/generics.h +++ b/packages/compiler/include/ast/generics.h @@ -1,4 +1,5 @@ #pragma once + #include #include #include @@ -18,8 +19,7 @@ namespace stride::ast class IAstType; class TokenSet; - using GenericParameter = std::string; - using GenericParameterList = std::vector; + using GenericParameterList = std::vector; using GenericTypeList = std::vector>; #define EMPTY_GENERIC_PARAMETER_LIST (GenericParameterList{}) @@ -45,4 +45,6 @@ namespace stride::ast AstObjectType* type, const definition::TypeDefinition* type_definition ); + + GenericTypeList copy_generic_type_list(const GenericTypeList& list); } diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index a46928fd..95231d3e 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -20,6 +20,7 @@ namespace stride::ast namespace definition { + class FunctionDefinition; class IDefinition; } @@ -377,8 +378,11 @@ namespace stride::ast { ExpressionList _arguments; std::unique_ptr _function_name_identifier; + GenericTypeList _generic_type_arguments{}; int _flags; + definition::FunctionDefinition* _definition = nullptr; + public: explicit AstFunctionCall( const std::shared_ptr& context, @@ -391,6 +395,9 @@ namespace stride::ast _function_name_identifier(std::move(function_name_identifier)), _flags(flags) {} + [[nodiscard]] + definition::FunctionDefinition* get_function_definition(); + [[nodiscard]] const ExpressionList& get_arguments() const { @@ -444,6 +451,9 @@ namespace stride::ast void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override; + [[nodiscard]] + const GenericTypeList& get_generic_type_arguments(); + private: [[nodiscard]] std::string format_function_name() const; @@ -458,7 +468,7 @@ namespace stride::ast [[nodiscard]] llvm::Function* resolve_regular_callee( llvm::Module* module - ) const; + ); llvm::Value* codegen_regular_function_call( llvm::Function* callee, diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index 9dd7dbe7..13bb7a34 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -4,6 +4,7 @@ #include "blocks.h" #include "expression.h" #include "ast/modifiers.h" +#include "ast/parsing_context.h" #include @@ -78,6 +79,7 @@ namespace stride::ast std::vector _captured_variables; VisibilityModifier _visibility; GenericParameterList _generic_parameters; + definition::FunctionDefinition* _function_definition = nullptr; int _flags; /// Cached LLVM function pointer for anonymous functions. @@ -99,7 +101,7 @@ namespace stride::ast std::unique_ptr return_type, const VisibilityModifier visibility, const int flags, - const GenericParameterList& generic_parameters + GenericParameterList generic_parameters ) : IAstExpression(source, context), _body(std::move(body)), @@ -107,7 +109,7 @@ namespace stride::ast _parameters(std::move(parameters)), _annotated_return_type(std::move(return_type)), _visibility(visibility), - _generic_parameters(generic_parameters), + _generic_parameters(std::move(generic_parameters)), _flags(flags) {} [[nodiscard]] @@ -117,11 +119,23 @@ namespace stride::ast } [[nodiscard]] - const std::string& get_scoped_function_name() const + std::vector> get_parameter_types() const; + + [[nodiscard]] + const std::string& get_internalized_function_name() const { return this->_symbol.internal_name; } + /// Returns a list of overlaods for this function. For example, whenever the + /// function is defined with generic parameters, there will be several overlaods generated + /// for each generic instantiation. The internalized name of each overload is returned by this function. + [[nodiscard]] + std::vector get_internalized_overload_names(); + + [[nodiscard]] + std::string get_internalized_overload_name(const GenericTypeList& overload) const; + [[nodiscard]] AstBlock* get_body() override { @@ -129,18 +143,7 @@ namespace stride::ast } [[nodiscard]] - std::vector> get_parameters() const - { - std::vector> cloned_params; - cloned_params.reserve(this->_parameters.size()); - - for (const auto& param : this->_parameters) - { - cloned_params.push_back(param->clone_as()); - } - - return cloned_params; - } + std::vector> get_parameters() const; /// Returns a non-owning const reference to the parameter list, avoiding the /// clone overhead of get_parameters() when only read access is needed. @@ -221,6 +224,8 @@ namespace stride::ast this->_captured_variables.push_back(symbol); } + definition::FunctionDefinition* get_function_definition(); + llvm::Value* codegen( llvm::Module* module, llvm::IRBuilderBase* builder) override; @@ -234,9 +239,10 @@ namespace stride::ast std::unique_ptr clone() override; private: - llvm::FunctionType* get_llvm_function_type( + llvm::FunctionType* get_overloaded_llvm_function_type( llvm::Module* module, - std::vector captured_variables + std::vector captured_variables, + const GenericTypeList& generic_instantiation_types = {} ) const; }; diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index 7e8753bd..0fd2dc20 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -4,13 +4,11 @@ #include "formatting.h" #include "ast/flags.h" #include "ast/generics.h" -#include "ast/modifiers.h" #include #include #include #include -#include namespace llvm { @@ -343,6 +341,7 @@ namespace stride::ast { std::vector> _parameters; std::unique_ptr _return_type; + GenericParameterList _generic_param_names; public: explicit AstFunctionType( @@ -350,11 +349,13 @@ namespace stride::ast const std::shared_ptr& context, std::vector> parameters, std::unique_ptr return_type, + GenericParameterList generic_parameter_names = {}, const int flags = SRFLAG_NONE ) : IAstType(source, context, flags | SRFLAG_TYPE_FUNCTION | SRFLAG_TYPE_PTR), _parameters(std::move(parameters)), - _return_type(std::move(return_type)) {} + _return_type(std::move(return_type)), + _generic_param_names(std::move(generic_parameter_names)) {} [[nodiscard]] const std::vector>& get_parameter_types() const @@ -362,6 +363,18 @@ namespace stride::ast return _parameters; } + [[nodiscard]] + const GenericParameterList& get_generic_parameter_names() const + { + return this->_generic_param_names; + } + + [[nodiscard]] + bool is_generic() const + { + return this->_generic_param_names.empty(); + } + [[nodiscard]] const std::unique_ptr& get_return_type() const { diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 94facadb..5854e364 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -77,31 +77,6 @@ namespace stride::ast virtual std::unique_ptr clone() const = 0; }; - class IdentifiableSymbolDef : public IDefinition - { - SymbolType _type; - - public: - explicit IdentifiableSymbolDef( - const SymbolType type, - const Symbol& symbol - ) : - IDefinition(symbol, VisibilityModifier::PRIVATE), - _type(type) {} - - [[nodiscard]] - SymbolType get_symbol_type() const - { - return this->_type; - } - - [[nodiscard]] - std::unique_ptr clone() const override - { - return std::make_unique(_type, get_symbol()); - } - }; - class TypeDefinition : public IDefinition { @@ -185,7 +160,9 @@ namespace stride::ast : public IDefinition { std::unique_ptr _function_type; + std::vector _generic_type_overloads{}; int _flags; + llvm::Function* _llvm_function = nullptr; public: @@ -223,6 +200,17 @@ namespace stride::ast return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; } + void add_generic_instantiation(GenericTypeList generic_types); + + [[nodiscard]] + const std::vector& get_generic_instantiations() const + { + return this->_generic_type_overloads; + } + + [[nodiscard]] + bool has_generic_instantiation(const GenericTypeList& generic_types) const; + ~FunctionDefinition() override = default; bool matches_type_signature(const std::string& name, const AstFunctionType* signature) const; @@ -241,7 +229,8 @@ namespace stride::ast [[nodiscard]] bool matches_parameter_signature( const std::string& internal_function_name, - const std::vector>& other_parameter_types + const std::vector>& other_parameter_types, + size_t generic_argument_count ) const; [[nodiscard]] @@ -269,8 +258,7 @@ namespace stride::ast // Stack of loop blocks for break and continue: pair // This isn't used during parsing, hence it not needing to be moved when creating a new ParsingContext. - static inline - std::vector> control_flow_loop_blocks; + static inline std::vector> control_flow_loop_blocks; public: explicit ParsingContext( @@ -342,7 +330,8 @@ namespace stride::ast [[nodiscard]] std::optional get_function_definition( const std::string& function_name, - const std::vector>& parameter_types + const std::vector>& parameter_types, + size_t instantiated_generic_count = 0 ) const; std::optional get_function_definition( @@ -358,11 +347,6 @@ namespace stride::ast [[nodiscard]] std::optional get_object_type(const std::string& name) const; - [[nodiscard]] - const definition::IdentifiableSymbolDef* get_symbol_def( - const std::string& symbol_name - ) const; - [[nodiscard]] std::optional> get_definition_by_internal_name( const std::string& internal_name) const; diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 2e1e6a21..e04eedc2 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -9,7 +9,8 @@ using namespace stride::ast::definition; std::optional ParsingContext::get_function_definition( const std::string& function_name, - const std::vector>& parameter_types + const std::vector>& parameter_types, + size_t instantiated_generic_count ) const { for (const auto& global_scope = this->traverse_to_root(); @@ -17,7 +18,10 @@ std::optional ParsingContext::get_function_definition( { if (auto* fn_def = dynamic_cast(symbol_def.get())) { - if (fn_def->matches_parameter_signature(function_name, parameter_types)) + if (fn_def->matches_parameter_signature( + function_name, + parameter_types, + instantiated_generic_count)) { return fn_def; } @@ -28,6 +32,7 @@ std::optional ParsingContext::get_function_definition( std::optional ParsingContext::get_function_definition( const std::string& function_name, + // We might call this function with an anonymous type, hence not having `AstFunctionType` IAstType* function_type ) const { @@ -62,26 +67,36 @@ bool FunctionDefinition::matches_type_signature( const auto& other_params = signature->get_parameter_types(); - return matches_parameter_signature(name, other_params); + return matches_parameter_signature( + name, + other_params, + signature->get_generic_parameter_names().size() + ); } bool FunctionDefinition::matches_parameter_signature( const std::string& internal_function_name, - const std::vector>& other_parameter_types + const std::vector>& other_parameter_types, + const size_t generic_argument_count ) const { if (this->get_internal_symbol_name() != internal_function_name) return false; + // Ensure we have the right generic overload variant of this function. + // This allows us to create several functions with the same signature / name, but with + // different generic parameter overloads. + if (this->_function_type->get_generic_parameter_names().size() != generic_argument_count) + return false; + const auto& self_params = this->_function_type->get_parameter_types(); if ((this->get_flags() & SRFLAG_FN_TYPE_VARIADIC) != 0) { if (other_parameter_types.size() < self_params.size()) return false; - } else { @@ -143,3 +158,33 @@ bool ParsingContext::is_function_defined_globally( } ); } + + +bool FunctionDefinition::has_generic_instantiation(const std::vector>& generic_types) const +{ + for (const auto& instantiation : this->_generic_type_overloads) + { + bool all_equal = true; + for (size_t i = 0; i < generic_types.size(); i++) + { + if (!instantiation[i]->equals(generic_types[i].get())) + { + all_equal = false; + break; + } + } + if (all_equal) + { + return true; + } + } + return false; +} + +void FunctionDefinition::add_generic_instantiation(GenericTypeList generic_types) +{ + if (has_generic_instantiation(generic_types)) + return; // Already instantiated + + this->_generic_type_overloads.push_back(std::move(generic_types)); +} diff --git a/packages/compiler/src/ast/context/parsing_context.cpp b/packages/compiler/src/ast/context/parsing_context.cpp index b6a1cbc0..81daeb2d 100644 --- a/packages/compiler/src/ast/context/parsing_context.cpp +++ b/packages/compiler/src/ast/context/parsing_context.cpp @@ -59,23 +59,6 @@ std::optional> ParsingContext::get_definition_by_in return std::nullopt; } -const IdentifiableSymbolDef* ParsingContext::get_symbol_def( - const std::string& symbol_name) const -{ - for (const auto& symbol_def : this->_symbols) - { - if (const auto* identifier_def = - dynamic_cast(symbol_def.get())) - { - if (identifier_def->get_internal_symbol_name() == symbol_name) - { - return identifier_def; - } - } - } - return nullptr; -} - static size_t levenshtein_distance(const std::string& a, const std::string& b) { const size_t len_a = a.size(); diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index a272424f..25a0037d 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -277,3 +277,14 @@ std::unique_ptr stride::ast::instantiate_generic_type( std::move(resolved_args) ); } + +GenericTypeList stride::ast::copy_generic_type_list(const GenericTypeList& list) +{ + GenericTypeList copy; + copy.reserve(list.size()); + for (const auto& type : list) + { + copy.push_back(type->clone_as()); + } + return copy; +} diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/function_call.cpp index fd43dda6..541c80e7 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_call.cpp @@ -171,60 +171,54 @@ llvm::Value* AstFunctionCall::codegen( ); } -llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) const +llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) { - if (const auto definition = - this->get_context()->get_function_definition(this->get_scoped_function_name(), this->get_argument_types()); - definition.has_value()) + const auto& function_definition = this->get_function_definition(); + if (llvm::Function* callee = module->getFunction(function_definition->get_internal_symbol_name())) { - if (llvm::Function* callee = module->getFunction(definition.value()->get_internal_symbol_name())) - { - return callee; - } - - const auto fn_def = definition.value(); - const auto fn_type = fn_def->get_type(); - std::vector param_types; - param_types.reserve(fn_type->get_parameter_types().size()); + return callee; + } - for (const auto& param : fn_type->get_parameter_types()) - { - param_types.push_back(param->get_llvm_type(module)); - } + const auto fn_def = function_definition; + const auto fn_type = fn_def->get_type(); + std::vector param_types; + param_types.reserve(fn_type->get_parameter_types().size()); - llvm::Type* ret_type = fn_type->get_return_type()->get_llvm_type(module); + for (const auto& param : fn_type->get_parameter_types()) + { + param_types.push_back(param->get_llvm_type(module)); + } - // When propagating varargs (call has '...'), the callee receives the caller's - // va_list as an extra fixed pointer argument rather than as true variadic args. - // This lets the callee forward the va_list directly to vprintf/vscanf-style APIs. - bool llvm_is_variadic = false; - if (this->is_variadic()) - { - param_types.push_back(llvm::PointerType::get(module->getContext(), 0)); - } - else - { - llvm_is_variadic = fn_def->is_variadic(); - } + llvm::Type* ret_type = fn_type->get_return_type()->get_llvm_type(module); - llvm::FunctionType* llvm_fn_type = llvm::FunctionType::get( - ret_type, - param_types, - llvm_is_variadic - ); + // When propagating varargs (call has '...'), the callee receives the caller's + // va_list as an extra fixed pointer argument rather than as true variadic args. + // This lets the callee forward the va_list directly to vprintf/vscanf-style APIs. + bool llvm_is_variadic = false; + if (this->is_variadic()) + { + param_types.push_back(llvm::PointerType::get(module->getContext(), 0)); + } + else + { + llvm_is_variadic = fn_def->is_variadic(); + } - // If we are calling a variadic function and propagating '...', - // the callee is actually a non-variadic function that takes a va_list. - // But we should use the actual function name for the lookup. - auto callee_cand = module->getOrInsertFunction( - fn_def->get_internal_symbol_name(), - llvm_fn_type - ); + llvm::FunctionType* llvm_fn_type = llvm::FunctionType::get( + ret_type, + param_types, + llvm_is_variadic + ); - return llvm::dyn_cast(callee_cand.getCallee()); - } + // If we are calling a variadic function and propagating '...', + // the callee is actually a non-variadic function that takes a va_list. + // But we should use the actual function name for the lookup. + auto callee_cand = module->getOrInsertFunction( + fn_def->get_internal_symbol_name(), + llvm_fn_type + ); - return nullptr; + return llvm::dyn_cast(callee_cand.getCallee()); } llvm::Value* AstFunctionCall::codegen_regular_function_call( @@ -366,6 +360,7 @@ llvm::Value* AstFunctionCall::codegen_anonymous_function_call( // First: check if the variable's internal name maps to a named function in the // symbol table with a matching type signature. If so, call it directly without // any pointer indirection. + // We sadly cannot use `get_function_definition` here if (this->get_context()->get_function_definition( field_def->get_internal_symbol_name(), field_def->get_type()).has_value()) @@ -610,6 +605,17 @@ void AstFunctionCall::validate() void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) { + // Add generic types to function definition's generic instantiations + if (!this->_generic_type_arguments.empty()) + { + const auto& definition = this->get_function_definition(); + + definition->add_generic_instantiation( + copy_generic_type_list(this->_generic_type_arguments) + ); + // nice, got that over with + } + for (const auto& arg : this->_arguments) { arg->resolve_forward_references(module, builder); @@ -637,6 +643,34 @@ std::vector> AstFunctionCall::get_argument_types() con return param_types; } +const GenericTypeList& AstFunctionCall::get_generic_type_arguments() +{ + return this->_generic_type_arguments; +} + +FunctionDefinition* AstFunctionCall::get_function_definition() +{ + if (this->_definition != nullptr) + return this->_definition; + + if (const auto def = this->get_context()->get_function_definition( + this->get_scoped_function_name(), + this->get_argument_types(), + this->get_generic_type_arguments().size() + ); + def.has_value()) + { + this->_definition = def.value(); + return this->_definition; + } + + throw parsing_error( + ErrorType::REFERENCE_ERROR, + std::format("Function '{}' was not found in this scope", this->format_function_name()), + this->get_source_fragment() + ); +} + std::unique_ptr AstFunctionCall::clone() { ExpressionList cloned_args; diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp index 850e3552..4329a63a 100644 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp @@ -515,6 +515,19 @@ void collect_free_variables( } } +std::vector> IAstFunction::get_parameters() const +{ + std::vector> cloned_params; + cloned_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + cloned_params.push_back(param->clone_as()); + } + + return cloned_params; +} + void IAstFunction::validate() { if (this->is_anonymous()) @@ -679,13 +692,13 @@ llvm::Value* IAstFunction::codegen( } else { - function = module->getFunction(this->get_scoped_function_name()); + function = module->getFunction(this->get_internalized_function_name()); if (!function) { module->print(llvm::errs(), nullptr); throw parsing_error( ErrorType::COMPILATION_ERROR, - std::format("Function symbol '{}' missing", this->get_scoped_function_name()), + std::format("Function symbol '{}' missing", this->get_internalized_function_name()), this->get_source_fragment() ); } @@ -722,22 +735,22 @@ llvm::Value* IAstFunction::codegen( // Captured variable handling // Map captured variables to function arguments with __capture_ prefix // - auto arg_it = function->arg_begin(); + auto fn_parameter_argument = function->arg_begin(); for (const auto& capture : this->_captured_variables) { - if (arg_it != function->arg_end()) + if (fn_parameter_argument != function->arg_end()) { - arg_it->setName(closures::format_captured_variable_name(capture.internal_name)); + fn_parameter_argument->setName(closures::format_captured_variable_name(capture.internal_name)); // Create alloca with __capture_ prefix so identifier lookup can find it llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( - arg_it->getType(), + fn_parameter_argument->getType(), nullptr, closures::format_captured_variable_name_internal(capture.internal_name) ); - builder->CreateStore(arg_it, alloca); - ++arg_it; + builder->CreateStore(fn_parameter_argument, alloca); + ++fn_parameter_argument; } } @@ -747,21 +760,21 @@ llvm::Value* IAstFunction::codegen( // for (const auto& param : this->_parameters) { - if (arg_it != function->arg_end()) + if (fn_parameter_argument != function->arg_end()) { - arg_it->setName(param->get_name() + ".arg"); + fn_parameter_argument->setName(param->get_name() + ".arg"); // Create a memory slot on the stack for the parameter llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( - arg_it->getType(), + fn_parameter_argument->getType(), nullptr, param->get_name() ); // Store the initial argument value into the alloca - builder->CreateStore(arg_it, alloca); + builder->CreateStore(fn_parameter_argument, alloca); - ++arg_it; + ++fn_parameter_argument; } } @@ -888,54 +901,76 @@ void IAstFunction::resolve_forward_references( } } - llvm::FunctionType* function_type = this->get_llvm_function_type(module, captured_types); - - if (const auto fn = module->getFunction(this->get_scoped_function_name()); - fn != nullptr && fn->getFunctionType() != function_type) - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format( - "Function symbol '{}' already exists with a different signature", - this->get_scoped_function_name() - ), - this->get_source_fragment() - ); - } - const auto linkage = this->_visibility == VisibilityModifier::PRIVATE ? llvm::Function::PrivateLinkage : llvm::Function::ExternalLinkage; - // Anonymous functions are created with a stable name prefix and a numeric ID - // so they are easily findable in the module without ambiguity. - const std::string llvm_function_name = this->get_scoped_function_name(); + if (const auto& definition = this->get_function_definition(); + !definition->get_generic_instantiations().empty()) + { + for (const auto& overload : definition->get_generic_instantiations()) + { + const auto overloaded_fn_name = get_internalized_overload_name(overload); + llvm::FunctionType* generic_function_type = this->get_overloaded_llvm_function_type( + module, + captured_types, + overload + ); - llvm::Function* created_fn = llvm::Function::Create( - function_type, - linkage, - llvm_function_name, - module - ); + llvm::Function* generic_function = llvm::Function::Create( + generic_function_type, + linkage, + overloaded_fn_name, + module + ); - if (this->is_anonymous()) + if (this->is_anonymous()) + { + generic_function->addFnAttr("stride.anonymous"); + this->_llvm_function = generic_function; + } + } + } + else { - created_fn->addFnAttr("stride.anonymous"); - this->_llvm_function = created_fn; + // Anonymous functions are created with a stable name prefix and a numeric ID + // so they are easily findable in the module without ambiguity. + const std::string llvm_function_name = this->get_internalized_function_name(); + + llvm::FunctionType* function_type = this->get_overloaded_llvm_function_type(module, captured_types); + + llvm::Function* function = llvm::Function::Create( + function_type, + linkage, + llvm_function_name, + module + ); + + if (this->is_anonymous()) + { + function->addFnAttr("stride.anonymous"); + this->_llvm_function = function; + } } this->_body->resolve_forward_references(module, builder); } -llvm::FunctionType* IAstFunction::get_llvm_function_type( +llvm::FunctionType* IAstFunction::get_overloaded_llvm_function_type( llvm::Module* module, - std::vector captured_variables + std::vector captured_variables, + const GenericTypeList& generic_instantiation_types ) const { std::vector base_parameter_types; for (const auto& param : this->_parameters) { - base_parameter_types.push_back(param->get_type()->get_llvm_type(module)); + const auto& resolved_generic_param_type = resolve_generics( + param->get_type(), + this->_generic_parameters, + generic_instantiation_types + ); + base_parameter_types.push_back(resolved_generic_param_type->get_llvm_type(module)); } std::vector parameter_types; @@ -956,7 +991,98 @@ llvm::FunctionType* IAstFunction::get_llvm_function_type( ); } - return llvm::FunctionType::get(return_type, parameter_types, this->is_variadic()); + const auto& llvm_function_ty = llvm::FunctionType::get(return_type, parameter_types, this->is_variadic());; + + if (const auto fn = module->getFunction(this->get_internalized_function_name()); + fn != nullptr && fn->getFunctionType() != llvm_function_ty) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "Function symbol '{}' already exists with a different signature", + this->get_internalized_function_name() + ), + this->get_source_fragment() + ); + } + + return llvm_function_ty; +} + +std::vector> IAstFunction::get_parameter_types() const +{ + std::vector> types; + types.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + types.push_back(param->get_type()->clone_ty()); + } + + return types; +} + +FunctionDefinition* IAstFunction::get_function_definition() +{ + if (this->_function_definition != nullptr) + return this->_function_definition; + + const auto& definition = this->get_context()->get_function_definition( + this->get_function_name(), + this->get_parameter_types(), + this->get_generic_parameters().size() + ); + + if (!definition.has_value()) + { + throw parsing_error( + ErrorType::REFERENCE_ERROR, + std::format("Function definition for '{}' not found in context", this->get_internalized_function_name()), + this->get_source_fragment() + ); + } + + this->_function_definition = definition.value(); + return this->_function_definition; +} + +std::string IAstFunction::get_internalized_overload_name(const GenericTypeList& overload) const +{ + std::vector generic_instantiation_type_names; + generic_instantiation_type_names.reserve(overload.size()); + + for (const auto& type : overload) + { + generic_instantiation_type_names.push_back(type->get_type_name()); + } + + return std::format( + "{}${}", + this->get_internalized_function_name(), + join(generic_instantiation_type_names, "_") + ); +} + +std::vector IAstFunction::get_internalized_overload_names() +{ + const auto& definition = this->get_function_definition(); + + // If the function is not generic, we just return a singular name (the regular internalized name) + if (!definition->get_type()->is_generic()) + { + return { this->get_internalized_function_name() }; + } + + std::vector overload_names; + + for (const auto& overload : definition->get_generic_instantiations()) + { + overload_names.push_back( + get_internalized_overload_name(overload) + ); + } + + return overload_names; } std::unique_ptr AstFunctionParameter::clone() @@ -1009,7 +1135,7 @@ std::string AstFunctionDeclaration::to_string() return std::format( "FunctionDeclaration(name: {}(internal: {}), params: [{}], body: {}{} -> {})", this->get_function_name(), - this->get_scoped_function_name(), + this->get_internalized_function_name(), params, body_str, this->is_extern() ? " (extern)" : "", diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 40077d3e..4e3873b4 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -18,7 +18,8 @@ std::unique_ptr stride::ast::infer_expression_literal_type(const AstLi return std::make_unique( literal->get_source_fragment(), literal->get_context(), - literal->get_primitive_type()); + literal->get_primitive_type() + ); } std::unique_ptr stride::ast::infer_function_call_return_type(AstFunctionCall* fn_call) @@ -26,8 +27,11 @@ std::unique_ptr stride::ast::infer_function_call_return_type(AstFuncti /// --- Basic function lookup, find based on parameter signature (ignoring return type) const auto& context = fn_call->get_context(); - if (const auto fn_def = - context->get_function_definition(fn_call->get_scoped_function_name(), fn_call->get_argument_types()); + if (const auto fn_def = context->get_function_definition( + fn_call->get_scoped_function_name(), + fn_call->get_argument_types(), + fn_call->get_generic_type_arguments().size() + ); fn_def.has_value()) { return fn_def.value()->get_type()->get_return_type()->clone_ty(); From cd8adefae9aa047046cee79edb484c6df08eb82f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 10:09:56 +0100 Subject: [PATCH 03/31] Rename function arguments --- example.sr | 23 ++++++++++++------- .../include/ast/nodes/function_declaration.h | 7 ++++-- .../src/ast/context/function_registry.cpp | 2 +- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/example.sr b/example.sr index d60698a8..25d74fcd 100644 --- a/example.sr +++ b/example.sr @@ -1,18 +1,25 @@ import System::{ - io::print, +io::print }; -enum Test { - First: 123, - Second: 456, - Third: 789 +type Array = T[]; + +type Car = { + drive: () -> void; + names: Array; +}; + +fn makeCar(): Car { + const names = ["Toyota", "Honda", "Ford"]; + return Car::{ drive: (): void -> { + io::print("Driving: %s, %s, %s", names[0], names[1], names[2]); + }, names, }; } fn main(): i32 { - const k: Test = Test::First; - - io::print("k[0] = %d", k); + const myCar = makeCar(); + myCar.drive(); return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index 13bb7a34..c886ed34 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -101,7 +101,7 @@ namespace stride::ast std::unique_ptr return_type, const VisibilityModifier visibility, const int flags, - GenericParameterList generic_parameters + GenericParameterList generic_parameters ) : IAstExpression(source, context), _body(std::move(body)), @@ -309,7 +309,10 @@ namespace stride::ast ~AstLambdaFunctionExpression() override = default; - std::string get_mangled_name() const { return ""; }; // TODO: Implement + std::string get_mangled_name() const + { + return ""; + }; // TODO: Implement }; std::unique_ptr parse_fn_declaration( diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index e04eedc2..f54f8a06 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -10,7 +10,7 @@ using namespace stride::ast::definition; std::optional ParsingContext::get_function_definition( const std::string& function_name, const std::vector>& parameter_types, - size_t instantiated_generic_count + const size_t instantiated_generic_count ) const { for (const auto& global_scope = this->traverse_to_root(); From 87921f16c39591d540795c0fe63d376cefb9074a Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 13:25:50 +0100 Subject: [PATCH 04/31] Add support for generic function overloads and improve handling of LLVM function definitions --- example.sr | 14 +- packages/compiler/include/ast/generics.h | 2 + .../compiler/include/ast/nodes/expression.h | 14 +- .../include/ast/nodes/function_declaration.h | 26 +- .../compiler/include/ast/parsing_context.h | 12 +- packages/compiler/include/ast/tokens/token.h | 26 ++ .../compiler/include/ast/type_inference.h | 2 +- .../src/ast/context/function_registry.cpp | 16 +- packages/compiler/src/ast/generics.cpp | 22 ++ .../src/ast/nodes/expressions/expression.cpp | 2 +- .../src/ast/nodes/functions/function_call.cpp | 78 +++- .../nodes/functions/function_declaration.cpp | 351 ++++++++---------- .../src/ast/nodes/types/alias_type.cpp | 1 + .../src/ast/nodes/types/function_type.cpp | 15 +- .../src/ast/traversal/function_visitor.cpp | 15 +- packages/compiler/src/ast/type_inference.cpp | 17 +- 16 files changed, 375 insertions(+), 238 deletions(-) diff --git a/example.sr b/example.sr index 25d74fcd..cf56e0de 100644 --- a/example.sr +++ b/example.sr @@ -1,5 +1,5 @@ import System::{ -io::print + io::print }; type Array = T[]; @@ -9,10 +9,15 @@ type Car = { names: Array; }; +fn some_comparison(a: T, b: T): bool { + return a == b; +} + fn makeCar(): Car { - const names = ["Toyota", "Honda", "Ford"]; + const names = ["Toyota", "Honda", "Ford", "Toyota"]; + return Car::{ drive: (): void -> { - io::print("Driving: %s, %s, %s", names[0], names[1], names[2]); + io::print("\x1b[32mDriving the car"); }, names, }; } @@ -20,6 +25,9 @@ fn main(): i32 { const myCar = makeCar(); myCar.drive(); + io::print("Driving: %s, %s, %s", myCar.names[0], myCar.names[1], myCar.names[2]); + + io::print("Compared: %b", some_comparison(5, 5)); return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/generics.h b/packages/compiler/include/ast/generics.h index 440a3bcc..0e71ec13 100644 --- a/packages/compiler/include/ast/generics.h +++ b/packages/compiler/include/ast/generics.h @@ -47,4 +47,6 @@ namespace stride::ast ); GenericTypeList copy_generic_type_list(const GenericTypeList& list); + + std::string get_overloaded_function_name(std::string function_name, const GenericTypeList& overload_types); } diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index 95231d3e..96e4ee3a 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -378,7 +378,7 @@ namespace stride::ast { ExpressionList _arguments; std::unique_ptr _function_name_identifier; - GenericTypeList _generic_type_arguments{}; + GenericTypeList _generic_type_arguments; int _flags; definition::FunctionDefinition* _definition = nullptr; @@ -388,11 +388,13 @@ namespace stride::ast const std::shared_ptr& context, std::unique_ptr function_name_identifier, ExpressionList arguments, + GenericTypeList generic_type_arguments, const int flags = SRFLAG_NONE ) : IAstExpression(function_name_identifier->get_source_fragment(), context), _arguments(std::move(arguments)), _function_name_identifier(std::move(function_name_identifier)), + _generic_type_arguments(std::move(generic_type_arguments)), _flags(flags) {} [[nodiscard]] @@ -1033,15 +1035,13 @@ namespace stride::ast TokenSet& set); /// Parses a variable assignment statement - std::optional> - parse_variable_reassignment( + std::optional> parse_variable_reassignment( const std::shared_ptr& context, AstIdentifier* identifier, TokenSet& set); /// Parses a binary arithmetic operation using precedence climbing - std::optional> - parse_arithmetic_binary_operation_optional( + std::optional> parse_arithmetic_binary_operation_optional( const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs, @@ -1145,4 +1145,8 @@ namespace stride::ast /// Checks whether the next tokens begin a member access: `.identifier` bool is_member_accessor(const TokenSet& set); + + /// Checks whether the subsequent tokens can be considered a function call, after an identifier + /// An example would be function_name() or function_name() + bool is_direct_function_call(const TokenSet& set); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index c886ed34..c2561c45 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -68,6 +68,13 @@ namespace stride::ast * Function declaration definitions * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ + + struct GenericFunctionMetadata + { + std::string overload_function_name; + llvm::Function* llvm_function; + }; + class IAstFunction : public IAstContainer, public IAstExpression @@ -82,12 +89,6 @@ namespace stride::ast definition::FunctionDefinition* _function_definition = nullptr; int _flags; - /// Cached LLVM function pointer for anonymous functions. - /// Named functions are always looked up by their scoped name in the module, - /// but anonymous functions are created with an empty name (LLVM auto-assigns - /// a numeric ID), so we must track them by pointer instead. - llvm::Function* _llvm_function = nullptr; - friend class AstFunctionDeclaration; friend class AstFunctionParameter; @@ -127,14 +128,11 @@ namespace stride::ast return this->_symbol.internal_name; } - /// Returns a list of overlaods for this function. For example, whenever the - /// function is defined with generic parameters, there will be several overlaods generated - /// for each generic instantiation. The internalized name of each overload is returned by this function. - [[nodiscard]] - std::vector get_internalized_overload_names(); - + /// Returns a list of overloads for this function. For example, whenever the + /// function is defined with generic parameters, there will be several overloads generated + /// for each generic instantiation. This function returns the internalized name of each overload. [[nodiscard]] - std::string get_internalized_overload_name(const GenericTypeList& overload) const; + std::vector get_function_overload_metadata(); [[nodiscard]] AstBlock* get_body() override @@ -239,7 +237,7 @@ namespace stride::ast std::unique_ptr clone() override; private: - llvm::FunctionType* get_overloaded_llvm_function_type( + llvm::FunctionType* get_generic_instantiated_llvm_function_type( llvm::Module* module, std::vector captured_variables, const GenericTypeList& generic_instantiation_types = {} diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 5854e364..0463bace 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -156,11 +156,17 @@ namespace stride::ast } }; + struct GenericFunctionOverload + { + GenericTypeList types; + mutable llvm::Function* function; + }; + class FunctionDefinition : public IDefinition { std::unique_ptr _function_type; - std::vector _generic_type_overloads{}; + std::vector _generic_type_overloads{}; int _flags; llvm::Function* _llvm_function = nullptr; @@ -200,10 +206,10 @@ namespace stride::ast return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; } - void add_generic_instantiation(GenericTypeList generic_types); + void add_generic_instantiation(GenericTypeList generic_overload_types); [[nodiscard]] - const std::vector& get_generic_instantiations() const + const std::vector& get_generic_instantiations() const { return this->_generic_type_overloads; } diff --git a/packages/compiler/include/ast/tokens/token.h b/packages/compiler/include/ast/tokens/token.h index 690e1458..6d109769 100644 --- a/packages/compiler/include/ast/tokens/token.h +++ b/packages/compiler/include/ast/tokens/token.h @@ -466,6 +466,32 @@ namespace stride::ast { return _type == other; } + + bool is_type_token() const + { + switch (this->_type) + { + case TokenType::PRIMITIVE_UINT8: + case TokenType::PRIMITIVE_UINT16: + case TokenType::PRIMITIVE_UINT32: + case TokenType::PRIMITIVE_UINT64: + case TokenType::PRIMITIVE_INT8: + case TokenType::PRIMITIVE_INT16: + case TokenType::PRIMITIVE_INT32: + case TokenType::PRIMITIVE_INT64: + case TokenType::PRIMITIVE_FLOAT32: + case TokenType::PRIMITIVE_FLOAT64: + case TokenType::PRIMITIVE_BOOL: + case TokenType::PRIMITIVE_STRING: + case TokenType::PRIMITIVE_CHAR: + case TokenType::PRIMITIVE_VOID: + case TokenType::PRIMITIVE_AUTO: + case TokenType::IDENTIFIER: + return true; + default: + return false; + } + } }; extern std::vector tokenTypes; diff --git a/packages/compiler/include/ast/type_inference.h b/packages/compiler/include/ast/type_inference.h index 172941ca..c9fff24f 100644 --- a/packages/compiler/include/ast/type_inference.h +++ b/packages/compiler/include/ast/type_inference.h @@ -57,5 +57,5 @@ namespace stride::ast std::unique_ptr infer_array_accessor_type(const AstArrayMemberAccessor* accessor, int recursion_guard); std::unique_ptr infer_function_type( - const IAstFunction* expression); + const IAstFunction* function); } diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index f54f8a06..3191a639 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -88,8 +88,10 @@ const // Ensure we have the right generic overload variant of this function. // This allows us to create several functions with the same signature / name, but with // different generic parameter overloads. - if (this->_function_type->get_generic_parameter_names().size() != generic_argument_count) - return false; + // + // For generic overloads, we just check whether the name and generic count is equal. + if (!this->_function_type->get_generic_parameter_names().empty() && generic_argument_count > 0) + return true; const auto& self_params = this->_function_type->get_parameter_types(); @@ -162,12 +164,12 @@ bool ParsingContext::is_function_defined_globally( bool FunctionDefinition::has_generic_instantiation(const std::vector>& generic_types) const { - for (const auto& instantiation : this->_generic_type_overloads) + for (const auto& [types, function] : this->_generic_type_overloads) { bool all_equal = true; for (size_t i = 0; i < generic_types.size(); i++) { - if (!instantiation[i]->equals(generic_types[i].get())) + if (!types[i]->equals(generic_types[i].get())) { all_equal = false; break; @@ -181,10 +183,10 @@ bool FunctionDefinition::has_generic_instantiation(const std::vector_generic_type_overloads.push_back(std::move(generic_types)); + this->_generic_type_overloads.push_back({std::move(generic_overload_types)}); } diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index 25a0037d..57537914 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -159,6 +159,8 @@ std::unique_ptr stride::ast::resolve_generics( func_type->get_context(), std::move(resolved_params), resolve_generics(func_type->get_return_type().get(), param_names, instantiated_types), + EMPTY_GENERIC_PARAMETER_LIST, + // No more generics; they're resolved. func_type->get_flags() ); } @@ -288,3 +290,23 @@ GenericTypeList stride::ast::copy_generic_type_list(const GenericTypeList& list) } return copy; } + +std::string stride::ast::get_overloaded_function_name(std::string function_name, const GenericTypeList& overload_types) +{ + if (overload_types.empty()) + return function_name; + + std::vector generic_instantiation_type_names; + generic_instantiation_type_names.reserve(overload_types.size()); + + for (const auto& type : overload_types) + { + generic_instantiation_type_names.push_back(type->get_type_name()); + } + + return std::format( + "{}${}", + function_name, + join(generic_instantiation_type_names, "_") + ); +} diff --git a/packages/compiler/src/ast/nodes/expressions/expression.cpp b/packages/compiler/src/ast/nodes/expressions/expression.cpp index 0ce7383a..9e5f95d6 100644 --- a/packages/compiler/src/ast/nodes/expressions/expression.cpp +++ b/packages/compiler/src/ast/nodes/expressions/expression.cpp @@ -63,7 +63,7 @@ std::unique_ptr stride::ast::parse_inline_expression_part( } // Named function invocations, e.g., `(...)` or `::(...)` - if (set.peek_next_eq(TokenType::LPAREN)) + if (is_direct_function_call(set)) { result = parse_function_call(context, identifier.get(), set); } diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/function_call.cpp index 541c80e7..5b56ba83 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_call.cpp @@ -31,6 +31,7 @@ std::unique_ptr stride::ast::parse_function_call( ) { const auto reference_token = set.peek(-1); + auto generic_types = parse_generic_type_arguments(context, set); auto function_parameter_set = collect_parenthesized_block(set); ExpressionList function_arg_nodes; @@ -96,10 +97,53 @@ std::unique_ptr stride::ast::parse_function_call( context, identifier->clone_as(), std::move(function_arg_nodes), + std::move(generic_types), function_call_flags ); } +// The previous token must be an identifier, otherwise this is not a function call +// This logic is not handled here. +bool stride::ast::is_direct_function_call(const TokenSet& set) +{ + // Function call for sure, "identifier::other(` + if (set.peek_next_eq(TokenType::LPAREN)) + return true; + + // If the subsequent token is a LT, it might just be a generic function instantiation + // We have to do some lookahead to make sure this is the case + if (set.peek_next_eq(TokenType::LT)) + { + int depth = 0; + for (size_t offset = 0; set.position() + offset < set.size(); offset++) + { + switch (const auto next_token = set.at(set.position() + offset); + next_token.get_type()) + { + case TokenType::LT: + ++depth; + break; + case TokenType::GT: + --depth; + break; + + default: + // Optimization, where we know for sure it can't be part of a generic instantiation + if (!next_token.is_type_token() && next_token.get_type() != TokenType::COMMA) + { + return false; + } + } + + if (depth == 0) + { + return set.peek_eq(TokenType::LPAREN, offset + 1); + } + } + } + return false; +} + std::string AstFunctionCall::format_suggestion(const IDefinition* suggestion) { if (const auto fn_call = dynamic_cast(suggestion)) @@ -653,8 +697,12 @@ FunctionDefinition* AstFunctionCall::get_function_definition() if (this->_definition != nullptr) return this->_definition; + const auto& internal_function_name = get_overloaded_function_name( + this->get_scoped_function_name(), + this->get_generic_type_arguments()); + if (const auto def = this->get_context()->get_function_definition( - this->get_scoped_function_name(), + internal_function_name, this->get_argument_types(), this->get_generic_type_arguments().size() ); @@ -674,16 +722,26 @@ FunctionDefinition* AstFunctionCall::get_function_definition() std::unique_ptr AstFunctionCall::clone() { ExpressionList cloned_args; + GenericTypeList generic_type_list_cloned; + + generic_type_list_cloned.reserve(this->_generic_type_arguments.size()); cloned_args.reserve(this->get_arguments().size()); + for (const auto& arg : this->get_arguments()) { cloned_args.push_back(arg->clone_as()); } + for (const auto& generic_arg : this->get_generic_type_arguments()) + { + generic_type_list_cloned.push_back(generic_arg->clone_ty()); + } + return std::make_unique( this->get_context(), this->get_function_name_identifier()->clone_as(), std::move(cloned_args), + std::move(generic_type_list_cloned), this->_flags ); } @@ -706,11 +764,29 @@ std::string AstFunctionCall::get_formatted_call() const std::vector arg_names; arg_names.reserve(this->_arguments.size()); + std::vector formatted_generics; + formatted_generics.reserve(this->_generic_type_arguments.size()); + + for (const auto& generic_arg : this->_generic_type_arguments) + { + formatted_generics.push_back(generic_arg->get_type_name()); + } + for (const auto& arg : this->_arguments) { arg_names.push_back(arg->get_type()->get_type_name()); } + if (!formatted_generics.empty()) + { + return std::format( + "{}<{}>({})", + this->get_function_name(), + join(formatted_generics, ", "), + join(arg_names, ", ") + ); + } + return std::format( "{}({})", this->get_function_name(), diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp index 4329a63a..e059162f 100644 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp @@ -674,209 +674,203 @@ llvm::Value* IAstFunction::codegen( llvm::IRBuilderBase* builder ) { - // Anonymous functions are tracked by their cached pointer (they have no stable - // string name in the module). Named functions are looked up the normal way. llvm::Function* function = nullptr; - if (this->is_anonymous()) + + for (const auto& [function_name, llvm_function_val] : this->get_function_overload_metadata()) { - function = this->_llvm_function; - if (!function) + if (!llvm_function_val) { - module->print(llvm::errs(), nullptr); throw parsing_error( ErrorType::COMPILATION_ERROR, - "Anonymous function pointer missing — resolve_forward_references must run first", + std::format("Function symbol '{}' missing", function_name), this->get_source_fragment() ); } - } - else - { - function = module->getFunction(this->get_internalized_function_name()); - if (!function) + + + // If the function body has already been generated (has basic blocks), just return the function pointer + if (this->is_extern() || !llvm_function_val->empty()) { - module->print(llvm::errs(), nullptr); - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format("Function symbol '{}' missing", this->get_internalized_function_name()), - this->get_source_fragment() - ); + return llvm_function_val; } - } - - // If the function body has already been generated (has basic blocks), just return the function pointer - if (this->is_extern() || !function->empty()) - { - return function; - } - // Save the current insert point to restore it later - // This is important when generating nested lambdas - llvm::BasicBlock* saved_insert_block = builder->GetInsertBlock(); - llvm::BasicBlock::iterator saved_insert_point; - const bool has_insert_point = saved_insert_block != nullptr; + // Save the current insert point to restore it later + // This is important when generating nested lambdas + llvm::BasicBlock* saved_insert_block = builder->GetInsertBlock(); + llvm::BasicBlock::iterator saved_insert_point; + const bool has_insert_point = saved_insert_block != nullptr; - if (saved_insert_block) - { - saved_insert_point = builder->GetInsertPoint(); - } + if (saved_insert_block) + { + saved_insert_point = builder->GetInsertPoint(); + } - llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create( - module->getContext(), - "entry", - function - ); - builder->SetInsertPoint(entry_bb); + llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create( + module->getContext(), + "entry", + llvm_function_val + ); + builder->SetInsertPoint(entry_bb); - // We create a new builder for the prologue to ensure allocas are at the very top - llvm::IRBuilder prologue_builder(&function->getEntryBlock(), function->getEntryBlock().begin()); + // We create a new builder for the prologue to ensure allocas are at the very top + llvm::IRBuilder prologue_builder( + &llvm_function_val->getEntryBlock(), + llvm_function_val->getEntryBlock().begin() + ); - // - // Captured variable handling - // Map captured variables to function arguments with __capture_ prefix - // - auto fn_parameter_argument = function->arg_begin(); - for (const auto& capture : this->_captured_variables) - { - if (fn_parameter_argument != function->arg_end()) + // + // Captured variable handling + // Map captured variables to function arguments with __capture_ prefix + // + auto fn_parameter_argument = llvm_function_val->arg_begin(); + for (const auto& capture : this->_captured_variables) { - fn_parameter_argument->setName(closures::format_captured_variable_name(capture.internal_name)); + if (fn_parameter_argument != llvm_function_val->arg_end()) + { + fn_parameter_argument->setName(closures::format_captured_variable_name(capture.internal_name)); - // Create alloca with __capture_ prefix so identifier lookup can find it - llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( - fn_parameter_argument->getType(), - nullptr, - closures::format_captured_variable_name_internal(capture.internal_name) - ); + // Create alloca with __capture_ prefix so identifier lookup can find it + llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( + fn_parameter_argument->getType(), + nullptr, + closures::format_captured_variable_name_internal(capture.internal_name) + ); - builder->CreateStore(fn_parameter_argument, alloca); - ++fn_parameter_argument; + builder->CreateStore(fn_parameter_argument, alloca); + ++fn_parameter_argument; + } } - } - // - // Function parameter handling - // Here we define the parameters on the stack as memory slots for the function - // - for (const auto& param : this->_parameters) - { - if (fn_parameter_argument != function->arg_end()) + // + // Function parameter handling + // Here we define the parameters on the stack as memory slots for the function + // + for (const auto& param : this->_parameters) { - fn_parameter_argument->setName(param->get_name() + ".arg"); + if (fn_parameter_argument != llvm_function_val->arg_end()) + { + fn_parameter_argument->setName(param->get_name() + ".arg"); - // Create a memory slot on the stack for the parameter - llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( - fn_parameter_argument->getType(), - nullptr, - param->get_name() - ); + // Create a memory slot on the stack for the parameter + llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( + fn_parameter_argument->getType(), + nullptr, + param->get_name() + ); - // Store the initial argument value into the alloca - builder->CreateStore(fn_parameter_argument, alloca); + // Store the initial argument value into the alloca + builder->CreateStore(fn_parameter_argument, alloca); - ++fn_parameter_argument; + ++fn_parameter_argument; + } } - } - // Generate Body - llvm::Value* function_body_value = this->_body->codegen(module, builder); + // Generate Body + llvm::Value* function_body_value = this->_body->codegen(module, builder); - // Final Safety: Implicit Return - // If the get_body didn't explicitly return (no terminator found), add one. - if (llvm::BasicBlock* current_bb = builder->GetInsertBlock(); - current_bb && !current_bb->getTerminator()) - { - if (llvm::Type* ret_type = function->getReturnType(); - ret_type->isVoidTy()) - { - builder->CreateRetVoid(); - } - else if (function_body_value && function_body_value->getType() == ret_type) + // Final Safety: Implicit Return + // If the get_body didn't explicitly return (no terminator found), add one. + if (llvm::BasicBlock* current_bb = builder->GetInsertBlock(); + current_bb && !current_bb->getTerminator()) { - builder->CreateRet(function_body_value); - } - else - { - // Default return to keep IR valid (useful for main or incomplete functions) - if (ret_type->isFloatingPointTy()) + if (llvm::Type* ret_type = llvm_function_val->getReturnType(); + ret_type->isVoidTy()) { - builder->CreateRet(llvm::ConstantFP::get(ret_type, 0.0)); + builder->CreateRetVoid(); } - else if (ret_type->isIntegerTy()) + else if (function_body_value && function_body_value->getType() == ret_type) { - builder->CreateRet(llvm::ConstantInt::get(ret_type, 0)); + builder->CreateRet(function_body_value); } + // Default return to keep IR valid (useful for main or incomplete functions) else { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format( - "Function '{}' is missing a return path.", - this->get_function_name() - ), - this->get_source_fragment() - ); + if (ret_type->isFloatingPointTy()) + { + builder->CreateRet(llvm::ConstantFP::get(ret_type, 0.0)); + } + else if (ret_type->isIntegerTy()) + { + builder->CreateRet(llvm::ConstantInt::get(ret_type, 0)); + } + else + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "Function '{}' is missing a return path.", + this->get_function_name() + ), + this->get_source_fragment() + ); + } } } - } - if (llvm::verifyFunction(*function, &llvm::errs())) - { - module->print(llvm::errs(), nullptr); - throw parsing_error( - ErrorType::COMPILATION_ERROR, - "LLVM Function Verification Failed for: " + this->get_function_name(), - this->get_source_fragment() - ); - } + if (llvm::verifyFunction(*llvm_function_val, &llvm::errs())) + { + module->print(llvm::errs(), nullptr); + throw parsing_error( + ErrorType::COMPILATION_ERROR, + "LLVM Function Verification Failed for: " + this->get_function_name(), + this->get_source_fragment() + ); + } - // Restore the previous insert point for nested lambda generation - if (has_insert_point && saved_insert_block) - { - builder->SetInsertPoint(saved_insert_block, saved_insert_point); - } + // Restore the previous insert point for nested lambda generation + if (has_insert_point && saved_insert_block) + { + builder->SetInsertPoint(saved_insert_block, saved_insert_point); + } - // For anonymous lambdas with captured variables, create a closure structure - // that bundles the function pointer with the current values of captured variables - if (this->is_anonymous()) - { - // Collect the current values of captured variables from the enclosing scope - std::vector captured_values; - for (const auto& capture : this->get_captured_variables()) + // For anonymous lambdas with captured variables, create a closure structure + // that bundles the function pointer with the current values of captured variables + if (this->is_anonymous()) { - if (const auto block = builder->GetInsertBlock()) + // Collect the current values of captured variables from the enclosing scope + std::vector captured_values; + for (const auto& capture : this->get_captured_variables()) { - llvm::Function* current_fn = block->getParent(); - llvm::Value* captured_val = closures::lookup_variable_or_capture(current_fn, capture.internal_name); - - if (!captured_val) + if (const auto block = builder->GetInsertBlock()) { - captured_val = closures::lookup_variable_by_base_name(current_fn, capture.name); - } + llvm::Function* current_fn = block->getParent(); + llvm::Value* captured_val = closures::lookup_variable_or_capture(current_fn, capture.internal_name); - if (captured_val) - { - // Load the value if it's an alloca - if (auto* alloca = llvm::dyn_cast(captured_val)) + if (!captured_val) { - captured_val = builder->CreateLoad( - alloca->getAllocatedType(), - alloca, - capture.internal_name - ); + captured_val = closures::lookup_variable_by_base_name(current_fn, capture.name); + } + + if (captured_val) + { + // Load the value if it's an alloca + if (auto* alloca = llvm::dyn_cast(captured_val)) + { + captured_val = builder->CreateLoad( + alloca->getAllocatedType(), + alloca, + capture.internal_name + ); + } + captured_values.push_back(captured_val); } - captured_values.push_back(captured_val); } } + + // Create and return a closure instead of the raw function pointer + return closures::create_closure(module, builder, llvm_function_val, captured_values); } - // Create and return a closure instead of the raw function pointer - return closures::create_closure(module, builder, function, captured_values); + function = llvm_function_val; } return function; } + +/** + * Here we define the function in the symbol table, so it can be looked up in the codegen phase. + */ void IAstFunction::resolve_forward_references( llvm::Module* module, llvm::IRBuilderBase* builder @@ -885,7 +879,7 @@ void IAstFunction::resolve_forward_references( // Avoid re-registering if already declared. // Named functions are looked up by their scoped name; anonymous functions are // tracked by the cached _llvm_function pointer (they have no stable string name). - if (this->is_anonymous() && this->_llvm_function) + if (this->is_anonymous() && this->get_function_definition()->get_llvm_function() != nullptr) return; // Add captured variables as first parameters @@ -905,19 +899,20 @@ void IAstFunction::resolve_forward_references( ? llvm::Function::PrivateLinkage : llvm::Function::ExternalLinkage; + // --- Generic function instantiation if (const auto& definition = this->get_function_definition(); !definition->get_generic_instantiations().empty()) { - for (const auto& overload : definition->get_generic_instantiations()) + for (const auto& [types, function] : definition->get_generic_instantiations()) { - const auto overloaded_fn_name = get_internalized_overload_name(overload); - llvm::FunctionType* generic_function_type = this->get_overloaded_llvm_function_type( + const auto overloaded_fn_name = get_overloaded_function_name(this->get_internalized_function_name(), types); + llvm::FunctionType* generic_function_type = this->get_generic_instantiated_llvm_function_type( module, captured_types, - overload + types ); - llvm::Function* generic_function = llvm::Function::Create( + function = llvm::Function::Create( generic_function_type, linkage, overloaded_fn_name, @@ -925,10 +920,7 @@ void IAstFunction::resolve_forward_references( ); if (this->is_anonymous()) - { - generic_function->addFnAttr("stride.anonymous"); - this->_llvm_function = generic_function; - } + function->addFnAttr("stride.anonymous"); } } else @@ -937,7 +929,7 @@ void IAstFunction::resolve_forward_references( // so they are easily findable in the module without ambiguity. const std::string llvm_function_name = this->get_internalized_function_name(); - llvm::FunctionType* function_type = this->get_overloaded_llvm_function_type(module, captured_types); + llvm::FunctionType* function_type = this->get_generic_instantiated_llvm_function_type(module, captured_types); llvm::Function* function = llvm::Function::Create( function_type, @@ -949,14 +941,14 @@ void IAstFunction::resolve_forward_references( if (this->is_anonymous()) { function->addFnAttr("stride.anonymous"); - this->_llvm_function = function; + definition->set_llvm_function(function); } } this->_body->resolve_forward_references(module, builder); } -llvm::FunctionType* IAstFunction::get_overloaded_llvm_function_type( +llvm::FunctionType* IAstFunction::get_generic_instantiated_llvm_function_type( llvm::Module* module, std::vector captured_variables, const GenericTypeList& generic_instantiation_types @@ -1046,43 +1038,30 @@ FunctionDefinition* IAstFunction::get_function_definition() return this->_function_definition; } -std::string IAstFunction::get_internalized_overload_name(const GenericTypeList& overload) const -{ - std::vector generic_instantiation_type_names; - generic_instantiation_type_names.reserve(overload.size()); - - for (const auto& type : overload) - { - generic_instantiation_type_names.push_back(type->get_type_name()); - } - - return std::format( - "{}${}", - this->get_internalized_function_name(), - join(generic_instantiation_type_names, "_") - ); -} - -std::vector IAstFunction::get_internalized_overload_names() +std::vector IAstFunction::get_function_overload_metadata() { const auto& definition = this->get_function_definition(); // If the function is not generic, we just return a singular name (the regular internalized name) if (!definition->get_type()->is_generic()) { - return { this->get_internalized_function_name() }; + return { + GenericFunctionMetadata{ this->get_internalized_function_name(), definition->get_llvm_function() } + }; } - std::vector overload_names; + std::vector metadata; + - for (const auto& overload : definition->get_generic_instantiations()) + for (const auto& [types, llvm_function] : definition->get_generic_instantiations()) { - overload_names.push_back( - get_internalized_overload_name(overload) + metadata.emplace_back( + get_overloaded_function_name(this->get_internalized_function_name(), types), + llvm_function ); } - return overload_names; + return metadata; } std::unique_ptr AstFunctionParameter::clone() diff --git a/packages/compiler/src/ast/nodes/types/alias_type.cpp b/packages/compiler/src/ast/nodes/types/alias_type.cpp index 067fc9a0..5316432f 100644 --- a/packages/compiler/src/ast/nodes/types/alias_type.cpp +++ b/packages/compiler/src/ast/nodes/types/alias_type.cpp @@ -148,6 +148,7 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr func->get_context(), std::move(resolved_params), std::move(resolved_return), + func->get_generic_parameter_names(), func->get_flags() ); } diff --git a/packages/compiler/src/ast/nodes/types/function_type.cpp b/packages/compiler/src/ast/nodes/types/function_type.cpp index e4a457cb..cb34ea32 100644 --- a/packages/compiler/src/ast/nodes/types/function_type.cpp +++ b/packages/compiler/src/ast/nodes/types/function_type.cpp @@ -65,12 +65,15 @@ std::optional> stride::ast::parse_function_type_option set.expect(TokenType::RPAREN, "Expected secondary ')' after function type notation"); } + // TODO: Resolve generic parameters in type resolution + return parse_type_metadata( std::make_unique( reference_token.get_source_fragment(), context, std::move(parameters), std::move(return_type), + EMPTY_GENERIC_PARAMETER_LIST, flags ), set @@ -80,17 +83,27 @@ std::optional> stride::ast::parse_function_type_option std::unique_ptr AstFunctionType::clone() { std::vector> parameters; + GenericParameterList generic_parameters_clone; + parameters.reserve(this->_parameters.size()); + generic_parameters_clone.reserve(this->_generic_param_names.size()); + for (const auto& p : this->_parameters) { parameters.push_back(p->clone_ty()); } + generic_parameters_clone.insert( + generic_parameters_clone.end(), + this->_generic_param_names.begin(), + this->_generic_param_names.end()); + return std::make_unique( this->get_source_fragment(), this->get_context(), std::move(parameters), - this->_return_type->clone_ty(), + this->get_return_type()->clone_ty(), + std::move(generic_parameters_clone), this->get_flags() ); } diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index 0e4680e1..62f05197 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -24,13 +24,10 @@ void FunctionVisitor::accept(IAstFunction* fn_declaration) } // Forward declare the function in the symbol registry - if (dynamic_cast(fn_declaration)) - { - fn_declaration->get_context()->define_function( - fn_declaration->get_symbol(), - fn_declaration->get_type()->clone_as(), - fn_declaration->get_visibility(), - fn_declaration->get_flags() - ); - } + fn_declaration->get_context()->define_function( + fn_declaration->get_symbol(), + fn_declaration->get_type()->clone_as(), + fn_declaration->get_visibility(), + fn_declaration->get_flags() + ); } diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 4e3873b4..50599d5d 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -267,21 +267,23 @@ std::unique_ptr stride::ast::infer_object_initializer_type(const AstOb ); } -std::unique_ptr stride::ast::infer_function_type(const IAstFunction* expression) +std::unique_ptr stride::ast::infer_function_type(const IAstFunction* function) { std::vector> param_types; - param_types.reserve(expression->get_parameters().size()); + param_types.reserve(function->get_parameters().size()); - for (const auto& param : expression->get_parameters()) + for (const auto& param : function->get_parameters()) { param_types.emplace_back(param->get_type()->clone_ty()); } return std::make_unique( - expression->get_source_fragment(), - expression->get_context(), + function->get_source_fragment(), + function->get_context(), std::move(param_types), - expression->get_return_type()->clone_ty() + function->get_return_type()->clone_ty(), + function->get_generic_parameters(), + function->get_flags() ); } @@ -312,7 +314,8 @@ std::unique_ptr stride::ast::infer_identifier_type(const AstIdentifier identifier->get_source_fragment(), identifier->get_context(), std::move(param_types), - callable->get_type()->get_return_type()->clone_ty() + callable->get_type()->get_return_type()->clone_ty(), + callable->get_type()->get_generic_parameter_names() ); } From 9325a5cecd5afd587c1a39f88e8444cb90a31391 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 13:42:35 +0100 Subject: [PATCH 05/31] refactor: replace `combine` with `join` in `SourceFragment` for clarity and consistency --- packages/compiler/include/files.h | 2 +- .../src/ast/nodes/expressions/array_initializer.cpp | 2 +- .../src/ast/nodes/expressions/array_member_accessor.cpp | 2 +- .../src/ast/nodes/expressions/comparison_operation.cpp | 5 ++++- .../compiler/src/ast/nodes/expressions/expression.cpp | 6 +++--- .../src/ast/nodes/expressions/member_accessor.cpp | 4 ++-- packages/compiler/src/ast/nodes/expressions/type_cast.cpp | 2 +- packages/compiler/src/ast/nodes/import.cpp | 2 +- packages/compiler/src/ast/nodes/module.cpp | 2 +- packages/compiler/src/ast/nodes/package.cpp | 2 +- packages/compiler/src/ast/nodes/return_statement.cpp | 2 +- packages/compiler/src/ast/nodes/types/type_definition.cpp | 2 +- packages/compiler/src/ast/type_inference.cpp | 4 +++- packages/compiler/src/errors.cpp | 8 ++++---- packages/compiler/src/files.cpp | 2 +- 15 files changed, 26 insertions(+), 21 deletions(-) diff --git a/packages/compiler/include/files.h b/packages/compiler/include/files.h index 629a3dd0..573e771f 100644 --- a/packages/compiler/include/files.h +++ b/packages/compiler/include/files.h @@ -43,7 +43,7 @@ namespace stride return *this; } - static SourceFragment combine(const SourceFragment& source_fragment, const SourceFragment& get_source_fragment); + static SourceFragment join(const SourceFragment& source_fragment, const SourceFragment& get_source_fragment); }; std::shared_ptr read_file(const std::string& path); diff --git a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp index 73d68a32..be316bd2 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp @@ -45,7 +45,7 @@ std::unique_ptr stride::ast::parse_array_initializer( const auto& last_token_pos = set.peek(-1).get_source_fragment(); // `]` is already consumed, so we peek back at it - const auto position = SourceFragment::combine(reference_token.get_source_fragment(), last_token_pos); + const auto position = SourceFragment::join(reference_token.get_source_fragment(), last_token_pos); return std::make_unique( position, diff --git a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp index eb9cdd31..db10cd7e 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp @@ -33,7 +33,7 @@ std::unique_ptr stride::ast::parse_array_member_accessor auto index_expression = parse_inline_expression(context, expression_block.value()); - const auto source_pos = SourceFragment::combine(array_base->get_source_fragment(), last_src_pos); + const auto source_pos = SourceFragment::join(array_base->get_source_fragment(), last_src_pos); return std::make_unique( source_pos, diff --git a/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp b/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp index a89c7a85..b875d984 100644 --- a/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp @@ -93,7 +93,10 @@ void AstComparisonOp::validate() throw parsing_error( ErrorType::SEMANTIC_ERROR, "Comparison operation operands must be used on primitive or optional types", - this->get_source_fragment() + { + ErrorSourceReference("type " + lhs_type->get_type_name(), _lhs->get_source_fragment()), + ErrorSourceReference("type " + rhs_type->get_type_name(), _rhs->get_source_fragment()) + } ); } diff --git a/packages/compiler/src/ast/nodes/expressions/expression.cpp b/packages/compiler/src/ast/nodes/expressions/expression.cpp index 9e5f95d6..4b5a0d41 100644 --- a/packages/compiler/src/ast/nodes/expressions/expression.cpp +++ b/packages/compiler/src/ast/nodes/expressions/expression.cpp @@ -187,7 +187,7 @@ std::unique_ptr parse_comparison_tier( const auto token = set.next(); auto rhs = parse_arithmetic_tier(context, set); lhs = std::make_unique( - token.get_source_fragment(), + stride::SourceFragment::join(lhs->get_source_fragment(), rhs->get_source_fragment()), context, std::move(lhs), op.value(), @@ -209,7 +209,7 @@ std::unique_ptr parse_logical_tier( const auto token = set.next(); auto rhs = parse_comparison_tier(context, set); lhs = std::make_unique( - token.get_source_fragment(), + stride::SourceFragment::join(lhs->get_source_fragment(), rhs->get_source_fragment()), context, std::move(lhs), op.value(), @@ -336,7 +336,7 @@ std::unique_ptr stride::ast::parse_segmented_identifier( } const auto source_pos = last_fragment.has_value() - ? SourceFragment::combine(initial_identifier.get_source_fragment(), last_fragment.value()) + ? SourceFragment::join(initial_identifier.get_source_fragment(), last_fragment.value()) : initial_identifier.get_source_fragment(); return std::make_unique( diff --git a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp index 0759dc65..1b8167ef 100644 --- a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp @@ -34,7 +34,7 @@ std::unique_ptr stride::ast::parse_chained_member_access( Symbol(member_tok.get_source_fragment(), member_tok.get_lexeme()) ); - const auto source = SourceFragment::combine(lhs->get_source_fragment(), member_tok.get_source_fragment()); + const auto source = SourceFragment::join(lhs->get_source_fragment(), member_tok.get_source_fragment()); return std::make_unique( source, @@ -70,7 +70,7 @@ std::unique_ptr stride::ast::parse_indirect_call( } const auto close_src = set.peek(-1).get_source_fragment(); - const auto source = SourceFragment::combine(callee_src, close_src); + const auto source = SourceFragment::join(callee_src, close_src); return std::make_unique( source, diff --git a/packages/compiler/src/ast/nodes/expressions/type_cast.cpp b/packages/compiler/src/ast/nodes/expressions/type_cast.cpp index 9b734024..00d642c9 100644 --- a/packages/compiler/src/ast/nodes/expressions/type_cast.cpp +++ b/packages/compiler/src/ast/nodes/expressions/type_cast.cpp @@ -19,7 +19,7 @@ std::optional> stride::ast::parse_type_cast_op( auto type = parse_type(context, set, { "Expected type after 'as' in type cast operation" }); - const auto source_fragment = SourceFragment::combine(lhs->get_source_fragment(), type->get_source_fragment()); + const auto source_fragment = SourceFragment::join(lhs->get_source_fragment(), type->get_source_fragment()); return std::make_unique( source_fragment, diff --git a/packages/compiler/src/ast/nodes/import.cpp b/packages/compiler/src/ast/nodes/import.cpp index 87cc2e2f..0dd01ca6 100644 --- a/packages/compiler/src/ast/nodes/import.cpp +++ b/packages/compiler/src/ast/nodes/import.cpp @@ -79,7 +79,7 @@ std::unique_ptr stride::ast::parse_import_statement( auto import_list = consume_import_submodules(context, set); return std::make_unique( - SourceFragment::combine(reference_token.get_source_fragment(), package_identifier->get_source_fragment()), + SourceFragment::join(reference_token.get_source_fragment(), package_identifier->get_source_fragment()), context, std::move(package_identifier), std::move(import_list) diff --git a/packages/compiler/src/ast/nodes/module.cpp b/packages/compiler/src/ast/nodes/module.cpp index 1124c6f9..35f42ab9 100644 --- a/packages/compiler/src/ast/nodes/module.cpp +++ b/packages/compiler/src/ast/nodes/module.cpp @@ -46,7 +46,7 @@ std::unique_ptr stride::ast::parse_module_statement( auto module_body = parse_block(module_context, set); return std::make_unique( - SourceFragment::combine(reference_token.get_source_fragment(), module_identifier_tok.get_source_fragment()), + SourceFragment::join(reference_token.get_source_fragment(), module_identifier_tok.get_source_fragment()), module_context, module_name, std::move(module_body) diff --git a/packages/compiler/src/ast/nodes/package.cpp b/packages/compiler/src/ast/nodes/package.cpp index b3ab9721..26265b8b 100644 --- a/packages/compiler/src/ast/nodes/package.cpp +++ b/packages/compiler/src/ast/nodes/package.cpp @@ -35,7 +35,7 @@ std::unique_ptr stride::ast::parse_package_declaration( const auto& last_pos = set.expect(TokenType::SEMICOLON, "Expected semicolon after package declaration"); return std::make_unique( - SourceFragment::combine(reference_token.get_source_fragment(), last_pos.get_source_fragment()), + SourceFragment::join(reference_token.get_source_fragment(), last_pos.get_source_fragment()), context, package_name ); diff --git a/packages/compiler/src/ast/nodes/return_statement.cpp b/packages/compiler/src/ast/nodes/return_statement.cpp index 95aaae9b..e226761d 100644 --- a/packages/compiler/src/ast/nodes/return_statement.cpp +++ b/packages/compiler/src/ast/nodes/return_statement.cpp @@ -45,7 +45,7 @@ std::unique_ptr stride::ast::parse_return_statement( .get_source_fragment(); return std::make_unique( - SourceFragment::combine(ref_pos, end_pos), + SourceFragment::join(ref_pos, end_pos), context, std::move(return_value) ); diff --git a/packages/compiler/src/ast/nodes/types/type_definition.cpp b/packages/compiler/src/ast/nodes/types/type_definition.cpp index 96f612a0..6f6f5e02 100644 --- a/packages/compiler/src/ast/nodes/types/type_definition.cpp +++ b/packages/compiler/src/ast/nodes/types/type_definition.cpp @@ -29,7 +29,7 @@ std::unique_ptr stride::ast::parse_type_definition( const auto& last_token = set.expect(TokenType::SEMICOLON, "Expected ';' after type definition"); const auto& last_pos = last_token.get_source_fragment(); - const auto source_fragment = SourceFragment::combine(ref_pos, last_pos); + const auto source_fragment = SourceFragment::join(ref_pos, last_pos); const auto type_name_symbol = resolve_internal_name( context->get_name(), source_fragment, diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 50599d5d..d89df6af 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -67,7 +67,9 @@ std::unique_ptr stride::ast::infer_binary_op_type(IBinaryOp* operation if (cast_expr(operation) || cast_expr(operation)) { return std::make_unique( - operation->get_source_fragment(), + SourceFragment::join( + operation->get_left()->get_source_fragment(), + operation->get_right()->get_source_fragment()), operation->get_context(), PrimitiveType::BOOL ); diff --git a/packages/compiler/src/errors.cpp b/packages/compiler/src/errors.cpp index 3492bd9f..6a92d21c 100644 --- a/packages/compiler/src/errors.cpp +++ b/packages/compiler/src/errors.cpp @@ -252,20 +252,20 @@ std::string stride::make_source_error( // Line for the right-hand reference (placed first in order to be above the left message). // right_line: pipe connector (for left_ref) + message (for right_ref) std::string right_line(base_padding + col1, ' '); - right_line += "┃"; + right_line += "│"; right_line += std::string(col2 - col1 - 1, ' '); - right_line += "┗ "; + right_line += "└ "; right_line += right_ref.message; result += std::format("\n\033[0;31m┃ {}", right_line); // Connector line: vertical pipe for left reference. std::string connector(base_padding + col1, ' '); - connector += "┃"; + connector += "│"; result += std::format("\n\033[0;31m┃ {}", connector); // Line for the left-hand reference message. std::string left_line(base_padding + col1, ' '); - left_line += "┗ " + left_ref.message; + left_line += "└─ " + left_ref.message; result += std::format("\n\033[0;31m┃ {}", left_line); } else diff --git a/packages/compiler/src/files.cpp b/packages/compiler/src/files.cpp index 37d8d67f..4aa7a216 100644 --- a/packages/compiler/src/files.cpp +++ b/packages/compiler/src/files.cpp @@ -23,7 +23,7 @@ std::shared_ptr stride::read_file(const std::string& path) return std::make_shared(path, std::move(content)); } -SourceFragment SourceFragment::combine(const SourceFragment& first, const SourceFragment& last) +SourceFragment SourceFragment::join(const SourceFragment& first, const SourceFragment& last) { return SourceFragment( first.source, From 340cb3655a0f12d83bedadd46b5f3c111fb08410 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 14:55:30 +0100 Subject: [PATCH 06/31] refactor: restructure function definition declarations and improve handling of generics --- example.sr | 4 +- .../ast/definitions/function_definition.h | 103 ++++++ ...on_declaration.h => function_definition.h} | 28 +- .../compiler/include/ast/parsing_context.h | 97 +----- packages/compiler/include/ast/visitor.h | 2 +- packages/compiler/src/ast/ast.cpp | 2 +- .../src/ast/context/function_registry.cpp | 7 +- packages/compiler/src/ast/generics.cpp | 11 + packages/compiler/src/ast/nodes/blocks.cpp | 2 +- .../expressions/variable_declaration.cpp | 2 +- .../src/ast/nodes/functions/function_call.cpp | 2 + ...eclaration.cpp => function_definition.cpp} | 306 ++++++++++-------- .../nodes/functions/function_parameters.cpp | 2 +- .../src/ast/traversal/function_visitor.cpp | 23 +- .../compiler/src/ast/traversal/traversal.cpp | 2 +- packages/compiler/src/ast/type_inference.cpp | 2 +- .../compiler/tests/test_type_inference.cpp | 2 +- 17 files changed, 324 insertions(+), 273 deletions(-) create mode 100644 packages/compiler/include/ast/definitions/function_definition.h rename packages/compiler/include/ast/nodes/{function_declaration.h => function_definition.h} (95%) rename packages/compiler/src/ast/nodes/functions/{function_declaration.cpp => function_definition.cpp} (82%) diff --git a/example.sr b/example.sr index cf56e0de..5a1a7d33 100644 --- a/example.sr +++ b/example.sr @@ -16,9 +16,7 @@ fn some_comparison(a: T, b: T): bool { fn makeCar(): Car { const names = ["Toyota", "Honda", "Ford", "Toyota"]; - return Car::{ drive: (): void -> { - io::print("\x1b[32mDriving the car"); - }, names, }; + return Car::{ drive: (): void -> io::print("\x1b[32mDriving the car"), names }; } fn main(): i32 { diff --git a/packages/compiler/include/ast/definitions/function_definition.h b/packages/compiler/include/ast/definitions/function_definition.h new file mode 100644 index 00000000..03eb3044 --- /dev/null +++ b/packages/compiler/include/ast/definitions/function_definition.h @@ -0,0 +1,103 @@ +#pragma once + +#include "ast/parsing_context.h" +#include "ast/nodes/function_definition.h" + +#include + +namespace stride::ast::definition +{ + struct GenericFunctionOverload + { + mutable llvm::Function* function; + std::unique_ptr node; + }; + + class FunctionDefinition + : public IDefinition + { + std::unique_ptr _function_type; + std::vector _function_candidates{}; + int _flags; + + llvm::Function* _llvm_function = nullptr; + + public: + explicit FunctionDefinition( + std::unique_ptr function_type, + const Symbol& symbol, + const VisibilityModifier visibility, + const int flags + ) : + IDefinition(symbol, visibility), + _function_type(std::move(function_type)), + _flags(flags) {} + + [[nodiscard]] + AstFunctionType* get_type() const + { + return this->_function_type.get(); + } + + [[nodiscard]] + std::string get_function_name() const + { + return this->get_symbol().name; + } + + [[nodiscard]] + int get_flags() const + { + return this->_flags; + } + + [[nodiscard]] + bool is_variadic() const + { + return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; + } + + void add_generic_instantiation(GenericTypeList generic_overload_types); + + [[nodiscard]] + const std::vector& get_instantiations() const + { + return this->_function_candidates; + } + + [[nodiscard]] + bool has_generic_instantiation(const GenericTypeList& generic_types) const; + + ~FunctionDefinition() override = default; + + bool matches_type_signature(const std::string& name, const AstFunctionType* signature) const; + + void set_llvm_function(llvm::Function* function) + { + this->_llvm_function = function; + } + + [[nodiscard]] + llvm::Function* get_llvm_function() const + { + return this->_llvm_function; + } + + [[nodiscard]] + bool matches_parameter_signature( + const std::string& internal_function_name, + const std::vector>& other_parameter_types, + size_t generic_argument_count + ) const; + + [[nodiscard]] + std::unique_ptr clone() const override + { + return std::make_unique( + _function_type->clone_as(), + get_symbol(), + get_visibility(), + _flags); + } + }; +} diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_definition.h similarity index 95% rename from packages/compiler/include/ast/nodes/function_declaration.h rename to packages/compiler/include/ast/nodes/function_definition.h index c2561c45..4605bfa5 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_definition.h @@ -39,7 +39,10 @@ namespace stride::ast _name(std::move(param_name)), _type(std::move(param_type)) {} - std::string to_string() override; + std::string to_string() override + { + return std::format("{}({})", this->get_name(), this->get_type()->get_type_name()); + } [[nodiscard]] const std::string& get_name() const @@ -166,19 +169,19 @@ namespace stride::ast [[nodiscard]] bool is_extern() const { - return this->_flags & SRFLAG_FN_TYPE_EXTERN; + return (this->_flags & SRFLAG_FN_TYPE_EXTERN) != 0; } [[nodiscard]] bool is_variadic() const { - return this->_flags & SRFLAG_FN_TYPE_VARIADIC; + return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; } [[nodiscard]] bool is_anonymous() const { - return this->_flags & SRFLAG_FN_TYPE_ANONYMOUS; + return (this->_flags & SRFLAG_FN_TYPE_ANONYMOUS) != 0; } [[nodiscard]] @@ -206,7 +209,7 @@ namespace stride::ast } [[nodiscard]] - bool is_generic_function() const + bool is_generic() const { return !this->_generic_parameters.empty(); } @@ -236,12 +239,16 @@ namespace stride::ast std::unique_ptr clone() override; + std::string to_string() override; + private: llvm::FunctionType* get_generic_instantiated_llvm_function_type( llvm::Module* module, std::vector captured_variables, const GenericTypeList& generic_instantiation_types = {} ) const; + + static void validate_candidate(IAstFunction* candidate); }; class AstFunctionDeclaration @@ -249,8 +256,6 @@ namespace stride::ast public IAstStatement { public: - using IAstStatement::IAstStatement; - explicit AstFunctionDeclaration( const std::shared_ptr& context, Symbol symbol, @@ -273,8 +278,6 @@ namespace stride::ast generic_parameters ) {} - std::string to_string() override; - ~AstFunctionDeclaration() override = default; }; @@ -303,14 +306,7 @@ namespace stride::ast {} ) {} - std::string to_string() override; - ~AstLambdaFunctionExpression() override = default; - - std::string get_mangled_name() const - { - return ""; - }; // TODO: Implement }; std::unique_ptr parse_fn_declaration( diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 0463bace..096651db 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -17,6 +17,7 @@ namespace llvm namespace stride::ast { + class AstFunctionDeclaration; enum class VisibilityModifier; enum class ContextType @@ -30,6 +31,8 @@ namespace stride::ast namespace definition { + class FunctionDefinition; + enum class SymbolType { CLASS, @@ -155,99 +158,6 @@ namespace stride::ast return std::make_unique(get_symbol(), _type->clone_ty(), get_visibility()); } }; - - struct GenericFunctionOverload - { - GenericTypeList types; - mutable llvm::Function* function; - }; - - class FunctionDefinition - : public IDefinition - { - std::unique_ptr _function_type; - std::vector _generic_type_overloads{}; - int _flags; - - llvm::Function* _llvm_function = nullptr; - - public: - explicit FunctionDefinition( - std::unique_ptr function_type, - const Symbol& symbol, - const VisibilityModifier visibility, - const int flags - ) : - IDefinition(symbol, visibility), - _function_type(std::move(function_type)), - _flags(flags) {} - - [[nodiscard]] - AstFunctionType* get_type() const - { - return this->_function_type.get(); - } - - [[nodiscard]] - std::string get_function_name() const - { - return this->get_symbol().name; - } - - [[nodiscard]] - int get_flags() const - { - return this->_flags; - } - - [[nodiscard]] - bool is_variadic() const - { - return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; - } - - void add_generic_instantiation(GenericTypeList generic_overload_types); - - [[nodiscard]] - const std::vector& get_generic_instantiations() const - { - return this->_generic_type_overloads; - } - - [[nodiscard]] - bool has_generic_instantiation(const GenericTypeList& generic_types) const; - - ~FunctionDefinition() override = default; - - bool matches_type_signature(const std::string& name, const AstFunctionType* signature) const; - - void set_llvm_function(llvm::Function* function) - { - this->_llvm_function = function; - } - - [[nodiscard]] - llvm::Function* get_llvm_function() const - { - return this->_llvm_function; - } - - [[nodiscard]] - bool matches_parameter_signature( - const std::string& internal_function_name, - const std::vector>& other_parameter_types, - size_t generic_argument_count - ) const; - - [[nodiscard]] - std::unique_ptr clone() const override - { - return std::make_unique(_function_type->clone_as(), - get_symbol(), - get_visibility(), - _flags); - } - }; } // namespace definition class ParsingContext @@ -433,7 +343,6 @@ namespace stride::ast return this->_context_name; } - private: [[nodiscard]] const ParsingContext& traverse_to_root() const; }; diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h index 65008451..4779b60a 100644 --- a/packages/compiler/include/ast/visitor.h +++ b/packages/compiler/include/ast/visitor.h @@ -33,7 +33,7 @@ namespace stride::ast class FunctionVisitor : public IVisitor { public: - void accept(IAstFunction* fn_declaration) override; + void accept(IAstFunction* function) override; }; class ImportVisitor : public IVisitor diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp index 7d49dda6..d06f8911 100644 --- a/packages/compiler/src/ast/ast.cpp +++ b/packages/compiler/src/ast/ast.cpp @@ -7,7 +7,7 @@ #include "ast/nodes/conditional_statement.h" #include "ast/nodes/control_flow_statements.h" #include "ast/nodes/for_loop.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/nodes/import.h" #include "ast/nodes/module.h" #include "ast/nodes/package.h" diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 3191a639..37e31317 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -1,6 +1,7 @@ #include "errors.h" #include "ast/casting.h" #include "ast/parsing_context.h" +#include "ast/definitions/function_definition.h" #include @@ -164,12 +165,12 @@ bool ParsingContext::is_function_defined_globally( bool FunctionDefinition::has_generic_instantiation(const std::vector>& generic_types) const { - for (const auto& [types, function] : this->_generic_type_overloads) + for (const auto& [instantiated_generic_types, function, declaration] : this->_function_candidates) { bool all_equal = true; for (size_t i = 0; i < generic_types.size(); i++) { - if (!types[i]->equals(generic_types[i].get())) + if (!instantiated_generic_types[i]->equals(generic_types[i].get())) { all_equal = false; break; @@ -188,5 +189,5 @@ void FunctionDefinition::add_generic_instantiation(GenericTypeList generic_overl if (has_generic_instantiation(generic_overload_types)) return; // Already instantiated - this->_generic_type_overloads.push_back({std::move(generic_overload_types)}); + this->_function_candidates.push_back({ std::move(generic_overload_types) }); } diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index 57537914..fc3669b6 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -62,6 +62,17 @@ std::unique_ptr stride::ast::resolve_generics( const GenericTypeList& instantiated_types ) { + if (param_names.size() != instantiated_types.size()) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Failed to resolve generic type: expected {} parameters, got ", + param_names.size(), + instantiated_types.size()), + type->get_source_fragment() + ); + } if (auto* named_type = cast_type(type)) { for (size_t i = 0; i < param_names.size(); i++) diff --git a/packages/compiler/src/ast/nodes/blocks.cpp b/packages/compiler/src/ast/nodes/blocks.cpp index e29fcce4..323d8cde 100644 --- a/packages/compiler/src/ast/nodes/blocks.cpp +++ b/packages/compiler/src/ast/nodes/blocks.cpp @@ -1,7 +1,7 @@ #include "ast/nodes/blocks.h" #include "ast/ast.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/tokens/token_set.h" #include diff --git a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp index df747aaa..556f39cf 100644 --- a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp +++ b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp @@ -5,7 +5,7 @@ #include "ast/optionals.h" #include "ast/parsing_context.h" #include "ast/nodes/expression.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/nodes/literal_values.h" #include "ast/tokens/token_set.h" diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/function_call.cpp index 5b56ba83..52eca1f9 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_call.cpp @@ -6,6 +6,7 @@ #include "ast/optionals.h" #include "ast/parsing_context.h" #include "ast/symbols.h" +#include "ast/definitions/function_definition.h" #include "ast/nodes/blocks.h" #include "ast/nodes/expression.h" #include "ast/nodes/types.h" @@ -218,6 +219,7 @@ llvm::Value* AstFunctionCall::codegen( llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) { const auto& function_definition = this->get_function_definition(); + if (llvm::Function* callee = module->getFunction(function_definition->get_internal_symbol_name())) { return callee; diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_definition.cpp similarity index 82% rename from packages/compiler/src/ast/nodes/functions/function_declaration.cpp rename to packages/compiler/src/ast/nodes/functions/function_definition.cpp index e059162f..2185265b 100644 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_definition.cpp @@ -1,4 +1,4 @@ -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "errors.h" #include "ast/casting.h" @@ -6,6 +6,7 @@ #include "ast/modifiers.h" #include "ast/parsing_context.h" #include "ast/symbols.h" +#include "ast/definitions/function_definition.h" #include "ast/nodes/blocks.h" #include "ast/nodes/conditional_statement.h" #include "ast/nodes/expression.h" @@ -15,6 +16,7 @@ #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" +#include #include #include #include @@ -59,6 +61,11 @@ std::unique_ptr stride::ast::parse_fn_declaration( GenericParameterList generic_parameter_names = parse_generic_declaration(set); + if (function_flags & SRFLAG_FN_TYPE_EXTERN && !generic_parameter_names.empty()) + { + set.throw_error("Extern functions cannot have generic parameters"); + } + set.expect(TokenType::LPAREN, "Expected '(' after function name"); std::vector> parameters; @@ -528,68 +535,36 @@ std::vector> IAstFunction::get_parameters( return cloned_params; } -void IAstFunction::validate() +void IAstFunction::validate_candidate(IAstFunction* candidate) { - if (this->is_anonymous()) - { - std::vector captures; - const auto outer_context = this->get_context()->get_parent_context() != nullptr - ? this->get_context()->get_parent_context() - : this->get_context(); - - collect_free_variables(this->get_body(), this->get_context(), outer_context, captures); + candidate->get_body()->validate(); - // Register captured variables in the lambda's context so they can be referenced - for (const auto& capture : captures) - { - this->add_captured_variable(capture); - - // Also define the capture in the lambda's context so identifier lookup works - if (const auto outer_var = this->get_context()->lookup_variable(capture.name, true)) - { - this->get_context()->define_variable( - capture, - outer_var->get_type()->clone_ty(), - VisibilityModifier::PRIVATE - ); - } - } - } - - // Extern functions don't require return statements and have no function body, so no validation - // needed. - if (this->is_extern()) - { - return; - } - - this->get_body()->validate(); - - const auto return_statements = collect_return_statements(this->get_body()); + const auto& ret_ty = candidate->get_return_type(); + const auto return_statements = collect_return_statements(candidate->get_body()); // For void types, we only disallow returning expressions, as this is redundant. - if (const auto void_ret = cast_type(this->get_return_type()); + if (const auto void_ret = cast_type(ret_ty); void_ret != nullptr && void_ret->get_primitive_type() == PrimitiveType::VOID) { for (const auto& return_stmt : return_statements) { if (return_stmt->get_return_expression().has_value()) { - throw parsing_error( - ErrorType::TYPE_ERROR, + throw stride::parsing_error( + stride::ErrorType::TYPE_ERROR, std::format( "{} has return type 'void' and cannot return a value.", - this->is_anonymous() + candidate->is_anonymous() ? "Anonymous function" - : std::format("Function '{}'", this->get_function_name())), + : std::format("Function '{}'", candidate->get_function_name())), { - ErrorSourceReference( + stride::ErrorSourceReference( "unexpected return value", return_stmt->get_source_fragment() ), - ErrorSourceReference( + stride::ErrorSourceReference( "Function returning void type", - this->get_source_fragment() + candidate->get_source_fragment() ) } @@ -601,24 +576,24 @@ void IAstFunction::validate() if (return_statements.empty()) { - if (cast_type(this->get_return_type())) + if (cast_type(ret_ty)) { - throw parsing_error( - ErrorType::TYPE_ERROR, + throw stride::parsing_error( + stride::ErrorType::TYPE_ERROR, std::format( "Function '{}' returns a struct type, but no return statement is present.", - this->get_function_name()), - this->get_source_fragment()); + candidate->get_function_name()), + candidate->get_source_fragment()); } - throw parsing_error( - ErrorType::COMPILATION_ERROR, + throw stride::parsing_error( + stride::ErrorType::COMPILATION_ERROR, std::format( "{} is missing a return statement.", - this->is_anonymous() + candidate->is_anonymous() ? "Anonymous function" - : std::format("Function '{}'", this->get_function_name())), - this->get_source_fragment() + : std::format("Function '{}'", candidate->get_function_name())), + candidate->get_source_fragment() ); } @@ -626,14 +601,14 @@ void IAstFunction::validate() { if (return_stmt->is_void_type()) { - if (!this->get_return_type()->is_void_ty()) + if (!ret_ty->is_void_ty()) { - throw parsing_error( - ErrorType::TYPE_ERROR, + throw stride::parsing_error( + stride::ErrorType::TYPE_ERROR, std::format( "Function '{}' returns a value of type '{}', but no return statement is present.", - this->is_anonymous() ? "" : this->get_function_name(), - this->get_return_type()->to_string()), + candidate->is_anonymous() ? "" : candidate->get_function_name(), + ret_ty->to_string()), return_stmt->get_source_fragment() ); } @@ -641,34 +616,87 @@ void IAstFunction::validate() } if (const auto& ret_expr = return_stmt->get_return_expression().value(); - !ret_expr->get_type()->equals(this->get_return_type()) && - !ret_expr->get_type()->is_assignable_to(this->get_return_type())) + !ret_expr->get_type()->equals(ret_ty) && + !ret_expr->get_type()->is_assignable_to(ret_ty)) { - const auto error_fragment = ErrorSourceReference( + const auto error_fragment = stride::ErrorSourceReference( std::format( "expected {}{}", - this->get_return_type()->is_primitive() + candidate->get_return_type()->is_primitive() ? "" - : this->get_return_type()->is_function() + : candidate->get_return_type()->is_function() ? "function-type " : "struct-type ", - this->get_return_type()->to_string()), + candidate->get_return_type()->to_string()), ret_expr->get_source_fragment() ); - throw parsing_error( - ErrorType::TYPE_ERROR, + throw stride::parsing_error( + stride::ErrorType::TYPE_ERROR, std::format( "Function '{}' expected a return type of '{}', but received '{}'.", - this->is_anonymous() ? "" : this->get_function_name(), - this->get_return_type()->to_string(), - ret_expr->get_type()->to_string()), + candidate->is_anonymous() ? "" : candidate->get_function_name(), + ret_ty->get_type_name(), + ret_expr->get_type()->get_type_name()), { error_fragment } ); } } } +void IAstFunction::validate() +{ + // Extern functions have no body to validate + if (this->is_extern()) + return; + + std::vector> validatable_candidates; + + // + // For generic functions, we create a new copy of the function with all parameters resolved, and do validation + // on that copy. This is because we want to validate the function body with the actual types that will be used in + // the function, rather than the generic placeholders. + // + // create a copy of this function with the parameters instantiated + for (const auto definition = this->get_function_definition(); + const auto& [types, function, node] : definition->get_instantiations()) + { + auto instantiated_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + types + ); + + std::vector> instantiated_function_params; + instantiated_function_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + instantiated_function_params.push_back( + std::make_unique( + param->get_source_fragment(), + param->get_context(), + param->get_name(), + resolve_generics(param->get_type(), this->_generic_parameters, types) + ) + ); + } + + const auto& candidate = std::make_unique( + this->get_context(), + this->_symbol, + std::move(instantiated_function_params), + this->_body->clone_as(), + std::move(instantiated_return_ty), + this->get_visibility(), + this->_flags, + this->get_generic_parameters() + ); + + validate_candidate(candidate.get()); + } +} + llvm::Value* IAstFunction::codegen( llvm::Module* module, llvm::IRBuilderBase* builder @@ -867,7 +895,6 @@ llvm::Value* IAstFunction::codegen( return function; } - /** * Here we define the function in the symbol table, so it can be looked up in the codegen phase. */ @@ -876,6 +903,29 @@ void IAstFunction::resolve_forward_references( llvm::IRBuilderBase* builder ) { + std::vector captures; + const auto outer_context = this->get_context()->get_parent_context() != nullptr + ? this->get_context()->get_parent_context() + : this->get_context(); + + collect_free_variables(this->get_body(), this->get_context(), outer_context, captures); + + // Register captured variables in the lambda's context so they can be referenced + for (const auto& capture : captures) + { + this->add_captured_variable(capture); + + // Also define the capture in the lambda's context so identifier lookup works + if (const auto outer_var = this->get_context()->lookup_variable(capture.name, true)) + { + this->get_context()->define_variable( + capture, + outer_var->get_type()->clone_ty(), + VisibilityModifier::PRIVATE + ); + } + } + // Avoid re-registering if already declared. // Named functions are looked up by their scoped name; anonymous functions are // tracked by the cached _llvm_function pointer (they have no stable string name). @@ -899,53 +949,35 @@ void IAstFunction::resolve_forward_references( ? llvm::Function::PrivateLinkage : llvm::Function::ExternalLinkage; - // --- Generic function instantiation - if (const auto& definition = this->get_function_definition(); - !definition->get_generic_instantiations().empty()) - { - for (const auto& [types, function] : definition->get_generic_instantiations()) - { - const auto overloaded_fn_name = get_overloaded_function_name(this->get_internalized_function_name(), types); - llvm::FunctionType* generic_function_type = this->get_generic_instantiated_llvm_function_type( - module, - captured_types, - types - ); - - function = llvm::Function::Create( - generic_function_type, - linkage, - overloaded_fn_name, - module - ); + const auto& definition = this->get_function_definition(); - if (this->is_anonymous()) - function->addFnAttr("stride.anonymous"); - } + if (definition->get_instantiations().empty()) + { + std::cerr << "Warning: No instantiations found for function '" << this->get_function_name() << + "'. This function will not be emitted in the LLVM IR.\n"; + return; } - else + for (const auto& [types, instantiation_fn, node] : definition->get_instantiations()) { - // Anonymous functions are created with a stable name prefix and a numeric ID - // so they are easily findable in the module without ambiguity. - const std::string llvm_function_name = this->get_internalized_function_name(); - - llvm::FunctionType* function_type = this->get_generic_instantiated_llvm_function_type(module, captured_types); + const auto overloaded_fn_name = get_overloaded_function_name(node->get_internalized_function_name(), types); + llvm::FunctionType* generic_function_type = node->get_generic_instantiated_llvm_function_type( + module, + captured_types, + types + ); - llvm::Function* function = llvm::Function::Create( - function_type, + instantiation_fn = llvm::Function::Create( + generic_function_type, linkage, - llvm_function_name, + overloaded_fn_name, module ); - if (this->is_anonymous()) - { - function->addFnAttr("stride.anonymous"); - definition->set_llvm_function(function); - } - } + if (node->is_anonymous()) + instantiation_fn->addFnAttr("stride.anonymous"); - this->_body->resolve_forward_references(module, builder); + this->_body->resolve_forward_references(module, builder); + } } llvm::FunctionType* IAstFunction::get_generic_instantiated_llvm_function_type( @@ -955,14 +987,28 @@ llvm::FunctionType* IAstFunction::get_generic_instantiated_llvm_function_type( ) const { std::vector base_parameter_types; - for (const auto& param : this->_parameters) + + if (generic_instantiation_types.empty()) { - const auto& resolved_generic_param_type = resolve_generics( - param->get_type(), - this->_generic_parameters, - generic_instantiation_types - ); - base_parameter_types.push_back(resolved_generic_param_type->get_llvm_type(module)); + for (const auto& param : this->_parameters) + { + if (llvm::Type* param_type = param->get_type()->get_llvm_type(module)) + { + base_parameter_types.push_back(param_type); + } + } + } + else + { + for (const auto& param : this->_parameters) + { + const auto& resolved_generic_param_type = resolve_generics( + param->get_type(), + this->_generic_parameters, + generic_instantiation_types + ); + base_parameter_types.push_back(resolved_generic_param_type->get_llvm_type(module)); + } } std::vector parameter_types; @@ -1052,11 +1098,10 @@ std::vector IAstFunction::get_function_overload_metadat std::vector metadata; - - for (const auto& [types, llvm_function] : definition->get_generic_instantiations()) + for (const auto& [types, llvm_function, node] : definition->get_instantiations()) { metadata.emplace_back( - get_overloaded_function_name(this->get_internalized_function_name(), types), + get_overloaded_function_name(node->get_internalized_function_name(), types), llvm_function ); } @@ -1097,7 +1142,7 @@ std::unique_ptr IAstFunction::clone() ); } -std::string AstFunctionDeclaration::to_string() +std::string IAstFunction::to_string() { std::string params; for (const auto& param : this->_parameters) @@ -1112,23 +1157,12 @@ std::string AstFunctionDeclaration::to_string() : this->get_body()->to_string(); return std::format( - "FunctionDeclaration(name: {}(internal: {}), params: [{}], body: {}{} -> {})", - this->get_function_name(), + "Function(name: {}(internal: {}), params: [{}], body: {}{} -> {})", + this->is_anonymous() ? "" : this->get_function_name(), this->get_internalized_function_name(), params, body_str, this->is_extern() ? " (extern)" : "", - this->get_return_type()->to_string()); -} - -std::string AstLambdaFunctionExpression::to_string() -{ - return "LambdaFunction"; -} - -std::string AstFunctionParameter::to_string() -{ - const auto name = this->get_name(); - auto type_str = this->get_type()->to_string(); - return std::format("{}({})", name, type_str); + this->get_return_type()->to_string() + ); } diff --git a/packages/compiler/src/ast/nodes/functions/function_parameters.cpp b/packages/compiler/src/ast/nodes/functions/function_parameters.cpp index afbe9d6f..c7b980e3 100644 --- a/packages/compiler/src/ast/nodes/functions/function_parameters.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_parameters.cpp @@ -1,4 +1,4 @@ -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index 62f05197..9b58f9eb 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -1,22 +1,19 @@ #include "ast/parsing_context.h" -#include "ast/type_inference.h" #include "ast/visitor.h" #include "ast/nodes/expression.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/nodes/types.h" using namespace stride::ast; -void FunctionVisitor::accept(IAstFunction* fn_declaration) +void FunctionVisitor::accept(IAstFunction* function) { - fn_declaration->set_type(infer_expression_type(fn_declaration)); - // Define parameters in the function's own context BEFORE traversing the body, // so that identifiers referencing params resolve correctly inside the body. - for (const auto& param : fn_declaration->get_parameters_ref()) + for (const auto& param : function->get_parameters_ref()) { const auto param_symbol = Symbol(param->get_source_fragment(), param->get_name()); - fn_declaration->get_context()->define_variable( + function->get_context()->define_variable( param_symbol, param->get_type()->clone_ty(), VisibilityModifier::PRIVATE @@ -24,10 +21,10 @@ void FunctionVisitor::accept(IAstFunction* fn_declaration) } // Forward declare the function in the symbol registry - fn_declaration->get_context()->define_function( - fn_declaration->get_symbol(), - fn_declaration->get_type()->clone_as(), - fn_declaration->get_visibility(), - fn_declaration->get_flags() + function->get_context()->define_function( + function->get_symbol(), + function->get_type()->clone_as(), + function->get_visibility(), + function->get_flags() ); -} +} \ No newline at end of file diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 5f4a2c5a..1e2d9a11 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -7,7 +7,7 @@ #include "ast/nodes/conditional_statement.h" #include "ast/nodes/expression.h" #include "ast/nodes/for_loop.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/nodes/import.h" #include "ast/nodes/module.h" #include "ast/nodes/package.h" diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index d89df6af..cf053524 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -4,7 +4,7 @@ #include "ast/casting.h" #include "ast/flags.h" #include "ast/parsing_context.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/nodes/literal_values.h" #include "ast/nodes/types.h" diff --git a/packages/compiler/tests/test_type_inference.cpp b/packages/compiler/tests/test_type_inference.cpp index b7409bb0..97a26a0a 100644 --- a/packages/compiler/tests/test_type_inference.cpp +++ b/packages/compiler/tests/test_type_inference.cpp @@ -2,7 +2,7 @@ #include "ast/nodes/literal_values.h" #include "ast/nodes/expression.h" #include "ast/nodes/types.h" -#include "ast/nodes/function_declaration.h" +#include "ast/nodes/function_definition.h" #include "ast/parsing_context.h" #include "ast/symbols.h" #include "errors.h" From 7172c23c15526be7d5aa7e355b8df354d59a77d2 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 15:39:29 +0100 Subject: [PATCH 07/31] refactor: rename `function_definition` to `function_declaration` and improve generic overload handling --- .../ast/definitions/function_definition.h | 11 ++- ...on_definition.h => function_declaration.h} | 2 +- packages/compiler/include/ast/nodes/types.h | 2 + packages/compiler/src/ast/ast.cpp | 2 +- .../src/ast/context/function_registry.cpp | 9 +- .../src/ast/context/parsing_context.cpp | 1 + packages/compiler/src/ast/nodes/blocks.cpp | 2 +- .../expressions/variable_declaration.cpp | 2 +- .../src/ast/nodes/functions/function_call.cpp | 3 +- ...efinition.cpp => function_declaration.cpp} | 99 ++++++++++++------- .../nodes/functions/function_parameters.cpp | 2 +- .../src/ast/nodes/types/function_type.cpp | 25 +++-- .../src/ast/traversal/function_visitor.cpp | 5 +- .../compiler/src/ast/traversal/traversal.cpp | 2 +- packages/compiler/src/ast/type_inference.cpp | 3 +- packages/compiler/src/compilation/program.cpp | 11 ++- 16 files changed, 115 insertions(+), 66 deletions(-) rename packages/compiler/include/ast/nodes/{function_definition.h => function_declaration.h} (99%) rename packages/compiler/src/ast/nodes/functions/{function_definition.cpp => function_declaration.cpp} (93%) diff --git a/packages/compiler/include/ast/definitions/function_definition.h b/packages/compiler/include/ast/definitions/function_definition.h index 03eb3044..13cfb2b4 100644 --- a/packages/compiler/include/ast/definitions/function_definition.h +++ b/packages/compiler/include/ast/definitions/function_definition.h @@ -1,7 +1,7 @@ #pragma once #include "ast/parsing_context.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include @@ -9,15 +9,16 @@ namespace stride::ast::definition { struct GenericFunctionOverload { + std::vector> generic_overload_types; mutable llvm::Function* function; - std::unique_ptr node; + mutable std::unique_ptr node; }; class FunctionDefinition : public IDefinition { std::unique_ptr _function_type; - std::vector _function_candidates{}; + std::vector _generic_overloads{}; int _flags; llvm::Function* _llvm_function = nullptr; @@ -60,9 +61,9 @@ namespace stride::ast::definition void add_generic_instantiation(GenericTypeList generic_overload_types); [[nodiscard]] - const std::vector& get_instantiations() const + const std::vector& get_generic_overloads() const { - return this->_function_candidates; + return this->_generic_overloads; } [[nodiscard]] diff --git a/packages/compiler/include/ast/nodes/function_definition.h b/packages/compiler/include/ast/nodes/function_declaration.h similarity index 99% rename from packages/compiler/include/ast/nodes/function_definition.h rename to packages/compiler/include/ast/nodes/function_declaration.h index 4605bfa5..60147496 100644 --- a/packages/compiler/include/ast/nodes/function_definition.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -242,7 +242,7 @@ namespace stride::ast std::string to_string() override; private: - llvm::FunctionType* get_generic_instantiated_llvm_function_type( + llvm::FunctionType* get_llvm_function_type( llvm::Module* module, std::vector captured_variables, const GenericTypeList& generic_instantiation_types = {} diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index 0fd2dc20..d8267cbd 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -34,6 +34,8 @@ namespace stride::ast using EnumMemberValueTy = std::unique_ptr; using EnumMemberPair = std::pair; + using FunctionParameters = std::vector>>; + enum class PrimitiveType { INT8, diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp index d06f8911..7d49dda6 100644 --- a/packages/compiler/src/ast/ast.cpp +++ b/packages/compiler/src/ast/ast.cpp @@ -7,7 +7,7 @@ #include "ast/nodes/conditional_statement.h" #include "ast/nodes/control_flow_statements.h" #include "ast/nodes/for_loop.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/nodes/import.h" #include "ast/nodes/module.h" #include "ast/nodes/package.h" diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 37e31317..7d85d122 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -73,7 +73,6 @@ bool FunctionDefinition::matches_type_signature( other_params, signature->get_generic_parameter_names().size() ); - } bool FunctionDefinition::matches_parameter_signature( @@ -165,7 +164,7 @@ bool ParsingContext::is_function_defined_globally( bool FunctionDefinition::has_generic_instantiation(const std::vector>& generic_types) const { - for (const auto& [instantiated_generic_types, function, declaration] : this->_function_candidates) + for (const auto& [instantiated_generic_types, llvm_function, node] : this->_generic_overloads) { bool all_equal = true; for (size_t i = 0; i < generic_types.size(); i++) @@ -181,6 +180,7 @@ bool FunctionDefinition::has_generic_instantiation(const std::vector_function_candidates.push_back({ std::move(generic_overload_types) }); + // All other fields will be populated in later stages + this->_generic_overloads.push_back({ + std::move(generic_overload_types) + }); } diff --git a/packages/compiler/src/ast/context/parsing_context.cpp b/packages/compiler/src/ast/context/parsing_context.cpp index 81daeb2d..0e8b05b3 100644 --- a/packages/compiler/src/ast/context/parsing_context.cpp +++ b/packages/compiler/src/ast/context/parsing_context.cpp @@ -2,6 +2,7 @@ #include "errors.h" #include "ast/symbols.h" +#include "ast/definitions/function_definition.h" #include #include diff --git a/packages/compiler/src/ast/nodes/blocks.cpp b/packages/compiler/src/ast/nodes/blocks.cpp index 323d8cde..e29fcce4 100644 --- a/packages/compiler/src/ast/nodes/blocks.cpp +++ b/packages/compiler/src/ast/nodes/blocks.cpp @@ -1,7 +1,7 @@ #include "ast/nodes/blocks.h" #include "ast/ast.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/tokens/token_set.h" #include diff --git a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp index 556f39cf..df747aaa 100644 --- a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp +++ b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp @@ -5,7 +5,7 @@ #include "ast/optionals.h" #include "ast/parsing_context.h" #include "ast/nodes/expression.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/nodes/literal_values.h" #include "ast/tokens/token_set.h" diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/function_call.cpp index 52eca1f9..d1fc42fe 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_call.cpp @@ -13,7 +13,7 @@ #include "ast/tokens/token_set.h" #include -#include +#include #include #include #include @@ -659,7 +659,6 @@ void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRB definition->add_generic_instantiation( copy_generic_type_list(this->_generic_type_arguments) ); - // nice, got that over with } for (const auto& arg : this->_arguments) diff --git a/packages/compiler/src/ast/nodes/functions/function_definition.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp similarity index 93% rename from packages/compiler/src/ast/nodes/functions/function_definition.cpp rename to packages/compiler/src/ast/nodes/functions/function_declaration.cpp index 2185265b..f9e0e20f 100644 --- a/packages/compiler/src/ast/nodes/functions/function_definition.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp @@ -1,4 +1,4 @@ -#include "ast/nodes/function_definition.h" +#include "ast/definitions/function_definition.h" #include "errors.h" #include "ast/casting.h" @@ -6,17 +6,16 @@ #include "ast/modifiers.h" #include "ast/parsing_context.h" #include "ast/symbols.h" -#include "ast/definitions/function_definition.h" #include "ast/nodes/blocks.h" #include "ast/nodes/conditional_statement.h" #include "ast/nodes/expression.h" #include "ast/nodes/for_loop.h" +#include "ast/nodes/function_declaration.h" #include "ast/nodes/return_statement.h" #include "ast/nodes/while_loop.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" -#include #include #include #include @@ -550,19 +549,19 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) { if (return_stmt->get_return_expression().has_value()) { - throw stride::parsing_error( - stride::ErrorType::TYPE_ERROR, + throw parsing_error( + ErrorType::TYPE_ERROR, std::format( "{} has return type 'void' and cannot return a value.", candidate->is_anonymous() ? "Anonymous function" : std::format("Function '{}'", candidate->get_function_name())), { - stride::ErrorSourceReference( + ErrorSourceReference( "unexpected return value", return_stmt->get_source_fragment() ), - stride::ErrorSourceReference( + ErrorSourceReference( "Function returning void type", candidate->get_source_fragment() ) @@ -578,16 +577,16 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) { if (cast_type(ret_ty)) { - throw stride::parsing_error( - stride::ErrorType::TYPE_ERROR, + throw parsing_error( + ErrorType::TYPE_ERROR, std::format( "Function '{}' returns a struct type, but no return statement is present.", candidate->get_function_name()), candidate->get_source_fragment()); } - throw stride::parsing_error( - stride::ErrorType::COMPILATION_ERROR, + throw parsing_error( + ErrorType::COMPILATION_ERROR, std::format( "{} is missing a return statement.", candidate->is_anonymous() @@ -603,8 +602,8 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) { if (!ret_ty->is_void_ty()) { - throw stride::parsing_error( - stride::ErrorType::TYPE_ERROR, + throw parsing_error( + ErrorType::TYPE_ERROR, std::format( "Function '{}' returns a value of type '{}', but no return statement is present.", candidate->is_anonymous() ? "" : candidate->get_function_name(), @@ -619,7 +618,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) !ret_expr->get_type()->equals(ret_ty) && !ret_expr->get_type()->is_assignable_to(ret_ty)) { - const auto error_fragment = stride::ErrorSourceReference( + const auto error_fragment = ErrorSourceReference( std::format( "expected {}{}", candidate->get_return_type()->is_primitive() @@ -631,8 +630,8 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) ret_expr->get_source_fragment() ); - throw stride::parsing_error( - stride::ErrorType::TYPE_ERROR, + throw parsing_error( + ErrorType::TYPE_ERROR, std::format( "Function '{}' expected a return type of '{}', but received '{}'.", candidate->is_anonymous() ? "" : candidate->get_function_name(), @@ -650,7 +649,11 @@ void IAstFunction::validate() if (this->is_extern()) return; - std::vector> validatable_candidates; + if (!this->is_generic()) + { + validate_candidate(this); + return; + } // // For generic functions, we create a new copy of the function with all parameters resolved, and do validation @@ -659,7 +662,7 @@ void IAstFunction::validate() // // create a copy of this function with the parameters instantiated for (const auto definition = this->get_function_definition(); - const auto& [types, function, node] : definition->get_instantiations()) + const auto& [types, function, node] : definition->get_generic_overloads()) { auto instantiated_return_ty = resolve_generics( this->_annotated_return_type.get(), @@ -949,24 +952,51 @@ void IAstFunction::resolve_forward_references( ? llvm::Function::PrivateLinkage : llvm::Function::ExternalLinkage; - const auto& definition = this->get_function_definition(); - - if (definition->get_instantiations().empty()) - { - std::cerr << "Warning: No instantiations found for function '" << this->get_function_name() << - "'. This function will not be emitted in the LLVM IR.\n"; - return; - } - for (const auto& [types, instantiation_fn, node] : definition->get_instantiations()) + for (const auto& definition = this->get_function_definition(); + const auto& [types, llvm_function, node] : definition->get_generic_overloads()) { + auto instantiated_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + types + ); + + std::vector> instantiated_function_params; + instantiated_function_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + instantiated_function_params.push_back( + std::make_unique( + param->get_source_fragment(), + param->get_context(), + param->get_name(), + resolve_generics(param->get_type(), this->_generic_parameters, types) + ) + ); + } + + auto candidate_node = std::make_unique( + this->get_context(), + this->_symbol, + std::move(instantiated_function_params), + this->_body->clone_as(), + std::move(instantiated_return_ty), + this->get_visibility(), + this->_flags, + this->get_generic_parameters() + ); + + node = std::move(candidate_node); + const auto overloaded_fn_name = get_overloaded_function_name(node->get_internalized_function_name(), types); - llvm::FunctionType* generic_function_type = node->get_generic_instantiated_llvm_function_type( + llvm::FunctionType* generic_function_type = node->get_llvm_function_type( module, captured_types, types ); - instantiation_fn = llvm::Function::Create( + llvm_function = llvm::Function::Create( generic_function_type, linkage, overloaded_fn_name, @@ -974,13 +1004,13 @@ void IAstFunction::resolve_forward_references( ); if (node->is_anonymous()) - instantiation_fn->addFnAttr("stride.anonymous"); + llvm_function->addFnAttr("stride.anonymous"); this->_body->resolve_forward_references(module, builder); } } -llvm::FunctionType* IAstFunction::get_generic_instantiated_llvm_function_type( +llvm::FunctionType* IAstFunction::get_llvm_function_type( llvm::Module* module, std::vector captured_variables, const GenericTypeList& generic_instantiation_types @@ -1007,7 +1037,10 @@ llvm::FunctionType* IAstFunction::get_generic_instantiated_llvm_function_type( this->_generic_parameters, generic_instantiation_types ); - base_parameter_types.push_back(resolved_generic_param_type->get_llvm_type(module)); + if (llvm::Type* param_type = resolved_generic_param_type->get_llvm_type(module)) + { + base_parameter_types.push_back(param_type); + } } } @@ -1098,7 +1131,7 @@ std::vector IAstFunction::get_function_overload_metadat std::vector metadata; - for (const auto& [types, llvm_function, node] : definition->get_instantiations()) + for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) { metadata.emplace_back( get_overloaded_function_name(node->get_internalized_function_name(), types), diff --git a/packages/compiler/src/ast/nodes/functions/function_parameters.cpp b/packages/compiler/src/ast/nodes/functions/function_parameters.cpp index c7b980e3..afbe9d6f 100644 --- a/packages/compiler/src/ast/nodes/functions/function_parameters.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_parameters.cpp @@ -1,4 +1,4 @@ -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" diff --git a/packages/compiler/src/ast/nodes/types/function_type.cpp b/packages/compiler/src/ast/nodes/types/function_type.cpp index cb34ea32..7865a74f 100644 --- a/packages/compiler/src/ast/nodes/types/function_type.cpp +++ b/packages/compiler/src/ast/nodes/types/function_type.cpp @@ -2,6 +2,7 @@ #include "ast/nodes/types.h" #include "ast/tokens/token_set.h" +#include #include using namespace stride::ast; @@ -38,13 +39,11 @@ std::optional> stride::ast::parse_function_type_option while (set.has_next() && !set.peek_next_eq(TokenType::RPAREN)) { - parameters.push_back( - parse_type( - context, - set, - { "Expected parameter type", options.type_name, flags } - ) - ); + parameters.emplace_back(parse_type( + context, + set, + { "Expected parameter type", options.type_name, flags } + )); if (set.peek_next_eq(TokenType::RPAREN)) { break; @@ -88,9 +87,9 @@ std::unique_ptr AstFunctionType::clone() parameters.reserve(this->_parameters.size()); generic_parameters_clone.reserve(this->_generic_param_names.size()); - for (const auto& p : this->_parameters) + for (const auto& param_ty : this->_parameters) { - parameters.push_back(p->clone_ty()); + parameters.emplace_back(param_ty->clone_ty()); } generic_parameters_clone.insert( @@ -151,9 +150,9 @@ llvm::Type* AstFunctionType::get_llvm_type_impl(llvm::Module* module) std::vector param_types; param_types.reserve(this->_parameters.size()); - for (const auto& param : this->_parameters) + for (const auto& param_ty : this->_parameters) { - param_types.push_back(param->get_llvm_type(module)); + param_types.push_back(param_ty->get_llvm_type(module)); } llvm::Type* ret_type = this->_return_type->get_llvm_type(module); @@ -169,8 +168,8 @@ std::string AstFunctionType::get_type_name() std::vector param_strings; param_strings.reserve(this->_parameters.size()); - for (const auto& p : this->_parameters) - param_strings.push_back(p->to_string()); + for (const auto& param_ty : this->_parameters) + param_strings.push_back(param_ty->to_string()); return std::format( "({}) -> {}", diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index 9b58f9eb..0b3aac48 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -1,13 +1,16 @@ #include "ast/parsing_context.h" +#include "ast/type_inference.h" #include "ast/visitor.h" #include "ast/nodes/expression.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/nodes/types.h" using namespace stride::ast; void FunctionVisitor::accept(IAstFunction* function) { + function->set_type(infer_function_type(function)); + // Define parameters in the function's own context BEFORE traversing the body, // so that identifiers referencing params resolve correctly inside the body. for (const auto& param : function->get_parameters_ref()) diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 1e2d9a11..5f4a2c5a 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -7,7 +7,7 @@ #include "ast/nodes/conditional_statement.h" #include "ast/nodes/expression.h" #include "ast/nodes/for_loop.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/nodes/import.h" #include "ast/nodes/module.h" #include "ast/nodes/package.h" diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index cf053524..8ddebfa6 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -4,7 +4,8 @@ #include "ast/casting.h" #include "ast/flags.h" #include "ast/parsing_context.h" -#include "ast/nodes/function_definition.h" +#include "ast/definitions/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/nodes/literal_values.h" #include "ast/nodes/types.h" diff --git a/packages/compiler/src/compilation/program.cpp b/packages/compiler/src/compilation/program.cpp index 6ddef586..3f754ef0 100644 --- a/packages/compiler/src/compilation/program.cpp +++ b/packages/compiler/src/compilation/program.cpp @@ -52,25 +52,32 @@ std::unique_ptr Program::prepare_module( ast::FunctionVisitor function_visitor; ast::ImportVisitor import_visitor; + /// --- First step - Cross-file symbol registration (imports and function signatures) for (const auto& [file_name, node] : this->_ast->get_files()) { import_visitor.set_current_file_name(file_name); traverser.visit_block(&import_visitor, node.get()); - traverser.visit_block(&function_visitor, node.get()); + traverser.visit_block(&function_visitor, node.get()); // Ensures functions are defined in our symbol table } import_visitor.cross_register_symbols(this->_ast.get()); + /// --- Second step - Type resolution and symbol forward declarations for (const auto& node : this->_ast->get_files() | std::views::values) { runtime::register_runtime_symbols(node->get_context()); traverser.visit_block(&type_visitor, node.get()); - node->validate(); node->resolve_forward_references( module.get(), &builder ); + } + + /// --- Final step - LLVM IR validation and code generation + for (const auto& node : this->_ast->get_files() | std::views::values) + { + node->validate(); node->codegen(module.get(), &builder); } From a963b5f7a0b2516d84797fb0820e30f13daa85c8 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 15:51:43 +0100 Subject: [PATCH 08/31] refactor: improve function name handling and streamline generic overload resolution --- .../include/ast/nodes/function_declaration.h | 4 +- .../src/ast/context/function_registry.cpp | 2 +- .../src/ast/nodes/functions/function_call.cpp | 6 +- .../nodes/functions/function_declaration.cpp | 70 +++++++++++++------ 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index 60147496..a2e3bb4c 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -117,7 +117,7 @@ namespace stride::ast _flags(flags) {} [[nodiscard]] - const std::string& get_function_name() const + const std::string& get_plain_function_name() const { return this->_symbol.name; } @@ -126,7 +126,7 @@ namespace stride::ast std::vector> get_parameter_types() const; [[nodiscard]] - const std::string& get_internalized_function_name() const + const std::string& get_registered_function_name() const { return this->_symbol.internal_name; } diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 7d85d122..fa7ae86f 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -95,7 +95,7 @@ const const auto& self_params = this->_function_type->get_parameter_types(); - if ((this->get_flags() & SRFLAG_FN_TYPE_VARIADIC) != 0) + if (this->is_variadic()) { if (other_parameter_types.size() < self_params.size()) return false; diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/function_call.cpp index d1fc42fe..091b8500 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_call.cpp @@ -698,12 +698,8 @@ FunctionDefinition* AstFunctionCall::get_function_definition() if (this->_definition != nullptr) return this->_definition; - const auto& internal_function_name = get_overloaded_function_name( - this->get_scoped_function_name(), - this->get_generic_type_arguments()); - if (const auto def = this->get_context()->get_function_definition( - internal_function_name, + this->get_scoped_function_name(), this->get_argument_types(), this->get_generic_type_arguments().size() ); diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp index f9e0e20f..5f1574a0 100644 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp @@ -38,6 +38,7 @@ std::unique_ptr stride::ast::parse_fn_declaration( ) { int function_flags = 0; + const auto reference_token = set.peek_next(); if (set.peek_next_eq(TokenType::KEYWORD_EXTERN)) { set.next(); @@ -50,7 +51,7 @@ std::unique_ptr stride::ast::parse_fn_declaration( function_flags |= SRFLAG_FN_TYPE_ASYNC; } - auto reference_token = set.expect(TokenType::KEYWORD_FN); + set.expect(TokenType::KEYWORD_FN); // Here we expect to receive the function name const auto fn_name_tok = set.expect(TokenType::IDENTIFIER, "Expected function name"); @@ -87,7 +88,9 @@ std::unique_ptr stride::ast::parse_fn_declaration( // Return type doesn't have the same flags as the function, hence NONE auto return_type = parse_type(context, set, { "Expected return type in function header" }); - const auto& position = reference_token.get_source_fragment(); + const auto& position = SourceFragment::join( + reference_token.get_source_fragment(), + return_type->get_source_fragment()); auto sym_function_name = Symbol(position, context->get_name(), fn_name); std::unique_ptr body = nullptr; @@ -555,7 +558,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) "{} has return type 'void' and cannot return a value.", candidate->is_anonymous() ? "Anonymous function" - : std::format("Function '{}'", candidate->get_function_name())), + : std::format("Function '{}'", candidate->get_plain_function_name())), { ErrorSourceReference( "unexpected return value", @@ -581,7 +584,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) ErrorType::TYPE_ERROR, std::format( "Function '{}' returns a struct type, but no return statement is present.", - candidate->get_function_name()), + candidate->get_plain_function_name()), candidate->get_source_fragment()); } @@ -591,7 +594,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) "{} is missing a return statement.", candidate->is_anonymous() ? "Anonymous function" - : std::format("Function '{}'", candidate->get_function_name())), + : std::format("Function '{}'", candidate->get_plain_function_name())), candidate->get_source_fragment() ); } @@ -606,7 +609,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) ErrorType::TYPE_ERROR, std::format( "Function '{}' returns a value of type '{}', but no return statement is present.", - candidate->is_anonymous() ? "" : candidate->get_function_name(), + candidate->is_anonymous() ? "" : candidate->get_plain_function_name(), ret_ty->to_string()), return_stmt->get_source_fragment() ); @@ -634,7 +637,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) ErrorType::TYPE_ERROR, std::format( "Function '{}' expected a return type of '{}', but received '{}'.", - candidate->is_anonymous() ? "" : candidate->get_function_name(), + candidate->is_anonymous() ? "" : candidate->get_plain_function_name(), ret_ty->get_type_name(), ret_expr->get_type()->get_type_name()), { error_fragment } @@ -830,7 +833,7 @@ llvm::Value* IAstFunction::codegen( ErrorType::COMPILATION_ERROR, std::format( "Function '{}' is missing a return path.", - this->get_function_name() + this->get_plain_function_name() ), this->get_source_fragment() ); @@ -843,7 +846,7 @@ llvm::Value* IAstFunction::codegen( module->print(llvm::errs(), nullptr); throw parsing_error( ErrorType::COMPILATION_ERROR, - "LLVM Function Verification Failed for: " + this->get_function_name(), + "LLVM Function Verification Failed for: " + this->get_plain_function_name(), this->get_source_fragment() ); } @@ -948,12 +951,33 @@ void IAstFunction::resolve_forward_references( } } + const auto& definition = this->get_function_definition(); + const auto linkage = this->_visibility == VisibilityModifier::PRIVATE ? llvm::Function::PrivateLinkage : llvm::Function::ExternalLinkage; - for (const auto& definition = this->get_function_definition(); - const auto& [types, llvm_function, node] : definition->get_generic_overloads()) + if (!this->is_generic()) + { + const auto overloaded_fn_name = this->get_registered_function_name(); + llvm::FunctionType* generic_function_type = this->get_llvm_function_type( + module, + captured_types, + {} + ); + + definition->set_llvm_function(llvm::Function::Create( + generic_function_type, + linkage, + overloaded_fn_name, + module + )); + this->_body->resolve_forward_references(module, builder); + + return; + } + + for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) { auto instantiated_return_ty = resolve_generics( this->_annotated_return_type.get(), @@ -989,8 +1013,8 @@ void IAstFunction::resolve_forward_references( node = std::move(candidate_node); - const auto overloaded_fn_name = get_overloaded_function_name(node->get_internalized_function_name(), types); - llvm::FunctionType* generic_function_type = node->get_llvm_function_type( + const auto overloaded_fn_name = get_overloaded_function_name(this->get_registered_function_name(), types); + llvm::FunctionType* generic_function_type = this->get_llvm_function_type( module, captured_types, types @@ -1003,7 +1027,7 @@ void IAstFunction::resolve_forward_references( module ); - if (node->is_anonymous()) + if (this->is_anonymous()) llvm_function->addFnAttr("stride.anonymous"); this->_body->resolve_forward_references(module, builder); @@ -1057,21 +1081,21 @@ llvm::FunctionType* IAstFunction::get_llvm_function_type( { throw parsing_error( ErrorType::COMPILATION_ERROR, - "Could not get LLVM return type for function: " + this->get_function_name(), + "Could not get LLVM return type for function: " + this->get_plain_function_name(), this->get_source_fragment() ); } const auto& llvm_function_ty = llvm::FunctionType::get(return_type, parameter_types, this->is_variadic());; - if (const auto fn = module->getFunction(this->get_internalized_function_name()); + if (const auto fn = module->getFunction(this->get_registered_function_name()); fn != nullptr && fn->getFunctionType() != llvm_function_ty) { throw parsing_error( ErrorType::COMPILATION_ERROR, std::format( "Function symbol '{}' already exists with a different signature", - this->get_internalized_function_name() + this->get_registered_function_name() ), this->get_source_fragment() ); @@ -1099,7 +1123,7 @@ FunctionDefinition* IAstFunction::get_function_definition() return this->_function_definition; const auto& definition = this->get_context()->get_function_definition( - this->get_function_name(), + this->get_registered_function_name(), this->get_parameter_types(), this->get_generic_parameters().size() ); @@ -1108,7 +1132,7 @@ FunctionDefinition* IAstFunction::get_function_definition() { throw parsing_error( ErrorType::REFERENCE_ERROR, - std::format("Function definition for '{}' not found in context", this->get_internalized_function_name()), + std::format("Function definition for '{}' not found in context", this->get_registered_function_name()), this->get_source_fragment() ); } @@ -1125,7 +1149,7 @@ std::vector IAstFunction::get_function_overload_metadat if (!definition->get_type()->is_generic()) { return { - GenericFunctionMetadata{ this->get_internalized_function_name(), definition->get_llvm_function() } + GenericFunctionMetadata{ this->get_registered_function_name(), definition->get_llvm_function() } }; } @@ -1134,7 +1158,7 @@ std::vector IAstFunction::get_function_overload_metadat for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) { metadata.emplace_back( - get_overloaded_function_name(node->get_internalized_function_name(), types), + get_overloaded_function_name(node->get_registered_function_name(), types), llvm_function ); } @@ -1191,8 +1215,8 @@ std::string IAstFunction::to_string() return std::format( "Function(name: {}(internal: {}), params: [{}], body: {}{} -> {})", - this->is_anonymous() ? "" : this->get_function_name(), - this->get_internalized_function_name(), + this->is_anonymous() ? "" : this->get_plain_function_name(), + this->get_registered_function_name(), params, body_str, this->is_extern() ? " (extern)" : "", From 13253b27a5396a7f4a8215e10a2b82c2c916d302 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 16:11:20 +0100 Subject: [PATCH 09/31] refactor: add `set_type` and `set_visibility` methods, enable variable overwrite in parsing context --- .../compiler/include/ast/parsing_context.h | 16 +++++++- .../src/ast/context/field_registry.cpp | 41 +++++++++++++++++-- .../nodes/functions/function_declaration.cpp | 15 +++++-- .../src/ast/traversal/function_visitor.cpp | 2 +- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 096651db..302ad845 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -78,6 +78,11 @@ namespace stride::ast [[nodiscard]] virtual std::unique_ptr clone() const = 0; + + void set_visibility(const VisibilityModifier visibility) + { + this->_visibility = visibility; + } }; class TypeDefinition @@ -157,6 +162,11 @@ namespace stride::ast { return std::make_unique(get_symbol(), _type->clone_ty(), get_visibility()); } + + void set_type(std::unique_ptr type) + { + this->_type = std::move(type); + } }; } // namespace definition @@ -300,13 +310,15 @@ namespace stride::ast void define_variable( Symbol variable_sym, std::unique_ptr type, - VisibilityModifier visibility + VisibilityModifier visibility, + bool overwrite = false ); void define_variable_globally( Symbol variable_symbol, std::unique_ptr type, - VisibilityModifier visibility + VisibilityModifier visibility, + bool overwrite = false ) const; [[nodiscard]] diff --git a/packages/compiler/src/ast/context/field_registry.cpp b/packages/compiler/src/ast/context/field_registry.cpp index 1711943c..75bc677e 100644 --- a/packages/compiler/src/ast/context/field_registry.cpp +++ b/packages/compiler/src/ast/context/field_registry.cpp @@ -59,11 +59,28 @@ bool ParsingContext::is_field_defined_globally( void ParsingContext::define_variable_globally( Symbol variable_symbol, std::unique_ptr type, - VisibilityModifier visibility + VisibilityModifier visibility, + const bool overwrite ) const { if (is_field_defined_globally(variable_symbol.internal_name)) { + if (overwrite) + { + for (auto& symbol_def : this->_symbols) + { + if (auto* var_def = dynamic_cast(symbol_def.get()); + var_def != nullptr && + var_def->get_internal_symbol_name() == variable_symbol.internal_name) + { + var_def->set_type(std::move(type)); + var_def->set_visibility(visibility); + return; + } + } + return; + } + throw parsing_error( ErrorType::SEMANTIC_ERROR, std::format("Variable '{}' is already defined in global scope", variable_symbol.name), @@ -84,7 +101,8 @@ void ParsingContext::define_variable_globally( void ParsingContext::define_variable( Symbol variable_sym, std::unique_ptr type, - VisibilityModifier visibility + VisibilityModifier visibility, + const bool overwrite ) { if (this->is_global_scope()) @@ -92,13 +110,30 @@ void ParsingContext::define_variable( this->define_variable_globally( std::move(variable_sym), std::move(type), - visibility + visibility, + overwrite ); return; } if (is_field_defined_in_scope(variable_sym.internal_name)) { + if (overwrite) + { + for (auto& symbol_def : this->_symbols) + { + if (auto* var_def = dynamic_cast(symbol_def.get()); + var_def != nullptr && + var_def->get_internal_symbol_name() == variable_sym.internal_name) + { + var_def->set_type(std::move(type)); + var_def->set_visibility(visibility); + return; + } + } + return; + } + throw parsing_error( ErrorType::SEMANTIC_ERROR, std::format("Variable '{}' is already defined in this scope", variable_sym.name), diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp index 5f1574a0..e3cdf639 100644 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp @@ -1000,7 +1000,7 @@ void IAstFunction::resolve_forward_references( ); } - auto candidate_node = std::make_unique( + node = std::make_unique( this->get_context(), this->_symbol, std::move(instantiated_function_params), @@ -1011,8 +1011,6 @@ void IAstFunction::resolve_forward_references( this->get_generic_parameters() ); - node = std::move(candidate_node); - const auto overloaded_fn_name = get_overloaded_function_name(this->get_registered_function_name(), types); llvm::FunctionType* generic_function_type = this->get_llvm_function_type( module, @@ -1027,6 +1025,17 @@ void IAstFunction::resolve_forward_references( module ); + for (const auto ¶m : instantiated_function_params) + { + const auto param_symbol = Symbol(param->get_source_fragment(), param->get_name()); + this->get_context()->define_variable( + param_symbol, + param->get_type()->clone_ty(), + VisibilityModifier::PRIVATE, + true + ); + } + if (this->is_anonymous()) llvm_function->addFnAttr("stride.anonymous"); diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index 0b3aac48..a8764346 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -30,4 +30,4 @@ void FunctionVisitor::accept(IAstFunction* function) function->get_visibility(), function->get_flags() ); -} \ No newline at end of file +} From 7e41c4139f199e0c5b33714446825275161a1646 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 16:49:04 +0100 Subject: [PATCH 10/31] refactor: update const qualifiers for variable definitions and improve function handling in parser - Removed `const` qualifiers from certain `FieldDefinition` pointers to allow modifications. - Added `AstFunctionCall` and `AstReturnStatement` classes in the AST structure. - Introduced utilities for collecting free variables, return statements, and resolving forward references. - Added LLVM code generation logic for function declarations and calls. --- .../compiler/include/ast/nodes/expression.h | 2 +- .../include/ast/nodes/function_declaration.h | 12 +- .../compiler/include/ast/nodes/traversal.h | 3 + .../compiler/include/ast/parsing_context.h | 4 +- packages/compiler/include/ast/visitor.h | 1 + .../src/ast/context/field_registry.cpp | 8 +- .../src/ast/nodes/expressions/identifier.cpp | 4 +- .../nodes/functions/call/function_call.cpp | 348 +++++ .../function_call_codegen.cpp} | 350 +---- .../declaration/function_declaration.cpp | 342 +++++ .../function_declaration_codegen.cpp | 206 +++ .../function_declaration_forward_refs.cpp | 497 +++++++ .../function_declaration_validation.cpp | 217 +++ .../nodes/functions/function_declaration.cpp | 1234 ----------------- packages/compiler/src/ast/type_inference.cpp | 2 +- packages/compiler/src/compilation/program.cpp | 3 + 16 files changed, 1641 insertions(+), 1592 deletions(-) create mode 100644 packages/compiler/src/ast/nodes/functions/call/function_call.cpp rename packages/compiler/src/ast/nodes/functions/{function_call.cpp => call/function_call_codegen.cpp} (62%) create mode 100644 packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp create mode 100644 packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp create mode 100644 packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp create mode 100644 packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp delete mode 100644 packages/compiler/src/ast/nodes/functions/function_declaration.cpp diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index 96e4ee3a..fd8241b6 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -188,7 +188,7 @@ namespace stride::ast _symbol(std::move(symbol)) {} [[nodiscard]] - std::optional get_definition() const; + std::optional get_definition() const; [[nodiscard]] const std::string& get_name() const diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index a2e3bb4c..8515f9a6 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -15,6 +15,7 @@ namespace llvm namespace stride::ast { + class AstReturnStatement; #define MAX_FUNCTION_PARAMETERS (32) /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * @@ -135,7 +136,7 @@ namespace stride::ast /// function is defined with generic parameters, there will be several overloads generated /// for each generic instantiation. This function returns the internalized name of each overload. [[nodiscard]] - std::vector get_function_overload_metadata(); + std::vector get_generic_function_metadata(); [[nodiscard]] AstBlock* get_body() override @@ -249,6 +250,15 @@ namespace stride::ast ) const; static void validate_candidate(IAstFunction* candidate); + + static void collect_free_variables( + IAstNode* node, + const std::shared_ptr& lambda_context, + const std::shared_ptr& outer_context, + std::vector& captures + ); + + static std::vector collect_return_statements(const AstBlock* body); }; class AstFunctionDeclaration diff --git a/packages/compiler/include/ast/nodes/traversal.h b/packages/compiler/include/ast/nodes/traversal.h index 540ea8a7..438e43ea 100644 --- a/packages/compiler/include/ast/nodes/traversal.h +++ b/packages/compiler/include/ast/nodes/traversal.h @@ -2,6 +2,7 @@ namespace stride::ast { + class AstFunctionCall; class AstPackage; class AstImport; class IAstNode; @@ -30,6 +31,8 @@ namespace stride::ast virtual void accept(AstImport* node) {} virtual void accept(AstPackage* node) {} + + virtual void accept(AstFunctionCall* function_call) {} }; /// Traverses an AST tree and invokes an IVisitor for each expression node. diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 302ad845..2892de2c 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -246,7 +246,7 @@ namespace stride::ast } [[nodiscard]] - const definition::FieldDefinition* get_variable_def( + definition::FieldDefinition* get_variable_def( const std::string& variable_name, bool use_raw_name = false ) const; @@ -284,7 +284,7 @@ namespace stride::ast } [[nodiscard]] - const definition::FieldDefinition* lookup_variable( + definition::FieldDefinition* lookup_variable( const std::string& name, bool use_raw_name = false ) const; diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h index 4779b60a..11b1e6d8 100644 --- a/packages/compiler/include/ast/visitor.h +++ b/packages/compiler/include/ast/visitor.h @@ -8,6 +8,7 @@ namespace stride::ast { + class AstFunctionCall; class AstImport; class AstPackage; class AstFunctionDeclaration; diff --git a/packages/compiler/src/ast/context/field_registry.cpp b/packages/compiler/src/ast/context/field_registry.cpp index 75bc677e..ee5307ef 100644 --- a/packages/compiler/src/ast/context/field_registry.cpp +++ b/packages/compiler/src/ast/context/field_registry.cpp @@ -5,14 +5,14 @@ using namespace stride::ast; -const definition::FieldDefinition* ParsingContext::get_variable_def( +definition::FieldDefinition* ParsingContext::get_variable_def( const std::string& variable_name, const bool use_raw_name ) const { for (const auto& symbol_def : this->_symbols) { - if (const auto* field_definition = dynamic_cast( + if (auto* field_definition = dynamic_cast( symbol_def.get())) { if (field_definition->get_internal_symbol_name() == variable_name @@ -149,7 +149,7 @@ void ParsingContext::define_variable( ); } -const definition::FieldDefinition* ParsingContext::lookup_variable( +definition::FieldDefinition* ParsingContext::lookup_variable( const std::string& name, const bool use_raw_name ) @@ -158,7 +158,7 @@ const auto current = this; while (current != nullptr) { - if (const auto def = current->get_variable_def(name, use_raw_name)) + if (auto def = current->get_variable_def(name, use_raw_name)) { return def; } diff --git a/packages/compiler/src/ast/nodes/expressions/identifier.cpp b/packages/compiler/src/ast/nodes/expressions/identifier.cpp index cf14a845..e3f476b3 100644 --- a/packages/compiler/src/ast/nodes/expressions/identifier.cpp +++ b/packages/compiler/src/ast/nodes/expressions/identifier.cpp @@ -8,11 +8,11 @@ using namespace stride::ast; -std::optional AstIdentifier::get_definition() const +std::optional AstIdentifier::get_definition() const { const std::string internal_name = this->get_scoped_name(); - if (const auto var_def = this->get_context()->lookup_variable(internal_name, false)) + if (auto var_def = this->get_context()->lookup_variable(internal_name, false)) { return var_def; } diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp new file mode 100644 index 00000000..5a13ce23 --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -0,0 +1,348 @@ +#include "errors.h" +#include "formatting.h" +#include "ast/casting.h" +#include "ast/flags.h" +#include "ast/parsing_context.h" +#include "ast/symbols.h" +#include "ast/definitions/function_definition.h" +#include "ast/nodes/blocks.h" +#include "ast/nodes/expression.h" +#include "ast/nodes/types.h" +#include "ast/tokens/token_set.h" + +#include +#include +#include +#include + +using namespace stride::ast; +using namespace stride::ast::definition; + +std::unique_ptr stride::ast::parse_function_call( + const std::shared_ptr& context, + AstIdentifier* identifier, + TokenSet& set +) +{ + const auto reference_token = set.peek(-1); + auto generic_types = parse_generic_type_arguments(context, set); + auto function_parameter_set = collect_parenthesized_block(set); + + ExpressionList function_arg_nodes; + + int function_call_flags = SRFLAG_NONE; + + // Parsing function parameter values + if (function_parameter_set.has_value()) + { + auto subset = function_parameter_set.value(); + + if (auto initial_arg = parse_inline_expression(context, subset)) + { + function_arg_nodes.push_back(std::move(initial_arg)); + + // Consume next parameters + while (subset.has_next()) + { + const auto preceding = subset.expect( + TokenType::COMMA, + "Expected ',' between function arguments" + ); + + auto function_argument = parse_inline_expression(context, subset); + + if (!function_argument) + { + // Since the RParen is already consumed, we have to manually extract its + // position with the following assumption It's possible this yields END_OF_FILE + const auto len = + set.peek(-1).get_source_fragment().offset - 1 - + preceding.get_source_fragment().offset; + throw parsing_error( + ErrorType::SYNTAX_ERROR, + "Expected expression for function argument", + SourceFragment( + subset.get_source(), + preceding.get_source_fragment().offset + 1, + len) + ); + } + + // If the next argument is a variadic argument reference, we stop parsing more arguments and mark this function call as variadic + if (cast_expr(function_argument.get())) + { + if (subset.has_next()) + { + subset.throw_error( + "Variadic argument propagation must be the last parameter in a function call" + ); + } + function_call_flags |= SRFLAG_FN_TYPE_VARIADIC; + + break; + } + + function_arg_nodes.push_back(std::move(function_argument)); + } + } + } + + return std::make_unique( + context, + identifier->clone_as(), + std::move(function_arg_nodes), + std::move(generic_types), + function_call_flags + ); +} + +// The previous token must be an identifier, otherwise this is not a function call +// This logic is not handled here. +bool stride::ast::is_direct_function_call(const TokenSet& set) +{ + // Function call for sure, "identifier::other(` + if (set.peek_next_eq(TokenType::LPAREN)) + return true; + + // If the subsequent token is a LT, it might just be a generic function instantiation + // We have to do some lookahead to make sure this is the case + if (set.peek_next_eq(TokenType::LT)) + { + int depth = 0; + for (size_t offset = 0; set.position() + offset < set.size(); offset++) + { + switch (const auto next_token = set.at(set.position() + offset); + next_token.get_type()) + { + case TokenType::LT: + ++depth; + break; + case TokenType::GT: + --depth; + break; + + default: + // Optimization, where we know for sure it can't be part of a generic instantiation + if (!next_token.is_type_token() && next_token.get_type() != TokenType::COMMA) + { + return false; + } + } + + if (depth == 0) + { + return set.peek_eq(TokenType::LPAREN, offset + 1); + } + } + } + return false; +} + +std::string AstFunctionCall::format_suggestion(const IDefinition* suggestion) +{ + if (const auto fn_call = dynamic_cast(suggestion)) + { + // We'll format the arguments + std::vector arg_types; + + for (const auto& arg : fn_call->get_type()->get_parameter_types()) + { + arg_types.push_back(arg->get_type_name()); + } + + if (arg_types.empty()) + arg_types.push_back(primitive_type_to_str(PrimitiveType::VOID)); + + return std::format("{}({})", + fn_call->get_symbol().name, + join(arg_types, ", ")); + } + + return suggestion->get_internal_symbol_name(); +} + +std::string AstFunctionCall::format_function_name() const +{ + std::vector arg_types; + + arg_types.reserve(this->_arguments.size()); + + for (const auto& arg : this->_arguments) + { + arg_types.push_back(arg->get_type()->to_string()); + } + + if (arg_types.empty()) + { + arg_types.push_back(primitive_type_to_str(PrimitiveType::VOID)); + } + + return std::format("{}({})", this->get_function_name(), join(arg_types, ", ")); +} + +void AstFunctionCall::validate() +{ + for (const auto& arg : this->_arguments) + { + arg->validate(); + } +} + +void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) +{ + // Add generic types to function definition's generic instantiations + if (!this->_generic_type_arguments.empty()) + { + const auto& definition = this->get_function_definition(); + + definition->add_generic_instantiation( + copy_generic_type_list(this->_generic_type_arguments) + ); + } + + for (const auto& arg : this->_arguments) + { + arg->resolve_forward_references(module, builder); + } +} + +std::vector> AstFunctionCall::get_argument_types() const +{ + if (this->_arguments.empty()) + return {}; + + std::vector> param_types; + param_types.reserve(this->_arguments.size()); + for (const auto& arg : this->_arguments) + { + // The last parameter should be the variadic argument reference, + // which is not included in the type list, as this is dynamically expanded + // Additionally, this would mess with function lookup by signature + if (cast_expr(arg.get())) + { + break; + } + param_types.push_back(arg->get_type()->clone_ty()); + } + return param_types; +} + +const GenericTypeList& AstFunctionCall::get_generic_type_arguments() +{ + return this->_generic_type_arguments; +} + +FunctionDefinition* AstFunctionCall::get_function_definition() +{ + if (this->_definition != nullptr) + return this->_definition; + + if (const auto def = this->get_context()->get_function_definition( + this->get_scoped_function_name(), + this->get_argument_types(), + this->get_generic_type_arguments().size() + ); + def.has_value()) + { + this->_definition = def.value(); + return this->_definition; + } + + throw parsing_error( + ErrorType::REFERENCE_ERROR, + std::format("Function '{}' was not found in this scope", this->format_function_name()), + this->get_source_fragment() + ); +} + +std::unique_ptr AstFunctionCall::clone() +{ + ExpressionList cloned_args; + GenericTypeList generic_type_list_cloned; + + generic_type_list_cloned.reserve(this->_generic_type_arguments.size()); + cloned_args.reserve(this->get_arguments().size()); + + for (const auto& arg : this->get_arguments()) + { + cloned_args.push_back(arg->clone_as()); + } + + for (const auto& generic_arg : this->get_generic_type_arguments()) + { + generic_type_list_cloned.push_back(generic_arg->clone_ty()); + } + + return std::make_unique( + this->get_context(), + this->get_function_name_identifier()->clone_as(), + std::move(cloned_args), + std::move(generic_type_list_cloned), + this->_flags + ); +} + +bool AstFunctionCall::is_reducible() +{ + // TODO: implement + // Function calls can be reducible if the function returns + // a constant value or if all arguments are reducible. + return false; +} + +std::optional> AstFunctionCall::reduce() +{ + return std::nullopt; +} + +std::string AstFunctionCall::get_formatted_call() const +{ + std::vector arg_names; + arg_names.reserve(this->_arguments.size()); + + std::vector formatted_generics; + formatted_generics.reserve(this->_generic_type_arguments.size()); + + for (const auto& generic_arg : this->_generic_type_arguments) + { + formatted_generics.push_back(generic_arg->get_type_name()); + } + + for (const auto& arg : this->_arguments) + { + arg_names.push_back(arg->get_type()->get_type_name()); + } + + if (!formatted_generics.empty()) + { + return std::format( + "{}<{}>({})", + this->get_function_name(), + join(formatted_generics, ", "), + join(arg_names, ", ") + ); + } + + return std::format( + "{}({})", + this->get_function_name(), + join(arg_names, ", ") + ); +} + +std::string AstFunctionCall::to_string() +{ + std::ostringstream oss; + + std::vector arg_types; + for (const auto& arg : this->get_arguments()) + { + arg_types.push_back(arg->to_string()); + } + + return std::format( + "FunctionCall({} [{}])", + this->get_function_name(), + join(arg_types, ", ") + ); +} diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp similarity index 62% rename from packages/compiler/src/ast/nodes/functions/function_call.cpp rename to packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp index 091b8500..d8a7bdb4 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp @@ -1,191 +1,14 @@ -#include "errors.h" -#include "formatting.h" #include "ast/casting.h" #include "ast/closures.h" -#include "ast/flags.h" #include "ast/optionals.h" #include "ast/parsing_context.h" -#include "ast/symbols.h" #include "ast/definitions/function_definition.h" -#include "ast/nodes/blocks.h" #include "ast/nodes/expression.h" -#include "ast/nodes/types.h" -#include "ast/tokens/token_set.h" - -#include -#include -#include -#include -#include + #include #include -#include -#include using namespace stride::ast; -using namespace stride::ast::definition; - -std::unique_ptr stride::ast::parse_function_call( - const std::shared_ptr& context, - AstIdentifier* identifier, - TokenSet& set -) -{ - const auto reference_token = set.peek(-1); - auto generic_types = parse_generic_type_arguments(context, set); - auto function_parameter_set = collect_parenthesized_block(set); - - ExpressionList function_arg_nodes; - - int function_call_flags = SRFLAG_NONE; - - // Parsing function parameter values - if (function_parameter_set.has_value()) - { - auto subset = function_parameter_set.value(); - - if (auto initial_arg = parse_inline_expression(context, subset)) - { - function_arg_nodes.push_back(std::move(initial_arg)); - - // Consume next parameters - while (subset.has_next()) - { - const auto preceding = subset.expect( - TokenType::COMMA, - "Expected ',' between function arguments" - ); - - auto function_argument = parse_inline_expression(context, subset); - - if (!function_argument) - { - // Since the RParen is already consumed, we have to manually extract its - // position with the following assumption It's possible this yields END_OF_FILE - const auto len = - set.peek(-1).get_source_fragment().offset - 1 - - preceding.get_source_fragment().offset; - throw parsing_error( - ErrorType::SYNTAX_ERROR, - "Expected expression for function argument", - SourceFragment( - subset.get_source(), - preceding.get_source_fragment().offset + 1, - len) - ); - } - - // If the next argument is a variadic argument reference, we stop parsing more arguments and mark this function call as variadic - if (cast_expr(function_argument.get())) - { - if (subset.has_next()) - { - subset.throw_error( - "Variadic argument propagation must be the last parameter in a function call" - ); - } - function_call_flags |= SRFLAG_FN_TYPE_VARIADIC; - - break; - } - - function_arg_nodes.push_back(std::move(function_argument)); - } - } - } - - return std::make_unique( - context, - identifier->clone_as(), - std::move(function_arg_nodes), - std::move(generic_types), - function_call_flags - ); -} - -// The previous token must be an identifier, otherwise this is not a function call -// This logic is not handled here. -bool stride::ast::is_direct_function_call(const TokenSet& set) -{ - // Function call for sure, "identifier::other(` - if (set.peek_next_eq(TokenType::LPAREN)) - return true; - - // If the subsequent token is a LT, it might just be a generic function instantiation - // We have to do some lookahead to make sure this is the case - if (set.peek_next_eq(TokenType::LT)) - { - int depth = 0; - for (size_t offset = 0; set.position() + offset < set.size(); offset++) - { - switch (const auto next_token = set.at(set.position() + offset); - next_token.get_type()) - { - case TokenType::LT: - ++depth; - break; - case TokenType::GT: - --depth; - break; - - default: - // Optimization, where we know for sure it can't be part of a generic instantiation - if (!next_token.is_type_token() && next_token.get_type() != TokenType::COMMA) - { - return false; - } - } - - if (depth == 0) - { - return set.peek_eq(TokenType::LPAREN, offset + 1); - } - } - } - return false; -} - -std::string AstFunctionCall::format_suggestion(const IDefinition* suggestion) -{ - if (const auto fn_call = dynamic_cast(suggestion)) - { - // We'll format the arguments - std::vector arg_types; - - for (const auto& arg : fn_call->get_type()->get_parameter_types()) - { - arg_types.push_back(arg->get_type_name()); - } - - if (arg_types.empty()) - arg_types.push_back(primitive_type_to_str(PrimitiveType::VOID)); - - return std::format("{}({})", - fn_call->get_symbol().name, - join(arg_types, ", ")); - } - - return suggestion->get_internal_symbol_name(); -} - -std::string AstFunctionCall::format_function_name() const -{ - std::vector arg_types; - - arg_types.reserve(this->_arguments.size()); - - for (const auto& arg : this->_arguments) - { - arg_types.push_back(arg->get_type()->to_string()); - } - - if (arg_types.empty()) - { - arg_types.push_back(primitive_type_to_str(PrimitiveType::VOID)); - } - - return std::format("{}({})", this->get_function_name(), join(arg_types, ", ")); -} llvm::Value* AstFunctionCall::codegen( llvm::Module* module, @@ -383,7 +206,7 @@ llvm::Value* AstFunctionCall::codegen_anonymous_function_call( if (const auto var_def = this->get_function_name_identifier()->get_definition(); var_def.has_value()) { - const auto field_def = dynamic_cast(var_def.value()); + const auto field_def = dynamic_cast(var_def.value()); if (!field_def) { @@ -639,171 +462,4 @@ llvm::Value* AstFunctionCall::codegen_anonymous_function_call( } return nullptr; -} - -void AstFunctionCall::validate() -{ - for (const auto& arg : this->_arguments) - { - arg->validate(); - } -} - -void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) -{ - // Add generic types to function definition's generic instantiations - if (!this->_generic_type_arguments.empty()) - { - const auto& definition = this->get_function_definition(); - - definition->add_generic_instantiation( - copy_generic_type_list(this->_generic_type_arguments) - ); - } - - for (const auto& arg : this->_arguments) - { - arg->resolve_forward_references(module, builder); - } -} - -std::vector> AstFunctionCall::get_argument_types() const -{ - if (this->_arguments.empty()) - return {}; - - std::vector> param_types; - param_types.reserve(this->_arguments.size()); - for (const auto& arg : this->_arguments) - { - // The last parameter should be the variadic argument reference, - // which is not included in the type list, as this is dynamically expanded - // Additionally, this would mess with function lookup by signature - if (cast_expr(arg.get())) - { - break; - } - param_types.push_back(arg->get_type()->clone_ty()); - } - return param_types; -} - -const GenericTypeList& AstFunctionCall::get_generic_type_arguments() -{ - return this->_generic_type_arguments; -} - -FunctionDefinition* AstFunctionCall::get_function_definition() -{ - if (this->_definition != nullptr) - return this->_definition; - - if (const auto def = this->get_context()->get_function_definition( - this->get_scoped_function_name(), - this->get_argument_types(), - this->get_generic_type_arguments().size() - ); - def.has_value()) - { - this->_definition = def.value(); - return this->_definition; - } - - throw parsing_error( - ErrorType::REFERENCE_ERROR, - std::format("Function '{}' was not found in this scope", this->format_function_name()), - this->get_source_fragment() - ); -} - -std::unique_ptr AstFunctionCall::clone() -{ - ExpressionList cloned_args; - GenericTypeList generic_type_list_cloned; - - generic_type_list_cloned.reserve(this->_generic_type_arguments.size()); - cloned_args.reserve(this->get_arguments().size()); - - for (const auto& arg : this->get_arguments()) - { - cloned_args.push_back(arg->clone_as()); - } - - for (const auto& generic_arg : this->get_generic_type_arguments()) - { - generic_type_list_cloned.push_back(generic_arg->clone_ty()); - } - - return std::make_unique( - this->get_context(), - this->get_function_name_identifier()->clone_as(), - std::move(cloned_args), - std::move(generic_type_list_cloned), - this->_flags - ); -} - -bool AstFunctionCall::is_reducible() -{ - // TODO: implement - // Function calls can be reducible if the function returns - // a constant value or if all arguments are reducible. - return false; -} - -std::optional> AstFunctionCall::reduce() -{ - return std::nullopt; -} - -std::string AstFunctionCall::get_formatted_call() const -{ - std::vector arg_names; - arg_names.reserve(this->_arguments.size()); - - std::vector formatted_generics; - formatted_generics.reserve(this->_generic_type_arguments.size()); - - for (const auto& generic_arg : this->_generic_type_arguments) - { - formatted_generics.push_back(generic_arg->get_type_name()); - } - - for (const auto& arg : this->_arguments) - { - arg_names.push_back(arg->get_type()->get_type_name()); - } - - if (!formatted_generics.empty()) - { - return std::format( - "{}<{}>({})", - this->get_function_name(), - join(formatted_generics, ", "), - join(arg_names, ", ") - ); - } - - return std::format( - "{}({})", - this->get_function_name(), - join(arg_names, ", ") - ); -} - -std::string AstFunctionCall::to_string() -{ - std::ostringstream oss; - - std::vector arg_types; - for (const auto& arg : this->get_arguments()) - { - arg_types.push_back(arg->to_string()); - } - - return std::format( - "FunctionCall({} [{}])", - this->get_function_name(), - join(arg_types, ", ") - ); -} +} \ No newline at end of file diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp new file mode 100644 index 00000000..6f7b4268 --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp @@ -0,0 +1,342 @@ +#include "ast/definitions/function_definition.h" + +#include "errors.h" +#include "ast/casting.h" +#include "ast/closures.h" +#include "ast/modifiers.h" +#include "ast/parsing_context.h" +#include "ast/symbols.h" +#include "ast/nodes/blocks.h" +#include "ast/nodes/conditional_statement.h" +#include "ast/nodes/expression.h" +#include "ast/nodes/for_loop.h" +#include "ast/nodes/function_declaration.h" +#include "ast/nodes/return_statement.h" +#include "ast/nodes/while_loop.h" +#include "ast/tokens/token.h" +#include "ast/tokens/token_set.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace stride::ast; +using namespace stride::ast::definition; + +/** + * Will attempt to parse the provided token stream into an AstFunctionDefinitionNode. + */ +std::unique_ptr stride::ast::parse_fn_declaration( + const std::shared_ptr& context, + TokenSet& set, + VisibilityModifier modifier +) +{ + int function_flags = 0; + const auto reference_token = set.peek_next(); + if (set.peek_next_eq(TokenType::KEYWORD_EXTERN)) + { + set.next(); + function_flags |= SRFLAG_FN_TYPE_EXTERN; + } + + if (set.peek_next_eq(TokenType::KEYWORD_ASYNC)) + { + set.next(); + function_flags |= SRFLAG_FN_TYPE_ASYNC; + } + + set.expect(TokenType::KEYWORD_FN); + + // Here we expect to receive the function name + const auto fn_name_tok = set.expect(TokenType::IDENTIFIER, "Expected function name"); + const auto& fn_name = fn_name_tok.get_lexeme(); + + auto function_context = std::make_shared(context, ContextType::FUNCTION); + + GenericParameterList generic_parameter_names = parse_generic_declaration(set); + + if (function_flags & SRFLAG_FN_TYPE_EXTERN && !generic_parameter_names.empty()) + { + set.throw_error("Extern functions cannot have generic parameters"); + } + + set.expect(TokenType::LPAREN, "Expected '(' after function name"); + std::vector> parameters; + + // Parameter parsing + if (!set.peek_next_eq(TokenType::RPAREN)) + { + parse_function_parameters(function_context, set, parameters, function_flags); + + if (!set.peek_next_eq(TokenType::RPAREN)) + { + set.throw_error( + "Expected closing parenthesis after variadic parameter; variadic parameter must be the last parameter in the function signature" + ); + } + } + + set.expect(TokenType::RPAREN, "Expected ')' after function parameters"); + set.expect(TokenType::COLON, "Expected a colon after function definition"); + + // Return type doesn't have the same flags as the function, hence NONE + auto return_type = parse_type(context, set, { "Expected return type in function header" }); + + const auto& position = SourceFragment::join( + reference_token.get_source_fragment(), + return_type->get_source_fragment()); + auto sym_function_name = Symbol(position, context->get_name(), fn_name); + + std::unique_ptr body = nullptr; + + if (function_flags & SRFLAG_FN_TYPE_EXTERN) + { + set.expect(TokenType::SEMICOLON, "Expected ';' after extern function declaration"); + body = AstBlock::create_empty(function_context, position); + } + else + { + body = parse_block(function_context, set); + } + + return std::make_unique( + function_context, + sym_function_name, + std::move(parameters), + std::move(body), + std::move(return_type), + modifier, + function_flags, + std::move(generic_parameter_names) + ); +} + +std::unique_ptr consume_anonymous_fn_body( + const std::shared_ptr& context, + TokenSet& set) +{ + if (!set.peek_next_eq(TokenType::LBRACE)) + { + auto expr = parse_inline_expression(context, set); + + const auto src_frag = expr->get_source_fragment(); + std::vector> body_nodes; + body_nodes.push_back(std::move(expr)); + + return std::make_unique( + src_frag, + context, + std::move(body_nodes) + ); + } + + return parse_block(context, set); +} + +std::unique_ptr stride::ast::parse_anonymous_fn_expression( + const std::shared_ptr& context, + TokenSet& set +) +{ + const auto reference_token = set.peek_next(); + std::vector> parameters = {}; + + int function_flags = SRFLAG_FN_TYPE_ANONYMOUS; + auto function_context = std::make_shared( + context, + ContextType::FUNCTION + ); + + // Parses expressions like: + // (: , ...): -> {} + if (auto header_definition = collect_parenthesized_block(set); + header_definition.has_value() && header_definition->has_next()) + { + parse_function_parameters( + function_context, + header_definition.value(), + parameters, + function_flags + ); + } + + set.expect(TokenType::COLON, "Expected ':' after lambda function header definition"); + auto return_type = parse_type( + function_context, + set, + { "Expected type after anonymous function header definition" } + ); + const auto lambda_arrow = set.expect( + TokenType::RARROW, + "Expected '->' after lambda parameters" + ); + + auto lambda_body = consume_anonymous_fn_body(function_context, set); + + static int anonymous_lambda_id = 0; + + auto symbol_name = Symbol( + { set.get_source(), + reference_token.get_source_fragment().offset, + lambda_arrow.get_source_fragment().offset - + reference_token.get_source_fragment().offset }, + ANONYMOUS_FN_PREFIX + std::to_string(anonymous_lambda_id++) + ); + + std::vector> cloned_params; + cloned_params.reserve(parameters.size()); + for (auto& param : parameters) + { + cloned_params.push_back(param->get_type()->clone_ty()); + } + + return std::make_unique( + function_context, + symbol_name, + std::move(parameters), + std::move(lambda_body), + std::move(return_type), + VisibilityModifier::PRIVATE, + // Anonymous functions are always private + function_flags + ); +} + +std::vector> IAstFunction::get_parameters() const +{ + std::vector> cloned_params; + cloned_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + cloned_params.push_back(param->clone_as()); + } + + return cloned_params; +} + +std::vector> IAstFunction::get_parameter_types() const +{ + std::vector> types; + types.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + types.push_back(param->get_type()->clone_ty()); + } + + return types; +} + +FunctionDefinition* IAstFunction::get_function_definition() +{ + if (this->_function_definition != nullptr) + return this->_function_definition; + + const auto& definition = this->get_context()->get_function_definition( + this->get_registered_function_name(), + this->get_parameter_types(), + this->get_generic_parameters().size() + ); + + if (!definition.has_value()) + { + throw parsing_error( + ErrorType::REFERENCE_ERROR, + std::format("Function definition for '{}' not found in context", this->get_registered_function_name()), + this->get_source_fragment() + ); + } + + this->_function_definition = definition.value(); + return this->_function_definition; +} + +std::vector IAstFunction::get_generic_function_metadata() +{ + const auto& definition = this->get_function_definition(); + + // If the function is not generic, we just return a singular name (the regular internalized name) + if (!definition->get_type()->is_generic()) + { + return { + GenericFunctionMetadata{ this->get_registered_function_name(), definition->get_llvm_function() } + }; + } + + std::vector metadata; + + for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) + { + metadata.emplace_back( + get_overloaded_function_name(node->get_registered_function_name(), types), + llvm_function + ); + } + + return metadata; +} + +std::unique_ptr AstFunctionParameter::clone() +{ + return std::make_unique( + this->get_source_fragment(), + this->get_context(), + this->get_name(), + this->get_type()->clone_ty() + ); +} + +std::unique_ptr IAstFunction::clone() +{ + std::vector> cloned_params; + cloned_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + cloned_params.push_back(param->clone_as()); + } + + return std::make_unique( + this->get_source_fragment(), + this->get_context(), + this->_symbol, + std::move(cloned_params), + this->_body->clone_as(), + this->_annotated_return_type->clone_ty(), + this->_visibility, + this->_flags, + this->_generic_parameters + ); +} + +std::string IAstFunction::to_string() +{ + std::string params; + for (const auto& param : this->_parameters) + { + if (!params.empty()) + params += ", "; + params += param->to_string(); + } + + const auto body_str = this->get_body() == nullptr + ? "" + : this->get_body()->to_string(); + + return std::format( + "Function(name: {}(internal: {}), params: [{}], body: {}{} -> {})", + this->is_anonymous() ? "" : this->get_plain_function_name(), + this->get_registered_function_name(), + params, + body_str, + this->is_extern() ? " (extern)" : "", + this->get_return_type()->to_string() + ); +} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp new file mode 100644 index 00000000..4fbd0874 --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp @@ -0,0 +1,206 @@ +#include "ast/closures.h" +#include "ast/nodes/function_declaration.h" + +#include +#include +#include + +using namespace stride::ast; + +llvm::Value* IAstFunction::codegen( + llvm::Module* module, + llvm::IRBuilderBase* builder +) +{ + llvm::Function* function = nullptr; + + for (const auto& [function_name, llvm_function_val] : this->get_generic_function_metadata()) + { + if (!llvm_function_val) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format("Function symbol '{}' missing", function_name), + this->get_source_fragment() + ); + } + + + // If the function body has already been generated (has basic blocks), just return the function pointer + if (this->is_extern() || !llvm_function_val->empty()) + { + return llvm_function_val; + } + + // Save the current insert point to restore it later + // This is important when generating nested lambdas + llvm::BasicBlock* saved_insert_block = builder->GetInsertBlock(); + llvm::BasicBlock::iterator saved_insert_point; + const bool has_insert_point = saved_insert_block != nullptr; + + if (saved_insert_block) + { + saved_insert_point = builder->GetInsertPoint(); + } + + llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create( + module->getContext(), + "entry", + llvm_function_val + ); + builder->SetInsertPoint(entry_bb); + + // We create a new builder for the prologue to ensure allocas are at the very top + llvm::IRBuilder prologue_builder( + &llvm_function_val->getEntryBlock(), + llvm_function_val->getEntryBlock().begin() + ); + + // + // Captured variable handling + // Map captured variables to function arguments with __capture_ prefix + // + auto fn_parameter_argument = llvm_function_val->arg_begin(); + for (const auto& capture : this->_captured_variables) + { + if (fn_parameter_argument != llvm_function_val->arg_end()) + { + fn_parameter_argument->setName(closures::format_captured_variable_name(capture.internal_name)); + + // Create alloca with __capture_ prefix so identifier lookup can find it + llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( + fn_parameter_argument->getType(), + nullptr, + closures::format_captured_variable_name_internal(capture.internal_name) + ); + + builder->CreateStore(fn_parameter_argument, alloca); + ++fn_parameter_argument; + } + } + + // + // Function parameter handling + // Here we define the parameters on the stack as memory slots for the function + // + for (const auto& param : this->_parameters) + { + if (fn_parameter_argument != llvm_function_val->arg_end()) + { + fn_parameter_argument->setName(param->get_name() + ".arg"); + + // Create a memory slot on the stack for the parameter + llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( + fn_parameter_argument->getType(), + nullptr, + param->get_name() + ); + + // Store the initial argument value into the alloca + builder->CreateStore(fn_parameter_argument, alloca); + + ++fn_parameter_argument; + } + } + + // Generate Body + llvm::Value* function_body_value = this->_body->codegen(module, builder); + + // Final Safety: Implicit Return + // If the get_body didn't explicitly return (no terminator found), add one. + if (llvm::BasicBlock* current_bb = builder->GetInsertBlock(); + current_bb && !current_bb->getTerminator()) + { + if (llvm::Type* ret_type = llvm_function_val->getReturnType(); + ret_type->isVoidTy()) + { + builder->CreateRetVoid(); + } + else if (function_body_value && function_body_value->getType() == ret_type) + { + builder->CreateRet(function_body_value); + } + // Default return to keep IR valid (useful for main or incomplete functions) + else + { + if (ret_type->isFloatingPointTy()) + { + builder->CreateRet(llvm::ConstantFP::get(ret_type, 0.0)); + } + else if (ret_type->isIntegerTy()) + { + builder->CreateRet(llvm::ConstantInt::get(ret_type, 0)); + } + else + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "Function '{}' is missing a return path.", + this->get_plain_function_name() + ), + this->get_source_fragment() + ); + } + } + } + + if (llvm::verifyFunction(*llvm_function_val, &llvm::errs())) + { + module->print(llvm::errs(), nullptr); + throw parsing_error( + ErrorType::COMPILATION_ERROR, + "LLVM Function Verification Failed for: " + this->get_plain_function_name(), + this->get_source_fragment() + ); + } + + // Restore the previous insert point for nested lambda generation + if (has_insert_point && saved_insert_block) + { + builder->SetInsertPoint(saved_insert_block, saved_insert_point); + } + + // For anonymous lambdas with captured variables, create a closure structure + // that bundles the function pointer with the current values of captured variables + if (this->is_anonymous()) + { + // Collect the current values of captured variables from the enclosing scope + std::vector captured_values; + for (const auto& capture : this->get_captured_variables()) + { + if (const auto block = builder->GetInsertBlock()) + { + llvm::Function* current_fn = block->getParent(); + llvm::Value* captured_val = closures::lookup_variable_or_capture(current_fn, capture.internal_name); + + if (!captured_val) + { + captured_val = closures::lookup_variable_by_base_name(current_fn, capture.name); + } + + if (captured_val) + { + // Load the value if it's an alloca + if (auto* alloca = llvm::dyn_cast(captured_val)) + { + captured_val = builder->CreateLoad( + alloca->getAllocatedType(), + alloca, + capture.internal_name + ); + } + captured_values.push_back(captured_val); + } + } + } + + // Create and return a closure instead of the raw function pointer + return closures::create_closure(module, builder, llvm_function_val, captured_values); + } + + function = llvm_function_val; + } + + return function; +} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp new file mode 100644 index 00000000..ba04640f --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -0,0 +1,497 @@ +#include "ast/casting.h" +#include "ast/definitions/function_definition.h" +#include "ast/nodes/conditional_statement.h" +#include "ast/nodes/for_loop.h" +#include "ast/nodes/function_declaration.h" +#include "ast/nodes/return_statement.h" +#include "ast/nodes/while_loop.h" + +#include +#include + +using namespace stride::ast; + +/** + * Here we define the function in the symbol table, so it can be looked up in the codegen phase. + */ +void IAstFunction::resolve_forward_references( + llvm::Module* module, + llvm::IRBuilderBase* builder +) +{ + std::vector captures; + const auto outer_context = this->get_context()->get_parent_context() != nullptr + ? this->get_context()->get_parent_context() + : this->get_context(); + + collect_free_variables(this->get_body(), this->get_context(), outer_context, captures); + + // Register captured variables in the lambda's context so they can be referenced + for (const auto& capture : captures) + { + this->add_captured_variable(capture); + + // Also define the capture in the lambda's context so identifier lookup works + if (const auto outer_var = this->get_context()->lookup_variable(capture.name, true)) + { + this->get_context()->define_variable( + capture, + outer_var->get_type()->clone_ty(), + VisibilityModifier::PRIVATE + ); + } + } + + // Avoid re-registering if already declared. + // Named functions are looked up by their scoped name; anonymous functions are + // tracked by the cached _llvm_function pointer (they have no stable string name). + if (this->is_anonymous() && this->get_function_definition()->get_llvm_function() != nullptr) + return; + + // Add captured variables as first parameters + std::vector captured_types; + for (const auto& capture : this->_captured_variables) + { + if (const auto capture_def = this->get_context()->lookup_variable(capture.name, true)) + { + if (llvm::Type* capture_type = capture_def->get_type()->get_llvm_type(module)) + { + captured_types.push_back(capture_type); + } + } + } + + const auto& definition = this->get_function_definition(); + + const auto linkage = this->_visibility == VisibilityModifier::PRIVATE + ? llvm::Function::PrivateLinkage + : llvm::Function::ExternalLinkage; + + if (!this->is_generic()) + { + llvm::FunctionType* generic_function_type = this->get_llvm_function_type( + module, + captured_types, + {} + ); + + definition->set_llvm_function(llvm::Function::Create( + generic_function_type, + linkage, + this->get_registered_function_name(), + module + )); + this->_body->resolve_forward_references(module, builder); + + return; + } + + for (const auto& [instantiated_generic_types, llvm_function, node] : definition->get_generic_overloads()) + { + auto instantiated_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + instantiated_generic_types + ); + + std::vector> instantiated_function_params; + instantiated_function_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + instantiated_function_params.push_back( + std::make_unique( + param->get_source_fragment(), + param->get_context(), + param->get_name(), + resolve_generics(param->get_type(), this->_generic_parameters, instantiated_generic_types) + ) + ); + } + + node = std::make_unique( + this->get_context(), + this->_symbol, + std::move(instantiated_function_params), + this->_body->clone_as(), + std::move(instantiated_return_ty), + this->get_visibility(), + this->_flags, + this->get_generic_parameters() + ); + + const auto overloaded_fn_name = get_overloaded_function_name( + this->get_registered_function_name(), + instantiated_generic_types); + llvm::FunctionType* generic_function_type = this->get_llvm_function_type( + module, + captured_types, + instantiated_generic_types + ); + + llvm_function = llvm::Function::Create( + generic_function_type, + linkage, + overloaded_fn_name, + module + ); + + for (const auto& instantiated_param : instantiated_function_params) + { + const auto param_symbol = Symbol(instantiated_param->get_source_fragment(), instantiated_param->get_name()); + instantiated_param->get_context()->define_variable( + param_symbol, + instantiated_param->get_type()->clone_ty(), + VisibilityModifier::PRIVATE, + true + ); + } + + if (this->is_anonymous()) + llvm_function->addFnAttr("stride.anonymous"); + + this->_body->resolve_forward_references(module, builder); + } +} + +void IAstFunction::collect_free_variables( + IAstNode* node, + const std::shared_ptr& lambda_context, + const std::shared_ptr& outer_context, + std::vector& captures +) +{ + if (!node) + { + return; + } + + auto capture_variable = [&](const std::string& name) + { + if (lambda_context->get_variable_def(name, true)) + { + // No need to collect any variables; they're readily available + return; + } + + // Not local to the lambda - check if it's in an outer scope (and is a variable, not a function) + if (const auto outer_symbol = outer_context->lookup_variable(name, true)) + { + // Check if we haven't already captured this variable + bool already_captured = false; + for (const auto& cap : captures) + { + if (cap.internal_name == outer_symbol->get_internal_symbol_name()) + { + already_captured = true; + break; + } + } + + if (!already_captured) + { + captures.push_back(outer_symbol->get_symbol()); + } + } + }; + + // Check specific node types first, then fall back to generic container handling + + // If it's an identifier, check if it references a variable from outer scope + if (const auto* identifier = cast_expr(node)) + { + capture_variable(identifier->get_name()); + return; + } + + // Handle nested callables (lambdas) - recursively collect their free variables + if (auto* callable = cast_expr(node)) + { + // For nested lambdas, we need to: + // 1. First collect what the nested lambda needs from its body (if not already done) + // 2. Then capture those variables that are available in our outer context + + // Check if this nested lambda already has captures collected + // If it does, we just need to propagate them upward + + if (const auto& existing_captures = callable->get_captured_variables(); + existing_captures.empty() && callable->get_body()) + { + // Captures haven't been collected yet, so do it now + std::vector nested_captures; + + // Recursively collect free variables in the nested lambda's body + // The nested lambda's context is its own context, and its outer context is our lambda_context + collect_free_variables( + callable->get_body(), + callable->get_context(), + lambda_context, + nested_captures); + + // Now register the nested lambda's captures + for (const auto& nested_capture : nested_captures) + { + callable->add_captured_variable(nested_capture); + + // Define the capture in the nested lambda's context so identifier lookup works + if (const auto var_def = lambda_context->lookup_variable(nested_capture.name, true)) + { + callable->get_context()->define_variable( + nested_capture, + var_def->get_type()->clone_ty(), + VisibilityModifier::PRIVATE + ); + } + } + } + + // Now capture those variables into OUR scope if they come from outside + for (const auto& cap : callable->get_captured_variables()) + { + capture_variable(cap.name); + } + return; + } + + // Handle return statements + if (const auto* return_stmt = cast_ast(node)) + { + if (return_stmt->get_return_expression().has_value()) + { + collect_free_variables( + return_stmt->get_return_expression().value().get(), + lambda_context, + outer_context, + captures); + } + return; + } + + // Handle function calls + if (const auto* fn_call = cast_expr(node)) + { + // Capture the "function name" if it's actually a variable (anonymous call) + capture_variable(fn_call->get_function_name()); + + for (const auto& arg : fn_call->get_arguments()) + { + collect_free_variables(arg.get(), lambda_context, outer_context, captures); + } + return; + } + + // Handle variable declarations (initializer) + if (const auto* var_decl = cast_expr(node)) + { + if (var_decl->get_initial_value()) + { + collect_free_variables( + var_decl->get_initial_value(), + lambda_context, + outer_context, + captures); + } + return; + } + + // Handle variable reassignment + if (const auto* assignment = cast_expr(node)) + { + collect_free_variables(assignment->get_value(), lambda_context, outer_context, captures); + return; + } + + // Handle binary ops + if (const auto* bin_op = cast_expr(node)) + { + collect_free_variables(bin_op->get_left(), lambda_context, outer_context, captures); + collect_free_variables(bin_op->get_right(), lambda_context, outer_context, captures); + return; + } + + // Handle unary ops + if (const auto* unary_op = cast_expr(node)) + { + collect_free_variables(&unary_op->get_operand(), lambda_context, outer_context, captures); + return; + } + + // Handle if statements + if (auto* if_stmt = cast_ast(node)) + { + collect_free_variables(if_stmt->get_condition(), lambda_context, outer_context, captures); + collect_free_variables(if_stmt->get_body(), lambda_context, outer_context, captures); + if (if_stmt->get_else_body()) + { + collect_free_variables( + if_stmt->get_else_body(), + lambda_context, + outer_context, + captures); + } + return; + } + + // Handle while loops + if (auto* while_loop = cast_ast(node)) + { + collect_free_variables( + while_loop->get_condition(), + lambda_context, + outer_context, + captures); + collect_free_variables(while_loop->get_body(), lambda_context, outer_context, captures); + return; + } + + // Handle for loops + if (auto* for_loop = cast_ast(node)) + { + if (for_loop->get_initializer()) + { + collect_free_variables( + for_loop->get_initializer(), + lambda_context, + outer_context, + captures); + } + if (for_loop->get_condition()) + { + collect_free_variables( + for_loop->get_condition(), + lambda_context, + outer_context, + captures); + } + if (for_loop->get_incrementor()) + { + collect_free_variables( + for_loop->get_incrementor(), + lambda_context, + outer_context, + captures); + } + + collect_free_variables(for_loop->get_body(), lambda_context, outer_context, captures); + return; + } + + // Handle array literals + if (const auto* array = cast_expr(node)) + { + for (const auto& elem : array->get_elements()) + { + collect_free_variables(elem.get(), lambda_context, outer_context, captures); + } + return; + } + + // Handle struct initializers + if (const auto* struct_init = cast_expr(node)) + { + for (const auto& val : struct_init->get_initializers() | std::views::values) + { + collect_free_variables(val.get(), lambda_context, outer_context, captures); + } + return; + } + + // Handle chained member access + if (const auto* chained = cast_expr(node)) + { + collect_free_variables(chained->get_base(), lambda_context, outer_context, captures); + return; + } + + // Handle blocks (lambda bodies, function bodies, etc.) + if (const auto* block = cast_ast(node)) + { + for (const auto& child : block->get_children()) + { + collect_free_variables(child.get(), lambda_context, outer_context, captures); + } + return; + } + + // Generic container handling (if statements, loops, etc.) - this should be last + if (auto* container = dynamic_cast(node)) + { + if (const auto* body = container->get_body()) + { + for (const auto& child : body->get_children()) + { + collect_free_variables(child.get(), lambda_context, outer_context, captures); + } + } + } +} + +llvm::FunctionType* IAstFunction::get_llvm_function_type( + llvm::Module* module, + std::vector captured_variables, + const GenericTypeList& generic_instantiation_types +) const +{ + std::vector base_parameter_types; + + if (generic_instantiation_types.empty()) + { + for (const auto& param : this->_parameters) + { + if (llvm::Type* param_type = param->get_type()->get_llvm_type(module)) + { + base_parameter_types.push_back(param_type); + } + } + } + else + { + for (const auto& param : this->_parameters) + { + const auto& resolved_generic_param_type = resolve_generics( + param->get_type(), + this->_generic_parameters, + generic_instantiation_types + ); + if (llvm::Type* param_type = resolved_generic_param_type->get_llvm_type(module)) + { + base_parameter_types.push_back(param_type); + } + } + } + + std::vector parameter_types; + parameter_types.reserve(base_parameter_types.size() + captured_variables.size()); + + // Captured variables are first in the internal LLVM function type + parameter_types.insert(parameter_types.end(), captured_variables.begin(), captured_variables.end()); + parameter_types.insert(parameter_types.end(), base_parameter_types.begin(), base_parameter_types.end()); + + const auto return_type = this->get_return_type()->get_llvm_type(module); + + if (!return_type) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + "Could not get LLVM return type for function: " + this->get_plain_function_name(), + this->get_source_fragment() + ); + } + + const auto& llvm_function_ty = llvm::FunctionType::get(return_type, parameter_types, this->is_variadic());; + + if (const auto fn = module->getFunction(this->get_registered_function_name()); + fn != nullptr && fn->getFunctionType() != llvm_function_ty) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "Function symbol '{}' already exists with a different signature", + this->get_registered_function_name() + ), + this->get_source_fragment() + ); + } + + return llvm_function_ty; +} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp new file mode 100644 index 00000000..87a6a31d --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -0,0 +1,217 @@ +#include "ast/casting.h" +#include "ast/definitions/function_definition.h" +#include "ast/nodes/conditional_statement.h" +#include "ast/nodes/function_declaration.h" +#include "ast/nodes/return_statement.h" + +using namespace stride::ast; + +void IAstFunction::validate() +{ + // Extern functions have no body to validate + if (this->is_extern()) + return; + + if (!this->is_generic()) + { + validate_candidate(this); + return; + } + + // + // For generic functions, we create a new copy of the function with all parameters resolved, and do validation + // on that copy. This is because we want to validate the function body with the actual types that will be used in + // the function, rather than the generic placeholders. + // + // create a copy of this function with the parameters instantiated + for (const auto definition = this->get_function_definition(); + const auto& [types, function, node] : definition->get_generic_overloads()) + { + auto instantiated_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + types + ); + + std::vector> instantiated_function_params; + instantiated_function_params.reserve(this->_parameters.size()); + + for (const auto& param : this->_parameters) + { + instantiated_function_params.push_back( + std::make_unique( + param->get_source_fragment(), + param->get_context(), + param->get_name(), + resolve_generics(param->get_type(), this->_generic_parameters, types) + ) + ); + } + + node = std::make_unique( + this->get_context(), + this->_symbol, + std::move(instantiated_function_params), + this->_body->clone_as(), + std::move(instantiated_return_ty), + this->get_visibility(), + this->_flags, + this->get_generic_parameters() + ); + + validate_candidate(node.get()); + } +} + +void IAstFunction::validate_candidate(IAstFunction* candidate) +{ + candidate->get_body()->validate(); + + const auto& ret_ty = candidate->get_return_type(); + const auto return_statements = collect_return_statements(candidate->get_body()); + + // For void types, we only disallow returning expressions, as this is redundant. + if (const auto void_ret = cast_type(ret_ty); + void_ret != nullptr && void_ret->get_primitive_type() == PrimitiveType::VOID) + { + for (const auto& return_stmt : return_statements) + { + if (return_stmt->get_return_expression().has_value()) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "{} has return type 'void' and cannot return a value.", + candidate->is_anonymous() + ? "Anonymous function" + : std::format("Function '{}'", candidate->get_plain_function_name())), + { + ErrorSourceReference( + "unexpected return value", + return_stmt->get_source_fragment() + ), + ErrorSourceReference( + "Function returning void type", + candidate->get_source_fragment() + ) + } + + ); + } + } + return; + } + + if (return_statements.empty()) + { + if (cast_type(ret_ty)) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Function '{}' returns a struct type, but no return statement is present.", + candidate->get_plain_function_name()), + candidate->get_source_fragment()); + } + + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "{} is missing a return statement.", + candidate->is_anonymous() + ? "Anonymous function" + : std::format("Function '{}'", candidate->get_plain_function_name())), + candidate->get_source_fragment() + ); + } + + for (const auto& return_stmt : return_statements) + { + if (return_stmt->is_void_type()) + { + if (!ret_ty->is_void_ty()) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Function '{}' returns a value of type '{}', but no return statement is present.", + candidate->is_anonymous() ? "" : candidate->get_plain_function_name(), + ret_ty->to_string()), + return_stmt->get_source_fragment() + ); + } + return; + } + + if (const auto& ret_expr = return_stmt->get_return_expression().value(); + !ret_expr->get_type()->equals(ret_ty) && + !ret_expr->get_type()->is_assignable_to(ret_ty)) + { + const auto error_fragment = ErrorSourceReference( + std::format( + "expected {}{}", + candidate->get_return_type()->is_primitive() + ? "" + : candidate->get_return_type()->is_function() + ? "function-type " + : "struct-type ", + candidate->get_return_type()->to_string()), + ret_expr->get_source_fragment() + ); + + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Function '{}' expected a return type of '{}', but received '{}'.", + candidate->is_anonymous() ? "" : candidate->get_plain_function_name(), + ret_ty->get_type_name(), + ret_expr->get_type()->get_type_name()), + { error_fragment } + ); + } + } +} + +std::vector IAstFunction::collect_return_statements(const AstBlock* body) +{ + if (!body) + { + return {}; + } + + std::vector return_statements; + for (const auto& child : body->get_children()) + { + if (auto* return_stmt = dynamic_cast(child.get())) + { + return_statements.push_back(return_stmt); + } + + // Recursively collect from child containers + if (const auto container_node = dynamic_cast(child. + get())) + { + const auto aggregated = collect_return_statements( + container_node->get_body()); + return_statements.insert(return_statements.end(), + aggregated.begin(), + aggregated.end()); + } + + // Edge case: if statements hold the `else` block too, though this doesn't fall under the + // `IAstContainer` abstraction. The `get_body` part is added in the previous case, though we + // still need to add the else body + if (const auto if_statement = dynamic_cast(child. + get())) + { + const auto aggregated = collect_return_statements( + if_statement->get_else_body()); + return_statements.insert( + return_statements.end(), + aggregated.begin(), + aggregated.end() + ); + } + } + return return_statements; +} diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp deleted file mode 100644 index e3cdf639..00000000 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ /dev/null @@ -1,1234 +0,0 @@ -#include "ast/definitions/function_definition.h" - -#include "errors.h" -#include "ast/casting.h" -#include "ast/closures.h" -#include "ast/modifiers.h" -#include "ast/parsing_context.h" -#include "ast/symbols.h" -#include "ast/nodes/blocks.h" -#include "ast/nodes/conditional_statement.h" -#include "ast/nodes/expression.h" -#include "ast/nodes/for_loop.h" -#include "ast/nodes/function_declaration.h" -#include "ast/nodes/return_statement.h" -#include "ast/nodes/while_loop.h" -#include "ast/tokens/token.h" -#include "ast/tokens/token_set.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace stride::ast; -using namespace stride::ast::definition; - -/** - * Will attempt to parse the provided token stream into an AstFunctionDefinitionNode. - */ -std::unique_ptr stride::ast::parse_fn_declaration( - const std::shared_ptr& context, - TokenSet& set, - VisibilityModifier modifier -) -{ - int function_flags = 0; - const auto reference_token = set.peek_next(); - if (set.peek_next_eq(TokenType::KEYWORD_EXTERN)) - { - set.next(); - function_flags |= SRFLAG_FN_TYPE_EXTERN; - } - - if (set.peek_next_eq(TokenType::KEYWORD_ASYNC)) - { - set.next(); - function_flags |= SRFLAG_FN_TYPE_ASYNC; - } - - set.expect(TokenType::KEYWORD_FN); - - // Here we expect to receive the function name - const auto fn_name_tok = set.expect(TokenType::IDENTIFIER, "Expected function name"); - const auto& fn_name = fn_name_tok.get_lexeme(); - - auto function_context = std::make_shared(context, ContextType::FUNCTION); - - GenericParameterList generic_parameter_names = parse_generic_declaration(set); - - if (function_flags & SRFLAG_FN_TYPE_EXTERN && !generic_parameter_names.empty()) - { - set.throw_error("Extern functions cannot have generic parameters"); - } - - set.expect(TokenType::LPAREN, "Expected '(' after function name"); - std::vector> parameters; - - // Parameter parsing - if (!set.peek_next_eq(TokenType::RPAREN)) - { - parse_function_parameters(function_context, set, parameters, function_flags); - - if (!set.peek_next_eq(TokenType::RPAREN)) - { - set.throw_error( - "Expected closing parenthesis after variadic parameter; variadic parameter must be the last parameter in the function signature" - ); - } - } - - set.expect(TokenType::RPAREN, "Expected ')' after function parameters"); - set.expect(TokenType::COLON, "Expected a colon after function definition"); - - // Return type doesn't have the same flags as the function, hence NONE - auto return_type = parse_type(context, set, { "Expected return type in function header" }); - - const auto& position = SourceFragment::join( - reference_token.get_source_fragment(), - return_type->get_source_fragment()); - auto sym_function_name = Symbol(position, context->get_name(), fn_name); - - std::unique_ptr body = nullptr; - - if (function_flags & SRFLAG_FN_TYPE_EXTERN) - { - set.expect(TokenType::SEMICOLON, "Expected ';' after extern function declaration"); - body = AstBlock::create_empty(function_context, position); - } - else - { - body = parse_block(function_context, set); - } - - return std::make_unique( - function_context, - sym_function_name, - std::move(parameters), - std::move(body), - std::move(return_type), - modifier, - function_flags, - std::move(generic_parameter_names) - ); -} - -std::unique_ptr consume_anonymous_fn_body( - const std::shared_ptr& context, - TokenSet& set) -{ - if (!set.peek_next_eq(TokenType::LBRACE)) - { - auto expr = parse_inline_expression(context, set); - - const auto src_frag = expr->get_source_fragment(); - std::vector> body_nodes; - body_nodes.push_back(std::move(expr)); - - return std::make_unique( - src_frag, - context, - std::move(body_nodes) - ); - } - - return parse_block(context, set); -} - -std::unique_ptr stride::ast::parse_anonymous_fn_expression( - const std::shared_ptr& context, - TokenSet& set -) -{ - const auto reference_token = set.peek_next(); - std::vector> parameters = {}; - - int function_flags = SRFLAG_FN_TYPE_ANONYMOUS; - auto function_context = std::make_shared( - context, - ContextType::FUNCTION - ); - - // Parses expressions like: - // (: , ...): -> {} - if (auto header_definition = collect_parenthesized_block(set); - header_definition.has_value() && header_definition->has_next()) - { - parse_function_parameters( - function_context, - header_definition.value(), - parameters, - function_flags - ); - } - - set.expect(TokenType::COLON, "Expected ':' after lambda function header definition"); - auto return_type = parse_type( - function_context, - set, - { "Expected type after anonymous function header definition" } - ); - const auto lambda_arrow = set.expect( - TokenType::RARROW, - "Expected '->' after lambda parameters" - ); - - auto lambda_body = consume_anonymous_fn_body(function_context, set); - - static int anonymous_lambda_id = 0; - - auto symbol_name = Symbol( - { set.get_source(), - reference_token.get_source_fragment().offset, - lambda_arrow.get_source_fragment().offset - - reference_token.get_source_fragment().offset }, - ANONYMOUS_FN_PREFIX + std::to_string(anonymous_lambda_id++) - ); - - std::vector> cloned_params; - cloned_params.reserve(parameters.size()); - for (auto& param : parameters) - { - cloned_params.push_back(param->get_type()->clone_ty()); - } - - return std::make_unique( - function_context, - symbol_name, - std::move(parameters), - std::move(lambda_body), - std::move(return_type), - VisibilityModifier::PRIVATE, - // Anonymous functions are always private - function_flags - ); -} - -std::vector collect_return_statements(const AstBlock* body) -{ - if (!body) - { - return {}; - } - - std::vector return_statements; - for (const auto& child : body->get_children()) - { - if (auto* return_stmt = dynamic_cast(child.get())) - { - return_statements.push_back(return_stmt); - } - - // Recursively collect from child containers - if (const auto container_node = dynamic_cast(child. - get())) - { - const auto aggregated = collect_return_statements( - container_node->get_body()); - return_statements.insert(return_statements.end(), - aggregated.begin(), - aggregated.end()); - } - - // Edge case: if statements hold the `else` block too, though this doesn't fall under the - // `IAstContainer` abstraction. The `get_body` part is added in the previous case, though we - // still need to add the else body - if (const auto if_statement = dynamic_cast(child. - get())) - { - const auto aggregated = collect_return_statements( - if_statement->get_else_body()); - return_statements.insert( - return_statements.end(), - aggregated.begin(), - aggregated.end() - ); - } - } - return return_statements; -} - -void collect_free_variables( - IAstNode* node, - const std::shared_ptr& lambda_context, - const std::shared_ptr& outer_context, - std::vector& captures -) -{ - if (!node) - { - return; - } - - auto capture_variable = [&](const std::string& name) - { - if (lambda_context->get_variable_def(name, true)) - { - // No need to collect any variables; they're readily available - return; - } - - // Not local to the lambda - check if it's in an outer scope (and is a variable, not a function) - if (const auto outer_symbol = outer_context->lookup_variable(name, true)) - { - // Check if we haven't already captured this variable - bool already_captured = false; - for (const auto& cap : captures) - { - if (cap.internal_name == outer_symbol->get_internal_symbol_name()) - { - already_captured = true; - break; - } - } - - if (!already_captured) - { - captures.push_back(outer_symbol->get_symbol()); - } - } - }; - - // Check specific node types first, then fall back to generic container handling - - // If it's an identifier, check if it references a variable from outer scope - if (const auto* identifier = cast_expr(node)) - { - capture_variable(identifier->get_name()); - return; - } - - // Handle nested callables (lambdas) - recursively collect their free variables - if (auto* callable = cast_expr(node)) - { - // For nested lambdas, we need to: - // 1. First collect what the nested lambda needs from its body (if not already done) - // 2. Then capture those variables that are available in our outer context - - // Check if this nested lambda already has captures collected - // If it does, we just need to propagate them upward - - if (const auto& existing_captures = callable->get_captured_variables(); - existing_captures.empty() && callable->get_body()) - { - // Captures haven't been collected yet, so do it now - std::vector nested_captures; - - // Recursively collect free variables in the nested lambda's body - // The nested lambda's context is its own context, and its outer context is our lambda_context - collect_free_variables( - callable->get_body(), - callable->get_context(), - lambda_context, - nested_captures); - - // Now register the nested lambda's captures - for (const auto& nested_capture : nested_captures) - { - callable->add_captured_variable(nested_capture); - - // Define the capture in the nested lambda's context so identifier lookup works - if (const auto var_def = lambda_context->lookup_variable(nested_capture.name, true)) - { - callable->get_context()->define_variable( - nested_capture, - var_def->get_type()->clone_ty(), - VisibilityModifier::PRIVATE - ); - } - } - } - - // Now capture those variables into OUR scope if they come from outside - for (const auto& cap : callable->get_captured_variables()) - { - capture_variable(cap.name); - } - return; - } - - // Handle return statements - if (const auto* return_stmt = cast_ast(node)) - { - if (return_stmt->get_return_expression().has_value()) - { - collect_free_variables( - return_stmt->get_return_expression().value().get(), - lambda_context, - outer_context, - captures); - } - return; - } - - // Handle function calls - if (const auto* fn_call = cast_expr(node)) - { - // Capture the "function name" if it's actually a variable (anonymous call) - capture_variable(fn_call->get_function_name()); - - for (const auto& arg : fn_call->get_arguments()) - { - collect_free_variables(arg.get(), lambda_context, outer_context, captures); - } - return; - } - - // Handle variable declarations (initializer) - if (const auto* var_decl = cast_expr(node)) - { - if (var_decl->get_initial_value()) - { - collect_free_variables( - var_decl->get_initial_value(), - lambda_context, - outer_context, - captures); - } - return; - } - - // Handle variable reassignment - if (const auto* assignment = cast_expr(node)) - { - collect_free_variables(assignment->get_value(), lambda_context, outer_context, captures); - return; - } - - // Handle binary ops - if (const auto* bin_op = cast_expr(node)) - { - collect_free_variables(bin_op->get_left(), lambda_context, outer_context, captures); - collect_free_variables(bin_op->get_right(), lambda_context, outer_context, captures); - return; - } - - // Handle unary ops - if (const auto* unary_op = cast_expr(node)) - { - collect_free_variables(&unary_op->get_operand(), lambda_context, outer_context, captures); - return; - } - - // Handle if statements - if (auto* if_stmt = cast_ast(node)) - { - collect_free_variables(if_stmt->get_condition(), lambda_context, outer_context, captures); - collect_free_variables(if_stmt->get_body(), lambda_context, outer_context, captures); - if (if_stmt->get_else_body()) - { - collect_free_variables( - if_stmt->get_else_body(), - lambda_context, - outer_context, - captures); - } - return; - } - - // Handle while loops - if (auto* while_loop = cast_ast(node)) - { - collect_free_variables( - while_loop->get_condition(), - lambda_context, - outer_context, - captures); - collect_free_variables(while_loop->get_body(), lambda_context, outer_context, captures); - return; - } - - // Handle for loops - if (auto* for_loop = cast_ast(node)) - { - if (for_loop->get_initializer()) - { - collect_free_variables( - for_loop->get_initializer(), - lambda_context, - outer_context, - captures); - } - if (for_loop->get_condition()) - { - collect_free_variables( - for_loop->get_condition(), - lambda_context, - outer_context, - captures); - } - if (for_loop->get_incrementor()) - { - collect_free_variables( - for_loop->get_incrementor(), - lambda_context, - outer_context, - captures); - } - - collect_free_variables(for_loop->get_body(), lambda_context, outer_context, captures); - return; - } - - // Handle array literals - if (const auto* array = cast_expr(node)) - { - for (const auto& elem : array->get_elements()) - { - collect_free_variables(elem.get(), lambda_context, outer_context, captures); - } - return; - } - - // Handle struct initializers - if (const auto* struct_init = cast_expr(node)) - { - for (const auto& val : struct_init->get_initializers() | std::views::values) - { - collect_free_variables(val.get(), lambda_context, outer_context, captures); - } - return; - } - - // Handle chained member access - if (const auto* chained = cast_expr(node)) - { - collect_free_variables(chained->get_base(), lambda_context, outer_context, captures); - return; - } - - // Handle blocks (lambda bodies, function bodies, etc.) - if (const auto* block = cast_ast(node)) - { - for (const auto& child : block->get_children()) - { - collect_free_variables(child.get(), lambda_context, outer_context, captures); - } - return; - } - - // Generic container handling (if statements, loops, etc.) - this should be last - if (auto* container = dynamic_cast(node)) - { - if (const auto* body = container->get_body()) - { - for (const auto& child : body->get_children()) - { - collect_free_variables(child.get(), lambda_context, outer_context, captures); - } - } - } -} - -std::vector> IAstFunction::get_parameters() const -{ - std::vector> cloned_params; - cloned_params.reserve(this->_parameters.size()); - - for (const auto& param : this->_parameters) - { - cloned_params.push_back(param->clone_as()); - } - - return cloned_params; -} - -void IAstFunction::validate_candidate(IAstFunction* candidate) -{ - candidate->get_body()->validate(); - - const auto& ret_ty = candidate->get_return_type(); - const auto return_statements = collect_return_statements(candidate->get_body()); - - // For void types, we only disallow returning expressions, as this is redundant. - if (const auto void_ret = cast_type(ret_ty); - void_ret != nullptr && void_ret->get_primitive_type() == PrimitiveType::VOID) - { - for (const auto& return_stmt : return_statements) - { - if (return_stmt->get_return_expression().has_value()) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "{} has return type 'void' and cannot return a value.", - candidate->is_anonymous() - ? "Anonymous function" - : std::format("Function '{}'", candidate->get_plain_function_name())), - { - ErrorSourceReference( - "unexpected return value", - return_stmt->get_source_fragment() - ), - ErrorSourceReference( - "Function returning void type", - candidate->get_source_fragment() - ) - } - - ); - } - } - return; - } - - if (return_statements.empty()) - { - if (cast_type(ret_ty)) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Function '{}' returns a struct type, but no return statement is present.", - candidate->get_plain_function_name()), - candidate->get_source_fragment()); - } - - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format( - "{} is missing a return statement.", - candidate->is_anonymous() - ? "Anonymous function" - : std::format("Function '{}'", candidate->get_plain_function_name())), - candidate->get_source_fragment() - ); - } - - for (const auto& return_stmt : return_statements) - { - if (return_stmt->is_void_type()) - { - if (!ret_ty->is_void_ty()) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Function '{}' returns a value of type '{}', but no return statement is present.", - candidate->is_anonymous() ? "" : candidate->get_plain_function_name(), - ret_ty->to_string()), - return_stmt->get_source_fragment() - ); - } - return; - } - - if (const auto& ret_expr = return_stmt->get_return_expression().value(); - !ret_expr->get_type()->equals(ret_ty) && - !ret_expr->get_type()->is_assignable_to(ret_ty)) - { - const auto error_fragment = ErrorSourceReference( - std::format( - "expected {}{}", - candidate->get_return_type()->is_primitive() - ? "" - : candidate->get_return_type()->is_function() - ? "function-type " - : "struct-type ", - candidate->get_return_type()->to_string()), - ret_expr->get_source_fragment() - ); - - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Function '{}' expected a return type of '{}', but received '{}'.", - candidate->is_anonymous() ? "" : candidate->get_plain_function_name(), - ret_ty->get_type_name(), - ret_expr->get_type()->get_type_name()), - { error_fragment } - ); - } - } -} - -void IAstFunction::validate() -{ - // Extern functions have no body to validate - if (this->is_extern()) - return; - - if (!this->is_generic()) - { - validate_candidate(this); - return; - } - - // - // For generic functions, we create a new copy of the function with all parameters resolved, and do validation - // on that copy. This is because we want to validate the function body with the actual types that will be used in - // the function, rather than the generic placeholders. - // - // create a copy of this function with the parameters instantiated - for (const auto definition = this->get_function_definition(); - const auto& [types, function, node] : definition->get_generic_overloads()) - { - auto instantiated_return_ty = resolve_generics( - this->_annotated_return_type.get(), - this->_generic_parameters, - types - ); - - std::vector> instantiated_function_params; - instantiated_function_params.reserve(this->_parameters.size()); - - for (const auto& param : this->_parameters) - { - instantiated_function_params.push_back( - std::make_unique( - param->get_source_fragment(), - param->get_context(), - param->get_name(), - resolve_generics(param->get_type(), this->_generic_parameters, types) - ) - ); - } - - const auto& candidate = std::make_unique( - this->get_context(), - this->_symbol, - std::move(instantiated_function_params), - this->_body->clone_as(), - std::move(instantiated_return_ty), - this->get_visibility(), - this->_flags, - this->get_generic_parameters() - ); - - validate_candidate(candidate.get()); - } -} - -llvm::Value* IAstFunction::codegen( - llvm::Module* module, - llvm::IRBuilderBase* builder -) -{ - llvm::Function* function = nullptr; - - for (const auto& [function_name, llvm_function_val] : this->get_function_overload_metadata()) - { - if (!llvm_function_val) - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format("Function symbol '{}' missing", function_name), - this->get_source_fragment() - ); - } - - - // If the function body has already been generated (has basic blocks), just return the function pointer - if (this->is_extern() || !llvm_function_val->empty()) - { - return llvm_function_val; - } - - // Save the current insert point to restore it later - // This is important when generating nested lambdas - llvm::BasicBlock* saved_insert_block = builder->GetInsertBlock(); - llvm::BasicBlock::iterator saved_insert_point; - const bool has_insert_point = saved_insert_block != nullptr; - - if (saved_insert_block) - { - saved_insert_point = builder->GetInsertPoint(); - } - - llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create( - module->getContext(), - "entry", - llvm_function_val - ); - builder->SetInsertPoint(entry_bb); - - // We create a new builder for the prologue to ensure allocas are at the very top - llvm::IRBuilder prologue_builder( - &llvm_function_val->getEntryBlock(), - llvm_function_val->getEntryBlock().begin() - ); - - // - // Captured variable handling - // Map captured variables to function arguments with __capture_ prefix - // - auto fn_parameter_argument = llvm_function_val->arg_begin(); - for (const auto& capture : this->_captured_variables) - { - if (fn_parameter_argument != llvm_function_val->arg_end()) - { - fn_parameter_argument->setName(closures::format_captured_variable_name(capture.internal_name)); - - // Create alloca with __capture_ prefix so identifier lookup can find it - llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( - fn_parameter_argument->getType(), - nullptr, - closures::format_captured_variable_name_internal(capture.internal_name) - ); - - builder->CreateStore(fn_parameter_argument, alloca); - ++fn_parameter_argument; - } - } - - // - // Function parameter handling - // Here we define the parameters on the stack as memory slots for the function - // - for (const auto& param : this->_parameters) - { - if (fn_parameter_argument != llvm_function_val->arg_end()) - { - fn_parameter_argument->setName(param->get_name() + ".arg"); - - // Create a memory slot on the stack for the parameter - llvm::AllocaInst* alloca = prologue_builder.CreateAlloca( - fn_parameter_argument->getType(), - nullptr, - param->get_name() - ); - - // Store the initial argument value into the alloca - builder->CreateStore(fn_parameter_argument, alloca); - - ++fn_parameter_argument; - } - } - - // Generate Body - llvm::Value* function_body_value = this->_body->codegen(module, builder); - - // Final Safety: Implicit Return - // If the get_body didn't explicitly return (no terminator found), add one. - if (llvm::BasicBlock* current_bb = builder->GetInsertBlock(); - current_bb && !current_bb->getTerminator()) - { - if (llvm::Type* ret_type = llvm_function_val->getReturnType(); - ret_type->isVoidTy()) - { - builder->CreateRetVoid(); - } - else if (function_body_value && function_body_value->getType() == ret_type) - { - builder->CreateRet(function_body_value); - } - // Default return to keep IR valid (useful for main or incomplete functions) - else - { - if (ret_type->isFloatingPointTy()) - { - builder->CreateRet(llvm::ConstantFP::get(ret_type, 0.0)); - } - else if (ret_type->isIntegerTy()) - { - builder->CreateRet(llvm::ConstantInt::get(ret_type, 0)); - } - else - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format( - "Function '{}' is missing a return path.", - this->get_plain_function_name() - ), - this->get_source_fragment() - ); - } - } - } - - if (llvm::verifyFunction(*llvm_function_val, &llvm::errs())) - { - module->print(llvm::errs(), nullptr); - throw parsing_error( - ErrorType::COMPILATION_ERROR, - "LLVM Function Verification Failed for: " + this->get_plain_function_name(), - this->get_source_fragment() - ); - } - - // Restore the previous insert point for nested lambda generation - if (has_insert_point && saved_insert_block) - { - builder->SetInsertPoint(saved_insert_block, saved_insert_point); - } - - // For anonymous lambdas with captured variables, create a closure structure - // that bundles the function pointer with the current values of captured variables - if (this->is_anonymous()) - { - // Collect the current values of captured variables from the enclosing scope - std::vector captured_values; - for (const auto& capture : this->get_captured_variables()) - { - if (const auto block = builder->GetInsertBlock()) - { - llvm::Function* current_fn = block->getParent(); - llvm::Value* captured_val = closures::lookup_variable_or_capture(current_fn, capture.internal_name); - - if (!captured_val) - { - captured_val = closures::lookup_variable_by_base_name(current_fn, capture.name); - } - - if (captured_val) - { - // Load the value if it's an alloca - if (auto* alloca = llvm::dyn_cast(captured_val)) - { - captured_val = builder->CreateLoad( - alloca->getAllocatedType(), - alloca, - capture.internal_name - ); - } - captured_values.push_back(captured_val); - } - } - } - - // Create and return a closure instead of the raw function pointer - return closures::create_closure(module, builder, llvm_function_val, captured_values); - } - - function = llvm_function_val; - } - - return function; -} - -/** - * Here we define the function in the symbol table, so it can be looked up in the codegen phase. - */ -void IAstFunction::resolve_forward_references( - llvm::Module* module, - llvm::IRBuilderBase* builder -) -{ - std::vector captures; - const auto outer_context = this->get_context()->get_parent_context() != nullptr - ? this->get_context()->get_parent_context() - : this->get_context(); - - collect_free_variables(this->get_body(), this->get_context(), outer_context, captures); - - // Register captured variables in the lambda's context so they can be referenced - for (const auto& capture : captures) - { - this->add_captured_variable(capture); - - // Also define the capture in the lambda's context so identifier lookup works - if (const auto outer_var = this->get_context()->lookup_variable(capture.name, true)) - { - this->get_context()->define_variable( - capture, - outer_var->get_type()->clone_ty(), - VisibilityModifier::PRIVATE - ); - } - } - - // Avoid re-registering if already declared. - // Named functions are looked up by their scoped name; anonymous functions are - // tracked by the cached _llvm_function pointer (they have no stable string name). - if (this->is_anonymous() && this->get_function_definition()->get_llvm_function() != nullptr) - return; - - // Add captured variables as first parameters - std::vector captured_types; - for (const auto& capture : this->_captured_variables) - { - if (const auto capture_def = this->get_context()->lookup_variable(capture.name, true)) - { - if (llvm::Type* capture_type = capture_def->get_type()->get_llvm_type(module)) - { - captured_types.push_back(capture_type); - } - } - } - - const auto& definition = this->get_function_definition(); - - const auto linkage = this->_visibility == VisibilityModifier::PRIVATE - ? llvm::Function::PrivateLinkage - : llvm::Function::ExternalLinkage; - - if (!this->is_generic()) - { - const auto overloaded_fn_name = this->get_registered_function_name(); - llvm::FunctionType* generic_function_type = this->get_llvm_function_type( - module, - captured_types, - {} - ); - - definition->set_llvm_function(llvm::Function::Create( - generic_function_type, - linkage, - overloaded_fn_name, - module - )); - this->_body->resolve_forward_references(module, builder); - - return; - } - - for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) - { - auto instantiated_return_ty = resolve_generics( - this->_annotated_return_type.get(), - this->_generic_parameters, - types - ); - - std::vector> instantiated_function_params; - instantiated_function_params.reserve(this->_parameters.size()); - - for (const auto& param : this->_parameters) - { - instantiated_function_params.push_back( - std::make_unique( - param->get_source_fragment(), - param->get_context(), - param->get_name(), - resolve_generics(param->get_type(), this->_generic_parameters, types) - ) - ); - } - - node = std::make_unique( - this->get_context(), - this->_symbol, - std::move(instantiated_function_params), - this->_body->clone_as(), - std::move(instantiated_return_ty), - this->get_visibility(), - this->_flags, - this->get_generic_parameters() - ); - - const auto overloaded_fn_name = get_overloaded_function_name(this->get_registered_function_name(), types); - llvm::FunctionType* generic_function_type = this->get_llvm_function_type( - module, - captured_types, - types - ); - - llvm_function = llvm::Function::Create( - generic_function_type, - linkage, - overloaded_fn_name, - module - ); - - for (const auto ¶m : instantiated_function_params) - { - const auto param_symbol = Symbol(param->get_source_fragment(), param->get_name()); - this->get_context()->define_variable( - param_symbol, - param->get_type()->clone_ty(), - VisibilityModifier::PRIVATE, - true - ); - } - - if (this->is_anonymous()) - llvm_function->addFnAttr("stride.anonymous"); - - this->_body->resolve_forward_references(module, builder); - } -} - -llvm::FunctionType* IAstFunction::get_llvm_function_type( - llvm::Module* module, - std::vector captured_variables, - const GenericTypeList& generic_instantiation_types -) const -{ - std::vector base_parameter_types; - - if (generic_instantiation_types.empty()) - { - for (const auto& param : this->_parameters) - { - if (llvm::Type* param_type = param->get_type()->get_llvm_type(module)) - { - base_parameter_types.push_back(param_type); - } - } - } - else - { - for (const auto& param : this->_parameters) - { - const auto& resolved_generic_param_type = resolve_generics( - param->get_type(), - this->_generic_parameters, - generic_instantiation_types - ); - if (llvm::Type* param_type = resolved_generic_param_type->get_llvm_type(module)) - { - base_parameter_types.push_back(param_type); - } - } - } - - std::vector parameter_types; - parameter_types.reserve(base_parameter_types.size() + captured_variables.size()); - - // Captured variables are first in the internal LLVM function type - parameter_types.insert(parameter_types.end(), captured_variables.begin(), captured_variables.end()); - parameter_types.insert(parameter_types.end(), base_parameter_types.begin(), base_parameter_types.end()); - - const auto return_type = this->get_return_type()->get_llvm_type(module); - - if (!return_type) - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - "Could not get LLVM return type for function: " + this->get_plain_function_name(), - this->get_source_fragment() - ); - } - - const auto& llvm_function_ty = llvm::FunctionType::get(return_type, parameter_types, this->is_variadic());; - - if (const auto fn = module->getFunction(this->get_registered_function_name()); - fn != nullptr && fn->getFunctionType() != llvm_function_ty) - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format( - "Function symbol '{}' already exists with a different signature", - this->get_registered_function_name() - ), - this->get_source_fragment() - ); - } - - return llvm_function_ty; -} - -std::vector> IAstFunction::get_parameter_types() const -{ - std::vector> types; - types.reserve(this->_parameters.size()); - - for (const auto& param : this->_parameters) - { - types.push_back(param->get_type()->clone_ty()); - } - - return types; -} - -FunctionDefinition* IAstFunction::get_function_definition() -{ - if (this->_function_definition != nullptr) - return this->_function_definition; - - const auto& definition = this->get_context()->get_function_definition( - this->get_registered_function_name(), - this->get_parameter_types(), - this->get_generic_parameters().size() - ); - - if (!definition.has_value()) - { - throw parsing_error( - ErrorType::REFERENCE_ERROR, - std::format("Function definition for '{}' not found in context", this->get_registered_function_name()), - this->get_source_fragment() - ); - } - - this->_function_definition = definition.value(); - return this->_function_definition; -} - -std::vector IAstFunction::get_function_overload_metadata() -{ - const auto& definition = this->get_function_definition(); - - // If the function is not generic, we just return a singular name (the regular internalized name) - if (!definition->get_type()->is_generic()) - { - return { - GenericFunctionMetadata{ this->get_registered_function_name(), definition->get_llvm_function() } - }; - } - - std::vector metadata; - - for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) - { - metadata.emplace_back( - get_overloaded_function_name(node->get_registered_function_name(), types), - llvm_function - ); - } - - return metadata; -} - -std::unique_ptr AstFunctionParameter::clone() -{ - return std::make_unique( - this->get_source_fragment(), - this->get_context(), - this->get_name(), - this->get_type()->clone_ty() - ); -} - -std::unique_ptr IAstFunction::clone() -{ - std::vector> cloned_params; - cloned_params.reserve(this->_parameters.size()); - - for (const auto& param : this->_parameters) - { - cloned_params.push_back(param->clone_as()); - } - - return std::make_unique( - this->get_source_fragment(), - this->get_context(), - this->_symbol, - std::move(cloned_params), - this->_body->clone_as(), - this->_annotated_return_type->clone_ty(), - this->_visibility, - this->_flags, - this->_generic_parameters - ); -} - -std::string IAstFunction::to_string() -{ - std::string params; - for (const auto& param : this->_parameters) - { - if (!params.empty()) - params += ", "; - params += param->to_string(); - } - - const auto body_str = this->get_body() == nullptr - ? "" - : this->get_body()->to_string(); - - return std::format( - "Function(name: {}(internal: {}), params: [{}], body: {}{} -> {})", - this->is_anonymous() ? "" : this->get_plain_function_name(), - this->get_registered_function_name(), - params, - body_str, - this->is_extern() ? " (extern)" : "", - this->get_return_type()->to_string() - ); -} diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 8ddebfa6..e4bd25a7 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -322,7 +322,7 @@ std::unique_ptr stride::ast::infer_identifier_type(const AstIdentifier ); } - if (const auto field = dynamic_cast(identifier_def.value())) + if (const auto field = dynamic_cast(identifier_def.value())) { return field->get_type()->clone_ty(); } diff --git a/packages/compiler/src/compilation/program.cpp b/packages/compiler/src/compilation/program.cpp index 3f754ef0..aab974e9 100644 --- a/packages/compiler/src/compilation/program.cpp +++ b/packages/compiler/src/compilation/program.cpp @@ -66,8 +66,11 @@ std::unique_ptr Program::prepare_module( for (const auto& node : this->_ast->get_files() | std::views::values) { runtime::register_runtime_symbols(node->get_context()); + + // Type resolution traverser.visit_block(&type_visitor, node.get()); + // Resolving forward references - Ensures symbols certain symbols are available before implementation node->resolve_forward_references( module.get(), &builder From a976680d9aec99a50d52dd7bac4e4a6d818579fe Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 17:01:23 +0100 Subject: [PATCH 11/31] refactor: move `resolve_forward_references` from `AstFunctionCall` to dedicated file, improve generic type resolution and argument handling --- .../compiler/include/ast/nodes/expression.h | 11 ++++++++++ .../nodes/functions/call/function_call.cpp | 18 ----------------- .../call/function_call_forward_refs.cpp | 20 +++++++++++++++++++ .../function_declaration_codegen.cpp | 1 - .../function_declaration_validation.cpp | 6 +++--- .../{ => declaration}/function_parameters.cpp | 0 6 files changed, 34 insertions(+), 22 deletions(-) create mode 100644 packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp rename packages/compiler/src/ast/nodes/functions/{ => declaration}/function_parameters.cpp (100%) diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index fd8241b6..2a3df530 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -87,6 +87,7 @@ namespace stride::ast std::unique_ptr _type; friend class IAstFunction; + friend class AstIdentifier; public: explicit IAstExpression( @@ -95,6 +96,16 @@ namespace stride::ast ) : IAstNode(source_position, context) {} + explicit IAstExpression( + const SourceFragment& source_position, + const std::shared_ptr& context, + std::unique_ptr type + ) : + IAstExpression(source_position, context) + { + this->_type = std::move(type); + } + ~IAstExpression() override = default; llvm::Value* codegen( diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index 5a13ce23..fd042890 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -188,24 +188,6 @@ void AstFunctionCall::validate() } } -void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) -{ - // Add generic types to function definition's generic instantiations - if (!this->_generic_type_arguments.empty()) - { - const auto& definition = this->get_function_definition(); - - definition->add_generic_instantiation( - copy_generic_type_list(this->_generic_type_arguments) - ); - } - - for (const auto& arg : this->_arguments) - { - arg->resolve_forward_references(module, builder); - } -} - std::vector> AstFunctionCall::get_argument_types() const { if (this->_arguments.empty()) diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp new file mode 100644 index 00000000..c336bdd5 --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp @@ -0,0 +1,20 @@ +#include "ast/definitions/function_definition.h" +#include "ast/nodes/expression.h" + +using namespace stride::ast; + +void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) +{ + // Add generic types to function definition's generic instantiations + if (!this->_generic_type_arguments.empty()) + { + const auto& definition = this->get_function_definition(); + + definition->add_generic_instantiation( + copy_generic_type_list(this->_generic_type_arguments) + ); + } + + for (const auto& arg : this->_arguments) + arg->resolve_forward_references(module, builder); +} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp index 4fbd0874..1700dd1b 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp @@ -25,7 +25,6 @@ llvm::Value* IAstFunction::codegen( ); } - // If the function body has already been generated (has basic blocks), just return the function pointer if (this->is_extern() || !llvm_function_val->empty()) { diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index 87a6a31d..6ebea037 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -25,12 +25,12 @@ void IAstFunction::validate() // // create a copy of this function with the parameters instantiated for (const auto definition = this->get_function_definition(); - const auto& [types, function, node] : definition->get_generic_overloads()) + const auto& [instantiated_generic_types, function, node] : definition->get_generic_overloads()) { auto instantiated_return_ty = resolve_generics( this->_annotated_return_type.get(), this->_generic_parameters, - types + instantiated_generic_types ); std::vector> instantiated_function_params; @@ -43,7 +43,7 @@ void IAstFunction::validate() param->get_source_fragment(), param->get_context(), param->get_name(), - resolve_generics(param->get_type(), this->_generic_parameters, types) + resolve_generics(param->get_type(), this->_generic_parameters, instantiated_generic_types) ) ); } diff --git a/packages/compiler/src/ast/nodes/functions/function_parameters.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_parameters.cpp similarity index 100% rename from packages/compiler/src/ast/nodes/functions/function_parameters.cpp rename to packages/compiler/src/ast/nodes/functions/declaration/function_parameters.cpp From d7daa73cc461dd8a29c8cf809ac31c4f8a734030 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 17:26:22 +0100 Subject: [PATCH 12/31] refactor: add `resolve_generics_in_body` for recursive generic resolution in function bodies, standardize naming in function metadata - Introduced `resolve_generics_in_body` to resolve generic types in function bodies recursively, with bottom-up traversal. - Renamed `GenericFunctionMetadata` to `FunctionImplementation`, streamlining function metadata handling. - Updated logic for cloning function bodies with resolved generics. - Improved type re-inference during generic parameter substitution. --- packages/compiler/include/ast/generics.h | 15 ++ .../include/ast/nodes/function_declaration.h | 4 +- .../src/ast/context/function_registry.cpp | 2 +- packages/compiler/src/ast/generics.cpp | 163 +++++++++++++++++- .../declaration/function_declaration.cpp | 23 +-- .../function_declaration_codegen.cpp | 4 +- .../function_declaration_forward_refs.cpp | 4 +- .../function_declaration_validation.cpp | 14 +- 8 files changed, 205 insertions(+), 24 deletions(-) diff --git a/packages/compiler/include/ast/generics.h b/packages/compiler/include/ast/generics.h index 0e71ec13..01daf91f 100644 --- a/packages/compiler/include/ast/generics.h +++ b/packages/compiler/include/ast/generics.h @@ -6,6 +6,7 @@ namespace stride::ast { + class IAstNode; class AstObjectInitializer; class AstObjectType; @@ -49,4 +50,18 @@ namespace stride::ast GenericTypeList copy_generic_type_list(const GenericTypeList& list); std::string get_overloaded_function_name(std::string function_name, const GenericTypeList& overload_types); + + /** + * Recursively walks a function body AST and resolves generic type parameters + * on all expression nodes, in the same manner as resolve_generics does for types. + * + * For each expression, the type is re-inferred (via context lookup) and then + * any generic parameters are substituted with their concrete instantiated types. + * The walk is bottom-up so that child expression types are resolved before their parents. + */ + void resolve_generics_in_body( + IAstNode* node, + const GenericParameterList& param_names, + const GenericTypeList& instantiated_types + ); } diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index 8515f9a6..571875b9 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -73,7 +73,7 @@ namespace stride::ast * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ - struct GenericFunctionMetadata + struct FunctionImplementation { std::string overload_function_name; llvm::Function* llvm_function; @@ -136,7 +136,7 @@ namespace stride::ast /// function is defined with generic parameters, there will be several overloads generated /// for each generic instantiation. This function returns the internalized name of each overload. [[nodiscard]] - std::vector get_generic_function_metadata(); + std::vector get_function_implementation_data(); [[nodiscard]] AstBlock* get_body() override diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index fa7ae86f..69168dac 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -164,7 +164,7 @@ bool ParsingContext::is_function_defined_globally( bool FunctionDefinition::has_generic_instantiation(const std::vector>& generic_types) const { - for (const auto& [instantiated_generic_types, llvm_function, node] : this->_generic_overloads) + for (const auto& [instantiated_generic_types, llvm_function, _node] : this->_generic_overloads) { bool all_equal = true; for (size_t i = 0; i < generic_types.size(); i++) diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index fc3669b6..a4f53b70 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -3,11 +3,20 @@ #include "errors.h" #include "ast/casting.h" #include "ast/parsing_context.h" +#include "ast/type_inference.h" +#include "ast/nodes/blocks.h" +#include "ast/nodes/conditional_statement.h" #include "ast/nodes/expression.h" +#include "ast/nodes/for_loop.h" +#include "ast/nodes/function_declaration.h" +#include "ast/nodes/return_statement.h" #include "ast/nodes/types.h" +#include "ast/nodes/while_loop.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" +#include + using namespace stride::ast; GenericParameterList stride::ast::parse_generic_declaration(TokenSet& set) @@ -114,7 +123,7 @@ std::unique_ptr stride::ast::resolve_generics( ); } - if (auto* object_type = cast_type(type)) + if (const auto* object_type = cast_type(type)) { const auto& members = object_type->get_members(); ObjectTypeMemberList resolved_members; @@ -199,7 +208,6 @@ std::unique_ptr stride::ast::instantiate_generic_type( const AstAliasType* alias_type, const definition::TypeDefinition* type_definition) { - const auto& instantiated_types = alias_type->get_instantiated_generic_types(); const auto& generic_param_names = type_definition->get_generics_parameters(); @@ -321,3 +329,154 @@ std::string stride::ast::get_overloaded_function_name(std::string function_name, join(generic_instantiation_type_names, "_") ); } + +/** + * Resolves an expression's type by re-inferring it (from context) and then substituting + * any generic parameter names with the concrete instantiated types. + */ +static void resolve_expression_type( + stride::ast::IAstExpression* expr, + const stride::ast::GenericParameterList& param_names, + const stride::ast::GenericTypeList& instantiated_types +) +{ + auto inferred_type = stride::ast::infer_expression_type(expr); + expr->set_type( + stride::ast::resolve_generics(inferred_type.get(), param_names, instantiated_types) + ); +} + +void stride::ast::resolve_generics_in_body( + IAstNode* node, + const GenericParameterList& param_names, + const GenericTypeList& instantiated_types +) +{ + if (!node || param_names.empty()) + return; + + // + // Statement / container nodes — recurse into children first (bottom-up). + // + + if (auto* block = dynamic_cast(node)) + { + for (const auto& child : block->get_children()) + resolve_generics_in_body(child.get(), param_names, instantiated_types); + return; + } + + if (auto* return_stmt = dynamic_cast(node)) + { + if (return_stmt->get_return_expression().has_value()) + resolve_generics_in_body( + return_stmt->get_return_expression().value().get(), + param_names, instantiated_types); + return; + } + + if (auto* conditional = dynamic_cast(node)) + { + resolve_generics_in_body(conditional->get_condition(), param_names, instantiated_types); + resolve_generics_in_body(conditional->get_body(), param_names, instantiated_types); + if (conditional->get_else_body()) + resolve_generics_in_body(conditional->get_else_body(), param_names, instantiated_types); + return; + } + + if (auto* while_loop = dynamic_cast(node)) + { + if (while_loop->get_condition()) + resolve_generics_in_body(while_loop->get_condition(), param_names, instantiated_types); + resolve_generics_in_body(while_loop->get_body(), param_names, instantiated_types); + return; + } + + if (auto* for_loop = dynamic_cast(node)) + { + if (for_loop->get_initializer()) + resolve_generics_in_body(for_loop->get_initializer(), param_names, instantiated_types); + if (for_loop->get_condition()) + resolve_generics_in_body(for_loop->get_condition(), param_names, instantiated_types); + if (for_loop->get_incrementor()) + resolve_generics_in_body(for_loop->get_incrementor(), param_names, instantiated_types); + resolve_generics_in_body(for_loop->get_body(), param_names, instantiated_types); + return; + } + + // + // Expression nodes — recurse into sub-expressions first, then resolve this node's type. + // + + auto* expr = dynamic_cast(node); + if (!expr) + return; + + // --- Recurse into child expressions (bottom-up order) --- + + if (auto* binary = cast_expr(expr)) + { + resolve_generics_in_body(binary->get_left(), param_names, instantiated_types); + resolve_generics_in_body(binary->get_right(), param_names, instantiated_types); + } + else if (auto* unary = cast_expr(expr)) + { + resolve_generics_in_body(&unary->get_operand(), param_names, instantiated_types); + } + else if (auto* var_decl = cast_expr(expr)) + { + if (var_decl->get_initial_value()) + resolve_generics_in_body(var_decl->get_initial_value(), param_names, instantiated_types); + } + else if (auto* fn_call = cast_expr(expr)) + { + for (const auto& arg : fn_call->get_arguments()) + resolve_generics_in_body(arg.get(), param_names, instantiated_types); + } + else if (auto* array = cast_expr(expr)) + { + for (const auto& elem : array->get_elements()) + resolve_generics_in_body(elem.get(), param_names, instantiated_types); + } + else if (auto* array_accessor = cast_expr(expr)) + { + resolve_generics_in_body(array_accessor->get_array_base(), param_names, instantiated_types); + resolve_generics_in_body(array_accessor->get_index(), param_names, instantiated_types); + } + else if (auto* struct_init = cast_expr(expr)) + { + for (const auto& val : struct_init->get_initializers() | std::views::values) + resolve_generics_in_body(val.get(), param_names, instantiated_types); + } + else if (auto* tuple_init = cast_expr(expr)) + { + for (const auto& member : tuple_init->get_members()) + resolve_generics_in_body(member.get(), param_names, instantiated_types); + } + else if (auto* reassign = cast_expr(expr)) + { + resolve_generics_in_body(reassign->get_identifier(), param_names, instantiated_types); + resolve_generics_in_body(reassign->get_value(), param_names, instantiated_types); + } + else if (auto* chained = cast_expr(expr)) + { + resolve_generics_in_body(chained->get_base(), param_names, instantiated_types); + } + else if (auto* type_cast = cast_expr(expr)) + { + resolve_generics_in_body(type_cast->get_value(), param_names, instantiated_types); + } + else if (auto* indirect_call = cast_expr(expr)) + { + for (const auto& arg : indirect_call->get_args()) + resolve_generics_in_body(arg.get(), param_names, instantiated_types); + resolve_generics_in_body(indirect_call->get_callee(), param_names, instantiated_types); + } + else if (auto* function_node = cast_expr(expr)) + { + resolve_generics_in_body(function_node->get_body(), param_names, instantiated_types); + } + + // --- Now resolve the type for this expression node --- + resolve_expression_type(expr, param_names, instantiated_types); +} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp index 6f7b4268..09f52a5f 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp @@ -1,29 +1,22 @@ -#include "ast/definitions/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "errors.h" -#include "ast/casting.h" #include "ast/closures.h" #include "ast/modifiers.h" #include "ast/parsing_context.h" #include "ast/symbols.h" +#include "ast/definitions/function_definition.h" #include "ast/nodes/blocks.h" -#include "ast/nodes/conditional_statement.h" #include "ast/nodes/expression.h" -#include "ast/nodes/for_loop.h" -#include "ast/nodes/function_declaration.h" #include "ast/nodes/return_statement.h" -#include "ast/nodes/while_loop.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" #include -#include #include #include #include -#include #include -#include using namespace stride::ast; using namespace stride::ast::definition; @@ -258,29 +251,29 @@ FunctionDefinition* IAstFunction::get_function_definition() return this->_function_definition; } -std::vector IAstFunction::get_generic_function_metadata() +std::vector IAstFunction::get_function_implementation_data() { const auto& definition = this->get_function_definition(); // If the function is not generic, we just return a singular name (the regular internalized name) - if (!definition->get_type()->is_generic()) + if (definition->get_generic_overloads().empty()) { return { - GenericFunctionMetadata{ this->get_registered_function_name(), definition->get_llvm_function() } + FunctionImplementation{ this->get_registered_function_name(), definition->get_llvm_function() } }; } - std::vector metadata; + std::vector implementations; for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) { - metadata.emplace_back( + implementations.emplace_back( get_overloaded_function_name(node->get_registered_function_name(), types), llvm_function ); } - return metadata; + return implementations; } std::unique_ptr AstFunctionParameter::clone() diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp index 1700dd1b..db3bb56f 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp @@ -7,6 +7,8 @@ using namespace stride::ast; + + llvm::Value* IAstFunction::codegen( llvm::Module* module, llvm::IRBuilderBase* builder @@ -14,7 +16,7 @@ llvm::Value* IAstFunction::codegen( { llvm::Function* function = nullptr; - for (const auto& [function_name, llvm_function_val] : this->get_generic_function_metadata()) + for (const auto& [function_name, llvm_function_val] : this->get_function_implementation_data()) { if (!llvm_function_val) { diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index ba04640f..8199ed75 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -111,12 +111,12 @@ void IAstFunction::resolve_forward_references( node = std::make_unique( this->get_context(), - this->_symbol, + this->get_symbol(), std::move(instantiated_function_params), this->_body->clone_as(), std::move(instantiated_return_ty), this->get_visibility(), - this->_flags, + this->get_flags(), this->get_generic_parameters() ); diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index 6ebea037..c87c353c 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -48,11 +48,23 @@ void IAstFunction::validate() ); } + // Clone the body and resolve generic types on every expression within it. + // The cloned body's expressions have no types set (clone does not preserve inferred types), + // so resolve_generics_in_body re-infers each expression's type from context (which still + // carries the generic parameter names, e.g. T) and then substitutes them with the concrete + // instantiated types (e.g. i32). + auto resolved_body = this->_body->clone_as(); + resolve_generics_in_body( + resolved_body.get(), + this->_generic_parameters, + instantiated_generic_types + ); + node = std::make_unique( this->get_context(), this->_symbol, std::move(instantiated_function_params), - this->_body->clone_as(), + std::move(resolved_body), std::move(instantiated_return_ty), this->get_visibility(), this->_flags, From 3f06386510b6e5efad755eb314c98f08a7bb715b Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 18:01:25 +0100 Subject: [PATCH 13/31] Resolved generic resolution for functions --- .../ast/definitions/function_definition.h | 4 ++- .../compiler/include/ast/nodes/traversal.h | 21 +-------------- packages/compiler/include/ast/visitor.h | 20 ++++++++++++++ .../src/ast/context/function_registry.cpp | 24 ++++++++++++++++- .../functions/call/function_call_codegen.cpp | 16 +++++++---- .../call/function_call_forward_refs.cpp | 10 ------- .../function_declaration_forward_refs.cpp | 10 +++++++ .../src/ast/traversal/expression_visitor.cpp | 13 ++++++++- .../src/ast/traversal/function_visitor.cpp | 2 +- .../compiler/src/ast/traversal/traversal.cpp | 1 + packages/compiler/src/compilation/program.cpp | 27 +++++++++++++------ 11 files changed, 101 insertions(+), 47 deletions(-) diff --git a/packages/compiler/include/ast/definitions/function_definition.h b/packages/compiler/include/ast/definitions/function_definition.h index 13cfb2b4..4584406e 100644 --- a/packages/compiler/include/ast/definitions/function_definition.h +++ b/packages/compiler/include/ast/definitions/function_definition.h @@ -58,7 +58,7 @@ namespace stride::ast::definition return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; } - void add_generic_instantiation(GenericTypeList generic_overload_types); + void add_generic_overload(GenericTypeList generic_overload_types); [[nodiscard]] const std::vector& get_generic_overloads() const @@ -69,6 +69,8 @@ namespace stride::ast::definition [[nodiscard]] bool has_generic_instantiation(const GenericTypeList& generic_types) const; + llvm::Function* get_generic_overload_llvm_function(const GenericTypeList& generic_types) const; + ~FunctionDefinition() override = default; bool matches_type_signature(const std::string& name, const AstFunctionType* signature) const; diff --git a/packages/compiler/include/ast/nodes/traversal.h b/packages/compiler/include/ast/nodes/traversal.h index 438e43ea..f43a671f 100644 --- a/packages/compiler/include/ast/nodes/traversal.h +++ b/packages/compiler/include/ast/nodes/traversal.h @@ -2,6 +2,7 @@ namespace stride::ast { + class IVisitor; class AstFunctionCall; class AstPackage; class AstImport; @@ -15,26 +16,6 @@ namespace stride::ast class AstReturnStatement; class AstBlock; - /// Visitor interface for expression nodes. - /// Implementations receive each expression after all its child expressions - /// have already been visited (bottom-up, post-order traversal). - class IVisitor - { - public: - virtual ~IVisitor() = default; - - /// Called for every expression node, after its sub-expressions have been visited. - virtual void accept(IAstExpression* expr) {}; - - virtual void accept(IAstFunction* expr) {}; - - virtual void accept(AstImport* node) {} - - virtual void accept(AstPackage* node) {} - - virtual void accept(AstFunctionCall* function_call) {} - }; - /// Traverses an AST tree and invokes an IVisitor for each expression node. /// Traversal is bottom-up (children are visited before their parent expression), /// ensuring that child expression types are available when the parent is visited. diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h index 11b1e6d8..a31fa344 100644 --- a/packages/compiler/include/ast/visitor.h +++ b/packages/compiler/include/ast/visitor.h @@ -15,6 +15,26 @@ namespace stride::ast class IAstExpression; class Ast; + /// Visitor interface for expression nodes. + /// Implementations receive each expression after all its child expressions + /// have already been visited (bottom-up, post-order traversal). + class IVisitor + { + public: + virtual ~IVisitor() = default; + + /// Called for every expression node, after its sub-expressions have been visited. + virtual void accept(IAstExpression* expr) {}; + + virtual void accept(IAstFunction* expr) {}; + + virtual void accept(AstImport* node) {} + + virtual void accept(AstPackage* node) {} + + virtual void accept(AstFunctionCall* function_call) {} + }; + /// Visitor that infers and assigns types to every expression node in the AST. /// /// Used together with AstNodeTraverser: the traverser drives bottom-up traversal, diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 69168dac..515b91ec 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -184,7 +184,7 @@ bool FunctionDefinition::has_generic_instantiation(const std::vector_generic_overloads) + { + bool all_equal = true; + for (size_t i = 0; i < generic_types.size(); i++) + { + if (!instantiated_generic_types[i]->equals(generic_types[i].get())) + { + all_equal = false; + break; + } + } + if (all_equal) + { + return llvm_function; + } + } + + return nullptr; +} \ No newline at end of file diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp index d8a7bdb4..ee4bd9d4 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp @@ -41,14 +41,20 @@ llvm::Value* AstFunctionCall::codegen( llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) { - const auto& function_definition = this->get_function_definition(); + const auto& fn_def = this->get_function_definition(); - if (llvm::Function* callee = module->getFunction(function_definition->get_internal_symbol_name())) + if (this->_generic_type_arguments.empty()) { - return callee; + if (llvm::Function* callee = module->getFunction(fn_def->get_internal_symbol_name())) + { + return callee; + } + } + else if (auto* llvm_func = fn_def->get_generic_overload_llvm_function(this->_generic_type_arguments)) + { + return llvm_func; } - const auto fn_def = function_definition; const auto fn_type = fn_def->get_type(); std::vector param_types; param_types.reserve(fn_type->get_parameter_types().size()); @@ -462,4 +468,4 @@ llvm::Value* AstFunctionCall::codegen_anonymous_function_call( } return nullptr; -} \ No newline at end of file +} diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp index c336bdd5..77eadd2d 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_forward_refs.cpp @@ -5,16 +5,6 @@ using namespace stride::ast; void AstFunctionCall::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) { - // Add generic types to function definition's generic instantiations - if (!this->_generic_type_arguments.empty()) - { - const auto& definition = this->get_function_definition(); - - definition->add_generic_instantiation( - copy_generic_type_list(this->_generic_type_arguments) - ); - } - for (const auto& arg : this->_arguments) arg->resolve_forward_references(module, builder); } diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index 8199ed75..5d309b7c 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -86,6 +86,16 @@ void IAstFunction::resolve_forward_references( return; } + if (const auto& overloads = definition->get_generic_overloads(); + overloads.empty()) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format("Generic function '{}' has no instantiations", this->get_plain_function_name()), + this->get_source_fragment() + ); + } + for (const auto& [instantiated_generic_types, llvm_function, node] : definition->get_generic_overloads()) { auto instantiated_return_ty = resolve_generics( diff --git a/packages/compiler/src/ast/traversal/expression_visitor.cpp b/packages/compiler/src/ast/traversal/expression_visitor.cpp index 629e6aff..9b05de48 100644 --- a/packages/compiler/src/ast/traversal/expression_visitor.cpp +++ b/packages/compiler/src/ast/traversal/expression_visitor.cpp @@ -1,6 +1,7 @@ #include "ast/parsing_context.h" #include "ast/type_inference.h" #include "ast/visitor.h" +#include "ast/definitions/function_definition.h" #include "ast/nodes/expression.h" using namespace stride::ast; @@ -24,4 +25,14 @@ void ExpressionVisitor::accept(IAstExpression* expr) var_decl->get_visibility() ); } -} \ No newline at end of file + else if (auto* function_call = dynamic_cast(expr); + function_call != nullptr && + !function_call->get_generic_type_arguments().empty() + ) + { + const auto& definition = function_call->get_function_definition(); + definition->add_generic_overload( + copy_generic_type_list(function_call->get_generic_type_arguments()) + ); + } +} diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index a8764346..0b3aac48 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -30,4 +30,4 @@ void FunctionVisitor::accept(IAstFunction* function) function->get_visibility(), function->get_flags() ); -} +} \ No newline at end of file diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 5f4a2c5a..81de9fbc 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -2,6 +2,7 @@ #include "ast/casting.h" #include "ast/parsing_context.h" +#include "ast/visitor.h" #include "ast/nodes/ast_node.h" #include "ast/nodes/blocks.h" #include "ast/nodes/conditional_statement.h" diff --git a/packages/compiler/src/compilation/program.cpp b/packages/compiler/src/compilation/program.cpp index aab974e9..a9bdb569 100644 --- a/packages/compiler/src/compilation/program.cpp +++ b/packages/compiler/src/compilation/program.cpp @@ -48,28 +48,39 @@ std::unique_ptr Program::prepare_module( llvm::IRBuilder<> builder(context); ast::AstNodeTraverser traverser; - ast::ExpressionVisitor type_visitor; + ast::ExpressionVisitor expression_visitor; ast::FunctionVisitor function_visitor; ast::ImportVisitor import_visitor; - /// --- First step - Cross-file symbol registration (imports and function signatures) + // + // First step - Cross-file symbol registration (imports and function signatures) + // for (const auto& [file_name, node] : this->_ast->get_files()) { + // Populate own symbol table with stride runtime symbols + // These are externally available functions that are linked after codegen + runtime::register_runtime_symbols(node->get_context()); + + // Resolve imports and populate local registry - Used for cross registration step import_visitor.set_current_file_name(file_name); traverser.visit_block(&import_visitor, node.get()); - traverser.visit_block(&function_visitor, node.get()); // Ensures functions are defined in our symbol table + // Ensures functions are defined in our symbol table + traverser.visit_block(&function_visitor, node.get()); } import_visitor.cross_register_symbols(this->_ast.get()); - /// --- Second step - Type resolution and symbol forward declarations + // + // Second step - Type resolution and symbol forward declarations + // for (const auto& node : this->_ast->get_files() | std::views::values) { - runtime::register_runtime_symbols(node->get_context()); - - // Type resolution - traverser.visit_block(&type_visitor, node.get()); + // Type checker - this must be executed after all external symbols have been populated + traverser.visit_block(&expression_visitor, node.get()); + } + for (const auto& node : this->_ast->get_files() | std::views::values) + { // Resolving forward references - Ensures symbols certain symbols are available before implementation node->resolve_forward_references( module.get(), From 0a00eba34a6fa1e4b2175c9753768107ab69a5a6 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 18:35:10 +0100 Subject: [PATCH 14/31] Added validation for string comparisons --- example.sr | 4 ++-- packages/compiler/src/ast/generics.cpp | 2 +- .../expressions/comparison_operation.cpp | 19 +++++++++++++++++-- .../nodes/functions/call/function_call.cpp | 8 -------- .../functions/call/function_call_validate.cpp | 12 ++++++++++++ .../function_declaration_validation.cpp | 5 +++-- 6 files changed, 35 insertions(+), 15 deletions(-) create mode 100644 packages/compiler/src/ast/nodes/functions/call/function_call_validate.cpp diff --git a/example.sr b/example.sr index 5a1a7d33..25b2e6d7 100644 --- a/example.sr +++ b/example.sr @@ -9,7 +9,7 @@ type Car = { names: Array; }; -fn some_comparison(a: T, b: T): bool { +fn some_comparison(a: A, b: B): bool { return a == b; } @@ -25,7 +25,7 @@ fn main(): i32 { myCar.drive(); io::print("Driving: %s, %s, %s", myCar.names[0], myCar.names[1], myCar.names[2]); - io::print("Compared: %b", some_comparison(5, 5)); + io::print("Compared: %b", some_comparison(5, 5)); return 0; } \ No newline at end of file diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index a4f53b70..698d5471 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -76,7 +76,7 @@ std::unique_ptr stride::ast::resolve_generics( throw parsing_error( ErrorType::TYPE_ERROR, std::format( - "Failed to resolve generic type: expected {} parameters, got ", + "Failed to resolve generic type: expected {} parameters, got {}", param_names.size(), instantiated_types.size()), type->get_source_fragment() diff --git a/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp b/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp index b875d984..4ba3aca1 100644 --- a/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/comparison_operation.cpp @@ -59,12 +59,27 @@ void AstComparisonOp::validate() const auto lhs_type = this->get_left()->get_type(); const auto rhs_type = this->get_right()->get_type(); + const auto lhs_primitive = cast_type(lhs_type); + const auto rhs_primitive = cast_type(rhs_type); + // Both sides are primitives if (lhs_type->is_primitive() && rhs_type->is_primitive()) + { + if (lhs_primitive->get_primitive_type() == PrimitiveType::STRING || rhs_primitive->get_primitive_type() == + PrimitiveType::STRING) + { + throw parsing_error( + ErrorType::SEMANTIC_ERROR, + "Cannot compare string literals", + { + ErrorSourceReference(lhs_type->get_type_name(), _lhs->get_source_fragment()), + ErrorSourceReference(rhs_type->get_type_name(), _rhs->get_source_fragment()) + } + ); + } return; + } - const auto lhs_primitive = cast_type(lhs_type); - const auto rhs_primitive = cast_type(rhs_type); // If LHS is NIL and RHS is valid, allow the comparison (nil checks) if (lhs_primitive && rhs_primitive diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index fd042890..3180b7d1 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -180,14 +180,6 @@ std::string AstFunctionCall::format_function_name() const return std::format("{}({})", this->get_function_name(), join(arg_types, ", ")); } -void AstFunctionCall::validate() -{ - for (const auto& arg : this->_arguments) - { - arg->validate(); - } -} - std::vector> AstFunctionCall::get_argument_types() const { if (this->_arguments.empty()) diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_validate.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_validate.cpp new file mode 100644 index 00000000..52aa11e4 --- /dev/null +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_validate.cpp @@ -0,0 +1,12 @@ +#include "ast/definitions/function_definition.h" +#include "ast/nodes/expression.h" + +using namespace stride::ast; + +void AstFunctionCall::validate() +{ + for (const auto& arg : this->_arguments) + { + arg->validate(); + } +} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index c87c353c..6b2a2d96 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -68,7 +68,7 @@ void IAstFunction::validate() std::move(instantiated_return_ty), this->get_visibility(), this->_flags, - this->get_generic_parameters() + EMPTY_GENERIC_PARAMETER_LIST // Omit generics - They've been resolved ); validate_candidate(node.get()); @@ -77,7 +77,6 @@ void IAstFunction::validate() void IAstFunction::validate_candidate(IAstFunction* candidate) { - candidate->get_body()->validate(); const auto& ret_ty = candidate->get_return_type(); const auto return_statements = collect_return_statements(candidate->get_body()); @@ -182,6 +181,8 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) ); } } + + candidate->get_body()->validate(); } std::vector IAstFunction::collect_return_statements(const AstBlock* body) From e94c446bafb7dadc04460e4fbddf8c29b70b57a1 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 18:49:12 +0100 Subject: [PATCH 15/31] Resolve and enhance generic type handling for structures and functions - Improved generic resolution for object initializers in the AST (`AstObjectInitializer`). - Added logic to resolve return types with generics in function declarations. - Updated expression node with the ability to set generic type arguments. --- example.sr | 29 +++++++++---------- .../compiler/include/ast/nodes/expression.h | 6 ++++ packages/compiler/src/ast/generics.cpp | 13 +++++++++ .../function_declaration_forward_refs.cpp | 26 +++++++++++++++-- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/example.sr b/example.sr index 25b2e6d7..9ac3bd4e 100644 --- a/example.sr +++ b/example.sr @@ -2,30 +2,29 @@ import System::{ io::print }; -type Array = T[]; - -type Car = { - drive: () -> void; - names: Array; +type Array = { + elements: T[]; + count: i32; }; -fn some_comparison(a: A, b: B): bool { - return a == b; +fn arrayOf(elements: T[], count: i32): Array { + return Array::{ + elements, + count, + }; } -fn makeCar(): Car { - const names = ["Toyota", "Honda", "Ford", "Toyota"]; +type Callback = () -> void; + +fn makeCb(): Callback { + const names = arrayOf(["Toyota", "Honda", "Ford", "Toyota"], 4); - return Car::{ drive: (): void -> io::print("\x1b[32mDriving the car"), names }; + return (): void -> io::print("\x1b[32mDriving the car: %s", names.elements[1]); } fn main(): i32 { - const myCar = makeCar(); - myCar.drive(); - io::print("Driving: %s, %s, %s", myCar.names[0], myCar.names[1], myCar.names[2]); - - io::print("Compared: %b", some_comparison(5, 5)); + makeCb()(); return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index 2a3df530..d3a4bd76 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -892,6 +892,12 @@ namespace stride::ast return !this->_generic_type_arguments.empty(); } + void set_generic_type_arguments(GenericTypeList generic_type_arguments) + { + this->_generic_type_arguments = std::move(generic_type_arguments); + this->_object_type = nullptr; // Reset cached type so it re-resolves + } + llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override; std::string to_string() override; diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index 698d5471..3212e7b3 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -445,6 +445,19 @@ void stride::ast::resolve_generics_in_body( } else if (auto* struct_init = cast_expr(expr)) { + // Resolve generic type arguments on the object initializer itself + // e.g. Array::{ ... } → Array::{ ... } + if (struct_init->has_generic_type_arguments()) + { + GenericTypeList resolved_args; + resolved_args.reserve(struct_init->get_generic_type_arguments().size()); + for (const auto& arg : struct_init->get_generic_type_arguments()) + { + resolved_args.push_back(resolve_generics(arg.get(), param_names, instantiated_types)); + } + struct_init->set_generic_type_arguments(std::move(resolved_args)); + } + for (const auto& val : struct_init->get_initializers() | std::views::values) resolve_generics_in_body(val.get(), param_names, instantiated_types); } diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index 5d309b7c..1b1f8cd6 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -119,15 +119,22 @@ void IAstFunction::resolve_forward_references( ); } + auto resolved_body = this->_body->clone_as(); + resolve_generics_in_body( + resolved_body.get(), + this->_generic_parameters, + instantiated_generic_types + ); + node = std::make_unique( this->get_context(), this->get_symbol(), std::move(instantiated_function_params), - this->_body->clone_as(), + std::move(resolved_body), std::move(instantiated_return_ty), this->get_visibility(), this->get_flags(), - this->get_generic_parameters() + EMPTY_GENERIC_PARAMETER_LIST ); const auto overloaded_fn_name = get_overloaded_function_name( @@ -477,7 +484,20 @@ llvm::FunctionType* IAstFunction::get_llvm_function_type( parameter_types.insert(parameter_types.end(), captured_variables.begin(), captured_variables.end()); parameter_types.insert(parameter_types.end(), base_parameter_types.begin(), base_parameter_types.end()); - const auto return_type = this->get_return_type()->get_llvm_type(module); + llvm::Type* return_type; + if (!generic_instantiation_types.empty()) + { + auto resolved_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + generic_instantiation_types + ); + return_type = resolved_return_ty->get_llvm_type(module); + } + else + { + return_type = this->get_return_type()->get_llvm_type(module); + } if (!return_type) { From ea4883f6ae3b6b4276023208e3a4a9d842242fc5 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 19:40:54 +0100 Subject: [PATCH 16/31] Resolve and improve generic type substitution for function calls and declarations - Updated type inference to handle substitution of generic parameter names with concrete types in function calls. - Refactored `function_definition` to `function_declaration` and enhanced overload handling for generic functions. - Adjusted LLVM code generation to use resolved bodies for generic overloads. - Updated function implementation metadata to include resolved function bodies. - Minor updates to tests to integrate empty generic type lists and align with new logic. --- .../include/ast/nodes/function_declaration.h | 1 + .../declaration/function_declaration.cpp | 3 ++- .../function_declaration_codegen.cpp | 7 ++++--- .../function_declaration_forward_refs.cpp | 3 ++- packages/compiler/src/ast/type_inference.cpp | 18 +++++++++++++++++- .../compiler/tests/test_type_inference.cpp | 4 +++- packages/compiler/tests/utils.h | 8 +++++--- 7 files changed, 34 insertions(+), 10 deletions(-) diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index 571875b9..c39cd67c 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -77,6 +77,7 @@ namespace stride::ast { std::string overload_function_name; llvm::Function* llvm_function; + AstBlock* body = nullptr; // Non-null for generic overloads (resolved body) }; class IAstFunction diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp index 09f52a5f..845d5091 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp @@ -269,7 +269,8 @@ std::vector IAstFunction::get_function_implementation_da { implementations.emplace_back( get_overloaded_function_name(node->get_registered_function_name(), types), - llvm_function + llvm_function, + node->get_body() ); } diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp index db3bb56f..33d0ba13 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp @@ -16,7 +16,7 @@ llvm::Value* IAstFunction::codegen( { llvm::Function* function = nullptr; - for (const auto& [function_name, llvm_function_val] : this->get_function_implementation_data()) + for (const auto& [function_name, llvm_function_val, overload_body] : this->get_function_implementation_data()) { if (!llvm_function_val) { @@ -104,8 +104,9 @@ llvm::Value* IAstFunction::codegen( } } - // Generate Body - llvm::Value* function_body_value = this->_body->codegen(module, builder); + // Generate Body — use resolved body for generic overloads + AstBlock* body_to_codegen = overload_body ? overload_body : this->_body.get(); + llvm::Value* function_body_value = body_to_codegen->codegen(module, builder); // Final Safety: Implicit Return // If the get_body didn't explicitly return (no terminator found), add one. diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index 1b1f8cd6..4b9ed2e0 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -1,5 +1,6 @@ #include "ast/casting.h" #include "ast/definitions/function_definition.h" +#include "ast/generics.h" #include "ast/nodes/conditional_statement.h" #include "ast/nodes/for_loop.h" #include "ast/nodes/function_declaration.h" @@ -167,7 +168,7 @@ void IAstFunction::resolve_forward_references( if (this->is_anonymous()) llvm_function->addFnAttr("stride.anonymous"); - this->_body->resolve_forward_references(module, builder); + node->get_body()->resolve_forward_references(module, builder); } } diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index e4bd25a7..bf21a606 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -3,6 +3,7 @@ #include "errors.h" #include "ast/casting.h" #include "ast/flags.h" +#include "ast/generics.h" #include "ast/parsing_context.h" #include "ast/definitions/function_definition.h" #include "ast/nodes/function_declaration.h" @@ -35,7 +36,22 @@ std::unique_ptr stride::ast::infer_function_call_return_type(AstFuncti ); fn_def.has_value()) { - return fn_def.value()->get_type()->get_return_type()->clone_ty(); + auto return_type = fn_def.value()->get_type()->get_return_type()->clone_ty(); + + // For generic function calls, resolve the return type by substituting + // generic parameter names with the concrete type arguments from the call site. + // e.g. arrayOf(...) has return type Array → Array + if (const auto& generic_args = fn_call->get_generic_type_arguments(); + !generic_args.empty()) + { + const auto& param_names = fn_def.value()->get_type()->get_generic_parameter_names(); + if (param_names.size() == generic_args.size()) + { + return_type = resolve_generics(return_type.get(), param_names, generic_args); + } + } + + return return_type; } /// --- Handling lambda functions that might be assigned to variables diff --git a/packages/compiler/tests/test_type_inference.cpp b/packages/compiler/tests/test_type_inference.cpp index 97a26a0a..c52b9cb3 100644 --- a/packages/compiler/tests/test_type_inference.cpp +++ b/packages/compiler/tests/test_type_inference.cpp @@ -2,7 +2,7 @@ #include "ast/nodes/literal_values.h" #include "ast/nodes/expression.h" #include "ast/nodes/types.h" -#include "ast/nodes/function_definition.h" +#include "ast/nodes/function_declaration.h" #include "ast/parsing_context.h" #include "ast/symbols.h" #include "errors.h" @@ -405,6 +405,7 @@ TEST_F(TypeInferenceTest, InferFunctionCall) context, dummy_iden("foo"), std::move(args), + EMPTY_GENERIC_TYPE_LIST, 0); EXPECT_EQ(infer_expression_type(fn_call.get())->to_string(), "f32"); @@ -425,6 +426,7 @@ TEST_F(TypeInferenceTest, InferFunctionCall) context, dummy_iden("bar"), ExpressionList{}, + EMPTY_GENERIC_TYPE_LIST, 0); EXPECT_EQ(infer_expression_type(l_fn_call.get())->to_string(), "i64"); } diff --git a/packages/compiler/tests/utils.h b/packages/compiler/tests/utils.h index 2b2f15eb..ffdfe5b4 100644 --- a/packages/compiler/tests/utils.h +++ b/packages/compiler/tests/utils.h @@ -29,17 +29,19 @@ namespace stride::tests auto node = parse_sequential(context, tokens); ast::AstNodeTraverser traverser; - ast::ExpressionVisitor type_visitor; + ast::ExpressionVisitor expression_visitor; ast::FunctionVisitor function_visitor; ast::ImportVisitor import_visitor; + runtime::register_runtime_symbols(node->get_context()); + import_visitor.set_current_file_name("test.sr"); traverser.visit_block(&import_visitor, node.get()); traverser.visit_block(&function_visitor, node.get()); + // import_visitor.cross_register_symbols(this->_ast.get()); - runtime::register_runtime_symbols(node->get_context()); - traverser.visit_block(&type_visitor, node.get()); + traverser.visit_block(&expression_visitor, node.get()); node->validate(); From 4991539543c3df836fe57317033f5c7b9f613ae7 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 19:52:54 +0100 Subject: [PATCH 17/31] Enhance generic function handling by adding LLVM function attributes and expanding free variable collection for various expression types --- .../function_declaration_forward_refs.cpp | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index 4b9ed2e0..448c36e9 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -76,12 +76,17 @@ void IAstFunction::resolve_forward_references( {} ); - definition->set_llvm_function(llvm::Function::Create( + auto* llvm_func = llvm::Function::Create( generic_function_type, linkage, this->get_registered_function_name(), module - )); + ); + definition->set_llvm_function(llvm_func); + + if (this->is_anonymous()) + llvm_func->addFnAttr("stride.anonymous"); + this->_body->resolve_forward_references(module, builder); return; @@ -421,6 +426,38 @@ void IAstFunction::collect_free_variables( return; } + // Handle array member access (e.g., arr[0], names.elements[1]) + if (const auto* array_access = cast_expr(node)) + { + collect_free_variables(array_access->get_array_base(), lambda_context, outer_context, captures); + collect_free_variables(array_access->get_index(), lambda_context, outer_context, captures); + return; + } + + // Handle indirect calls (calling function pointers / lambdas stored in variables) + if (const auto* indirect_call = cast_expr(node)) + { + collect_free_variables(indirect_call->get_callee(), lambda_context, outer_context, captures); + for (const auto& arg : indirect_call->get_args()) + collect_free_variables(arg.get(), lambda_context, outer_context, captures); + return; + } + + // Handle type casts + if (const auto* type_cast = cast_expr(node)) + { + collect_free_variables(type_cast->get_value(), lambda_context, outer_context, captures); + return; + } + + // Handle tuple initializers + if (const auto* tuple_init = cast_expr(node)) + { + for (const auto& member : tuple_init->get_members()) + collect_free_variables(member.get(), lambda_context, outer_context, captures); + return; + } + // Handle blocks (lambda bodies, function bodies, etc.) if (const auto* block = cast_ast(node)) { From 40e6d33ff664495d576be8773825ac9a407cd61c Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 20:39:32 +0100 Subject: [PATCH 18/31] Add generic support for function calls and improve type resolution --- example.sr | 6 +++ .../compiler/include/ast/nodes/expression.h | 5 ++- .../nodes/functions/call/function_call.cpp | 13 ++++++- .../functions/call/function_call_codegen.cpp | 39 ++++++++++++++----- .../src/ast/traversal/expression_visitor.cpp | 18 +++++---- 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/example.sr b/example.sr index 9ac3bd4e..a6e7b963 100644 --- a/example.sr +++ b/example.sr @@ -14,6 +14,10 @@ fn arrayOf(elements: T[], count: i32): Array { }; } +const increment: (i32) -> i32 = (x: i32): i32 -> { + return x + 1; +}; + type Callback = () -> void; fn makeCb(): Callback { @@ -26,5 +30,7 @@ fn main(): i32 { makeCb()(); + const result: i32 = increment(10); + return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index d3a4bd76..db77980c 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -20,6 +20,7 @@ namespace stride::ast namespace definition { + class FieldDefinition; class FunctionDefinition; class IDefinition; } @@ -392,7 +393,7 @@ namespace stride::ast GenericTypeList _generic_type_arguments; int _flags; - definition::FunctionDefinition* _definition = nullptr; + definition::IDefinition* _definition = nullptr; public: explicit AstFunctionCall( @@ -409,7 +410,7 @@ namespace stride::ast _flags(flags) {} [[nodiscard]] - definition::FunctionDefinition* get_function_definition(); + definition::IDefinition* get_function_definition(); [[nodiscard]] const ExpressionList& get_arguments() const diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index 3180b7d1..eaaa881b 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -206,9 +206,9 @@ const GenericTypeList& AstFunctionCall::get_generic_type_arguments() return this->_generic_type_arguments; } -FunctionDefinition* AstFunctionCall::get_function_definition() +IDefinition* AstFunctionCall::get_function_definition() { - if (this->_definition != nullptr) + if (this->_definition) return this->_definition; if (const auto def = this->get_context()->get_function_definition( @@ -222,6 +222,15 @@ FunctionDefinition* AstFunctionCall::get_function_definition() return this->_definition; } + if (const auto field_def = this->get_context()->get_variable_def( + this->get_scoped_function_name(), + true + )) + { + this->_definition = field_def; + return this->_definition; + } + throw parsing_error( ErrorType::REFERENCE_ERROR, std::format("Function '{}' was not found in this scope", this->format_function_name()), diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp index ee4bd9d4..bdb323ac 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp @@ -32,7 +32,7 @@ llvm::Value* AstFunctionCall::codegen( : ""; throw parsing_error( - ErrorType::REFERENCE_ERROR, + ErrorType::COMPILATION_ERROR, std::format("Function '{}' was not found in this scope", this->format_function_name()), this->get_source_fragment(), suggested_alternative @@ -41,21 +41,40 @@ llvm::Value* AstFunctionCall::codegen( llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) { - const auto& fn_def = this->get_function_definition(); + const auto& definition = this->get_function_definition(); - if (this->_generic_type_arguments.empty()) + const AstFunctionType* fn_type = nullptr; + + if (const auto* fn_definition = dynamic_cast(definition)) { - if (llvm::Function* callee = module->getFunction(fn_def->get_internal_symbol_name())) + fn_type = fn_definition->get_type(); + if (this->_generic_type_arguments.empty()) + { + if (llvm::Function* callee = fn_definition->get_llvm_function()) + { + return callee; + } + } + else if (auto* llvm_func = fn_definition->get_generic_overload_llvm_function(this->_generic_type_arguments)) { - return callee; + return llvm_func; } + } else if (dynamic_cast(definition)) + { + // Callable variables (function pointers, lambdas) are handled by + // codegen_anonymous_function_call, not as regular callees. + return nullptr; } - else if (auto* llvm_func = fn_def->get_generic_overload_llvm_function(this->_generic_type_arguments)) + + if (fn_type == nullptr) { - return llvm_func; + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format("Function '{}' is not a function", this->format_function_name()), + this->get_source_fragment() + ); } - const auto fn_type = fn_def->get_type(); std::vector param_types; param_types.reserve(fn_type->get_parameter_types().size()); @@ -76,7 +95,7 @@ llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) } else { - llvm_is_variadic = fn_def->is_variadic(); + llvm_is_variadic = fn_type->is_variadic(); } llvm::FunctionType* llvm_fn_type = llvm::FunctionType::get( @@ -89,7 +108,7 @@ llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) // the callee is actually a non-variadic function that takes a va_list. // But we should use the actual function name for the lookup. auto callee_cand = module->getOrInsertFunction( - fn_def->get_internal_symbol_name(), + definition->get_internal_symbol_name(), llvm_fn_type ); diff --git a/packages/compiler/src/ast/traversal/expression_visitor.cpp b/packages/compiler/src/ast/traversal/expression_visitor.cpp index 9b05de48..35487df6 100644 --- a/packages/compiler/src/ast/traversal/expression_visitor.cpp +++ b/packages/compiler/src/ast/traversal/expression_visitor.cpp @@ -1,3 +1,4 @@ +#include "ast/casting.h" #include "ast/parsing_context.h" #include "ast/type_inference.h" #include "ast/visitor.h" @@ -16,12 +17,12 @@ void ExpressionVisitor::accept(IAstExpression* expr) { // Use the initial value's type (already set by bottom-up traversal) as the // canonical type registered in context, which is what identifier lookups rely on. - const auto canonical_type = var_decl->get_type(); - canonical_type->set_flags(var_decl->get_flags()); // Ensure type flags are preserved + const auto inferred_type = var_decl->get_type(); + inferred_type->set_flags(var_decl->get_flags()); // Ensure type flags are preserved var_decl->get_context()->define_variable( var_decl->get_symbol(), - canonical_type->clone_ty(), + inferred_type->clone_ty(), var_decl->get_visibility() ); } @@ -30,9 +31,12 @@ void ExpressionVisitor::accept(IAstExpression* expr) !function_call->get_generic_type_arguments().empty() ) { - const auto& definition = function_call->get_function_definition(); - definition->add_generic_overload( - copy_generic_type_list(function_call->get_generic_type_arguments()) - ); + auto* definition = function_call->get_function_definition(); + if (auto* fn_def = dynamic_cast(definition)) + { + fn_def->add_generic_overload( + copy_generic_type_list(function_call->get_generic_type_arguments()) + ); + } } } From 7669ad0a7d50630f28876e942356d70c6dc93a96 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 20:50:46 +0100 Subject: [PATCH 19/31] Resolve function-like variable lookup, resolved segfault by incorrect variable capturing --- .../src/ast/nodes/functions/call/function_call.cpp | 2 +- .../function_declaration_forward_refs.cpp | 14 ++++++++++++++ .../function_declaration_validation.cpp | 1 + 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index eaaa881b..7630f0d5 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -222,7 +222,7 @@ IDefinition* AstFunctionCall::get_function_definition() return this->_definition; } - if (const auto field_def = this->get_context()->get_variable_def( + if (const auto field_def = this->get_context()->lookup_variable( this->get_scoped_function_name(), true )) diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index 448c36e9..cace6947 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -200,6 +200,20 @@ void IAstFunction::collect_free_variables( // Not local to the lambda - check if it's in an outer scope (and is a variable, not a function) if (const auto outer_symbol = outer_context->lookup_variable(name, true)) { + // Global variables don't need capturing - they're accessible directly + // Walk up contexts to find which scope owns the variable; skip if global + auto* ctx = outer_context.get(); + while (ctx != nullptr) + { + if (ctx->get_variable_def(outer_symbol->get_internal_symbol_name())) + { + if (ctx->is_global_scope()) + return; + break; + } + ctx = ctx->get_parent_context().get(); + } + // Check if we haven't already captured this variable bool already_captured = false; for (const auto& cap : captures) diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index 6b2a2d96..878b3b6b 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -110,6 +110,7 @@ void IAstFunction::validate_candidate(IAstFunction* candidate) ); } } + candidate->get_body()->validate(); return; } From 60fcbbf13b8d7b11a4bbe3e5bf1abba0aaf94f87 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 21:07:45 +0100 Subject: [PATCH 20/31] Added guard clause for edge case to prevent out of bounds access --- .../src/ast/context/function_registry.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 515b91ec..6e65994e 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -90,7 +90,9 @@ const // different generic parameter overloads. // // For generic overloads, we just check whether the name and generic count is equal. - if (!this->_function_type->get_generic_parameter_names().empty() && generic_argument_count > 0) + if (!this->_function_type->get_generic_parameter_names().empty() && + generic_argument_count > 0 && + this->_function_type->get_generic_parameter_names().size() == generic_argument_count) return true; const auto& self_params = this->_function_type->get_parameter_types(); @@ -169,6 +171,11 @@ bool FunctionDefinition::has_generic_instantiation(const std::vectorequals(generic_types[i].get())) { all_equal = false; @@ -202,6 +209,11 @@ llvm::Function* FunctionDefinition::get_generic_overload_llvm_function(const Gen bool all_equal = true; for (size_t i = 0; i < generic_types.size(); i++) { + if (instantiated_generic_types.size() != generic_types.size()) + { + continue; + } + if (!instantiated_generic_types[i]->equals(generic_types[i].get())) { all_equal = false; @@ -215,4 +227,4 @@ llvm::Function* FunctionDefinition::get_generic_overload_llvm_function(const Gen } return nullptr; -} \ No newline at end of file +} From 895052504d09df1d5a5c5bf429cd5b43862023c7 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 21:08:19 +0100 Subject: [PATCH 21/31] Removed redundant comment --- packages/compiler/tests/utils.h | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/compiler/tests/utils.h b/packages/compiler/tests/utils.h index ffdfe5b4..5b2292c8 100644 --- a/packages/compiler/tests/utils.h +++ b/packages/compiler/tests/utils.h @@ -39,7 +39,6 @@ namespace stride::tests traverser.visit_block(&import_visitor, node.get()); traverser.visit_block(&function_visitor, node.get()); - // import_visitor.cross_register_symbols(this->_ast.get()); traverser.visit_block(&expression_visitor, node.get()); From 90f41f7d20e746df86fed9a703c027b0440c758b Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 21:11:53 +0100 Subject: [PATCH 22/31] Added additional conditions for function call check to prevent incorrect early return --- packages/compiler/include/ast/nodes/types.h | 2 +- .../src/ast/nodes/functions/call/function_call.cpp | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index d8267cbd..e475db84 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -374,7 +374,7 @@ namespace stride::ast [[nodiscard]] bool is_generic() const { - return this->_generic_param_names.empty(); + return !this->_generic_param_names.empty(); } [[nodiscard]] diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index 7630f0d5..b270c90d 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -123,7 +123,12 @@ bool stride::ast::is_direct_function_call(const TokenSet& set) default: // Optimization, where we know for sure it can't be part of a generic instantiation - if (!next_token.is_type_token() && next_token.get_type() != TokenType::COMMA) + if (!next_token.is_type_token() && + next_token.get_type() != TokenType::COMMA && + next_token.get_type() != TokenType::DOUBLE_COLON + && next_token.get_type() != TokenType::LSQUARE_BRACKET + && next_token.get_type() != TokenType::RSQUARE_BRACKET + && next_token.get_type() != TokenType::QUESTION) { return false; } From 5f57a42bb20fc5d79cd4c89d75f6a77d9e081722 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 21:22:33 +0100 Subject: [PATCH 23/31] Add additional test for generic type resolution --- packages/compiler/tests/test_generics.cpp | 472 +++++++++++++++++++++- 1 file changed, 469 insertions(+), 3 deletions(-) diff --git a/packages/compiler/tests/test_generics.cpp b/packages/compiler/tests/test_generics.cpp index aa7c3849..4bea6bb4 100644 --- a/packages/compiler/tests/test_generics.cpp +++ b/packages/compiler/tests/test_generics.cpp @@ -3,18 +3,22 @@ using namespace stride::tests; +// ============================================================================ +// Generic type definitions (type aliases and objects) +// ============================================================================ + TEST(Generics, ResolveGenericNamedTypeUnderlyingType) { // type SomeNamed = T[] - // const some_var: SomeNamed = 12; // This should FAIL if i32[] is correctly resolved as underlying type, + // const some_var: SomeNamed = 12; // This should FAIL if i32[] is correctly resolved as underlying type, // because 12 is not an array. // However, the issue description says: // "I expect that instantiate_named_type(SomeNamed) yields i32[]" - + // If it yields i32[], then: // const some_var: SomeNamed = [1, 2, 3]; // should WORK. - + assert_compiles(R"( type SomeNamed = T[]; const some_var: SomeNamed = [1, 2, 3]; @@ -98,3 +102,465 @@ TEST(Generics, FunctionSignatureMismatch) } )"); } + +TEST(Generics, GenericObjectWithMixedGenericAndConcreteFields) +{ + assert_compiles(R"( + type Tagged = { tag: string; value: T; }; + const t = Tagged::{ tag: "count", value: 42 }; + )"); +} + +// ============================================================================ +// Generic function definitions — body type resolution +// ============================================================================ + +TEST(GenericFunctions, ComparisonEqualsInBody) +{ + // The body uses `==` on generic params — requires resolve_generics_in_body + // to resolve T to a primitive so the comparison validator accepts it + assert_compiles(R"( + fn are_equal(a: T, b: T): bool { + return a == b; + } + + fn main(): i32 { + const result = are_equal(5, 5); + return 0; + } + )"); +} + +TEST(GenericFunctions, ComparisonNotEqualInBody) +{ + assert_compiles(R"( + fn not_equal(a: T, b: T): bool { + return a != b; + } + + fn main(): i32 { + const r = not_equal(1, 2); + return 0; + } + )"); +} + +TEST(GenericFunctions, ComparisonLessThanInBody) +{ + assert_compiles(R"( + fn less_than(a: T, b: T): bool { + return a < b; + } + + fn main(): i32 { + const r = less_than(1, 2); + return 0; + } + )"); +} + +TEST(GenericFunctions, ComparisonGreaterThanOrEqualInBody) +{ + assert_compiles(R"( + fn gte(a: T, b: T): bool { + return a >= b; + } + + fn main(): i32 { + const r = gte(5, 3); + return 0; + } + )"); +} + +TEST(GenericFunctions, AdditionInBody) +{ + assert_compiles(R"( + fn add(a: T, b: T): T { + return a + b; + } + + fn main(): i32 { + const sum = add(3, 4); + return 0; + } + )"); +} + +TEST(GenericFunctions, SubtractionInBody) +{ + assert_compiles(R"( + fn subtract(a: T, b: T): T { + return a - b; + } + + fn main(): i32 { + const d = subtract(10, 3); + return 0; + } + )"); +} + +TEST(GenericFunctions, MultiplicationInBody) +{ + assert_compiles(R"( + fn multiply(a: T, b: T): T { + return a * b; + } + + fn main(): i32 { + const p = multiply(6, 7); + return 0; + } + )"); +} + +TEST(GenericFunctions, ReturnsGenericType) +{ + // The function's return type is the generic parameter itself + assert_compiles(R"( + fn identity(value: T): T { + return value; + } + + fn main(): i32 { + const x = identity(42); + return 0; + } + )"); +} + +TEST(GenericFunctions, VoidReturnType) +{ + // Generic function that returns void — body uses generic params but + // doesn't return them + assert_compiles(R"( + fn consume(value: T): void { + } + + fn main(): void { + consume(42); + } + )"); +} + +// ============================================================================ +// Generic function — multiple type parameters +// ============================================================================ + +TEST(GenericFunctions, TwoTypeParams) +{ + assert_compiles(R"( + fn first_of(a: A, b: B): A { + return a; + } + + fn main(): i32 { + const x = first_of(10, true); + return 0; + } + )"); +} + +TEST(GenericFunctions, TwoTypeParamsReturnSecond) +{ + assert_compiles(R"( + fn second_of(a: A, b: B): B { + return b; + } + + fn main(): i32 { + const x = second_of(true, 42); + return 0; + } + )"); +} + +// ============================================================================ +// Generic function — mixed generic and concrete parameters +// ============================================================================ + +TEST(GenericFunctions, MixedGenericAndConcreteParams) +{ + assert_compiles(R"( + fn apply_flag(value: T, flag: bool): T { + return value; + } + + fn main(): i32 { + const x = apply_flag(7, true); + return 0; + } + )"); +} + +// ============================================================================ +// Generic function — multiple instantiations of the same definition +// ============================================================================ + +TEST(GenericFunctions, MultipleInstantiationsSameFunction) +{ + // Same generic function instantiated with different concrete types + assert_compiles(R"( + fn identity(v: T): T { + return v; + } + + fn main(): i32 { + const a = identity(1); + const b = identity(true); + return 0; + } + )"); +} + +TEST(GenericFunctions, MultipleInstantiationsComparison) +{ + assert_compiles(R"( + fn is_equal(a: T, b: T): bool { + return a == b; + } + + fn main(): i32 { + const ri = is_equal(5, 5); + const rf = is_equal(1.0D, 2.0D); + return 0; + } + )"); +} + +TEST(GenericFunctions, FloatInstantiation) +{ + assert_compiles(R"( + fn add(a: T, b: T): T { + return a + b; + } + + fn main(): i32 { + const sum = add(1.5D, 2.5D); + return 0; + } + )"); +} + +// ============================================================================ +// Generic function — complex body patterns (control flow) +// ============================================================================ + +TEST(GenericFunctions, ConditionalWithGenericComparison) +{ + assert_compiles(R"( + fn max_of(a: T, b: T): T { + if (a > b) { + return a; + } + return b; + } + + fn main(): i32 { + const m = max_of(3, 7); + return 0; + } + )"); +} + +TEST(GenericFunctions, ConditionalWithElse) +{ + assert_compiles(R"( + fn min_of(a: T, b: T): T { + if (a < b) { + return a; + } else { + return b; + } + } + + fn main(): i32 { + const m = min_of(3, 7); + return 0; + } + )"); +} + +TEST(GenericFunctions, ClampFunction) +{ + assert_compiles(R"( + fn clamp(value: T, lo: T, hi: T): T { + if (value < lo) { + return lo; + } + if (value > hi) { + return hi; + } + return value; + } + + fn main(): i32 { + const r = clamp(15, 0, 10); + return 0; + } + )"); +} + +TEST(GenericFunctions, WhileLoopInBody) +{ + assert_compiles(R"( + fn count_down(start: T): T { + let n = start; + while (n > 0) { + n = n - 1; + } + return n; + } + + fn main(): i32 { + const z = count_down(5); + return 0; + } + )"); +} + +TEST(GenericFunctions, ForLoopInBody) +{ + assert_compiles(R"( + fn accumulate(base: T, n: i32): T { + let result = base; + for (let i: i32 = 0; i < n; i++) { + result = result + base; + } + return result; + } + + fn main(): i32 { + const s = accumulate(3, 4); + return 0; + } + )"); +} + +TEST(GenericFunctions, LocalVariableWithGenericType) +{ + // A local variable inside a generic function body whose type derives + // from the generic parameter + assert_compiles(R"( + fn double_it(x: T): T { + const result: T = x + x; + return result; + } + + fn main(): i32 { + const d = double_it(21); + return 0; + } + )"); +} + +// ============================================================================ +// Generic functions interacting with generic object types +// ============================================================================ + +TEST(GenericFunctions, AcceptsGenericObject) +{ + assert_compiles(R"( + type Pair = { first: T; second: T; }; + + fn get_first(p: Pair): T { + return p.first; + } + + fn main(): i32 { + const p = Pair::{ first: 1, second: 2 }; + const v = get_first(p); + return 0; + } + )"); +} + +TEST(GenericFunctions, ReturnsGenericObject) +{ + assert_compiles(R"( + type Wrapper = { value: T; }; + + fn wrap(v: T): Wrapper { + return Wrapper::{ value: v }; + } + + fn main(): i32 { + const w = wrap(99); + return 0; + } + )"); +} + +TEST(GenericFunctions, CreatesGenericObjectInBody) +{ + assert_compiles(R"( + type Box = { inner: T; }; + + fn make_box(v: T): Box { + const b = Box::{ inner: v }; + return b; + } + + fn main(): i32 { + const b = make_box(5); + return 0; + } + )"); +} + +TEST(GenericFunctions, GenericObjectFieldAccess) +{ + assert_compiles(R"( + type Container = { data: T; }; + + fn extract(c: Container): T { + return c.data; + } + + fn main(): i32 { + const c = Container::{ data: 123 }; + const v = extract(c); + return 0; + } + )"); +} + +// ============================================================================ +// Generic function — uninstantiated (should compile without errors) +// ============================================================================ + +TEST(GenericFunctions, UninstantiatedGenericIsNotCodegenerated) +{ + // A generic function that is never called should not cause errors; + // it simply won't be code-generated. + assert_compiles(R"( + fn unused_generic(x: T): T { + return x; + } + + fn main(): i32 { + return 0; + } + )"); +} + +// ============================================================================ +// Generic function — error cases +// ============================================================================ + +TEST(GenericFunctions, ReturnTypeMismatch) +{ + // Generic function declared to return T, but instantiated as i32 and + // returning a bool literal — should fail type checking + assert_throws_message(R"( + fn wrong_return(v: T): T { + return true; + } + + fn main(): i32 { + const x = wrong_return(1); + return 0; + } + )", "expected a return type of"); +} From 95bed9111635d0bf79cd53e6c46112b6dbc1f9ed Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 21:40:18 +0100 Subject: [PATCH 24/31] Add handling for unresolved generic parameters in type inference and casting logic - Guarded against unresolved generic parameters in multiple cases (e.g., type extraction, casting, and assignability checks). - Updated type inference to prioritize unresolved generic parameters when resolving operation types. - Adjusted `AstVariableDeclaration` creation to use `std::optional` for annotated types. --- .../expressions/variable_declaration.cpp | 4 +++- .../src/ast/nodes/types/alias_type.cpp | 20 +++++++++++++++++++ .../compiler/src/ast/nodes/types/types.cpp | 7 +++++++ packages/compiler/src/ast/type_inference.cpp | 13 ++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp index df747aaa..36c91946 100644 --- a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp +++ b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp @@ -561,7 +561,9 @@ std::unique_ptr AstVariableDeclaration::clone() return std::make_unique( this->get_context(), this->_symbol, - this->has_annotated_type() ? this->_annotated_type.value()->clone_ty() : nullptr, + this->has_annotated_type() + ? std::optional>(this->_annotated_type.value()->clone_ty()) + : std::nullopt, this->_initial_value->clone_as(), this->_visibility ); diff --git a/packages/compiler/src/ast/nodes/types/alias_type.cpp b/packages/compiler/src/ast/nodes/types/alias_type.cpp index 5316432f..76be6da8 100644 --- a/packages/compiler/src/ast/nodes/types/alias_type.cpp +++ b/packages/compiler/src/ast/nodes/types/alias_type.cpp @@ -254,6 +254,12 @@ IAstType* AstAliasType::get_underlying_type() bool AstAliasType::is_castable_to_impl(IAstType* other) { + // Unresolved generic parameters cannot be cast + if (!this->get_type_definition().has_value()) + { + return false; + } + const auto self_reference_type = get_underlying_type(); // Check our base type is a primitive, and whether that type is castable to `other` @@ -288,6 +294,13 @@ bool AstAliasType::is_assignable_to_impl(IAstType* other) // type SomePrimitive = i32[] // const someVar: SomePrimitive = [1, 2, 3]; // In this case, `[1, 2, 3]` should be assignable to the base types of `SomePrimitive` + // If this is an unresolved generic parameter (no type definition), + // we can't resolve its underlying type. + if (!this->get_type_definition().has_value()) + { + return false; + } + if (const auto self_base_type = get_underlying_type()) { return self_base_type->is_assignable_to(other); @@ -332,6 +345,13 @@ bool AstAliasType::equals(IAstType* other) return true; } + // If this is an unresolved generic parameter (no type definition), + // we can't resolve its underlying type — just compare by name. + if (!this->get_type_definition().has_value()) + { + return false; + } + if (const auto self_base = this->get_underlying_type()) { return self_base->equals(other); diff --git a/packages/compiler/src/ast/nodes/types/types.cpp b/packages/compiler/src/ast/nodes/types/types.cpp index 62345486..5e36844d 100644 --- a/packages/compiler/src/ast/nodes/types/types.cpp +++ b/packages/compiler/src/ast/nodes/types/types.cpp @@ -154,6 +154,13 @@ AstPrimitiveType* extract_primitive_reference_types(IAstType* type) if (const auto named = cast_type(type)) { + // If this is an unresolved generic parameter (no type definition), + // we can't extract primitive reference types from it. + if (!named->get_type_definition().has_value()) + { + return nullptr; + } + const auto ref_type = named->get_underlying_type(); return extract_primitive_reference_types(ref_type); diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index bf21a606..a3686bfc 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -100,6 +100,19 @@ std::unique_ptr stride::ast::infer_binary_op_type(IBinaryOp* operation return std::move(lhs); } + // If either operand is an unresolved generic parameter (e.g. T), + // return it as the result type — it will be resolved when the generic is instantiated. + if (auto* lhs_alias = cast_type(lhs.get()); + lhs_alias && !lhs_alias->get_type_definition().has_value()) + { + return std::move(lhs); + } + if (auto* rhs_alias = cast_type(rhs.get()); + rhs_alias && !rhs_alias->get_type_definition().has_value()) + { + return std::move(rhs); + } + // --- Pointers have priority if (lhs->is_pointer() && !rhs->is_pointer()) { From 076d060655585f5bcb6b4a219928d4f302ca2f94 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 22:29:58 +0100 Subject: [PATCH 25/31] Refactor code generation and generic handling logic - Extracted `codegen_ptr` from `AstIdentifier` for pointer-based code generation. - Refactored generic type resolution to include annotated types in `AstVariableDeclaration`. - Streamlined identifier resolution for globals and functions. - Simplified generic overload management and added guard clauses for absent overloads in function declarations. --- .../compiler/include/ast/nodes/expression.h | 10 +++ packages/compiler/src/ast/generics.cpp | 11 +++- .../src/ast/nodes/expressions/identifier.cpp | 61 ++++++++++--------- .../expressions/variable_reassignation.cpp | 2 +- .../declaration/function_declaration.cpp | 32 +++++----- .../function_declaration_forward_refs.cpp | 6 +- 6 files changed, 70 insertions(+), 52 deletions(-) diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index db77980c..5915c1fb 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -219,6 +219,11 @@ namespace stride::ast llvm::IRBuilderBase* builder ) override; + llvm::Value* codegen_ptr( + llvm::Module* module, + llvm::IRBuilderBase* builder + ); + std::string to_string() override; bool is_reducible() override @@ -557,6 +562,11 @@ namespace stride::ast return std::nullopt; } + void set_annotated_type(std::unique_ptr type) + { + this->_annotated_type = std::move(type); + } + [[nodiscard]] IAstExpression* get_initial_value() const { diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index 3212e7b3..7cb59376 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -88,7 +88,9 @@ std::unique_ptr stride::ast::resolve_generics( { if (param_names[i] == named_type->get_name()) { - return instantiated_types[i]->clone_ty(); + auto resolved = instantiated_types[i]->clone_ty(); + resolved->set_flags(resolved->get_flags() | named_type->get_flags()); + return resolved; } } @@ -425,6 +427,13 @@ void stride::ast::resolve_generics_in_body( } else if (auto* var_decl = cast_expr(expr)) { + if (var_decl->has_annotated_type()) + { + var_decl->set_annotated_type( + resolve_generics(var_decl->get_annotated_type().value(), param_names, instantiated_types) + ); + } + if (var_decl->get_initial_value()) resolve_generics_in_body(var_decl->get_initial_value(), param_names, instantiated_types); } diff --git a/packages/compiler/src/ast/nodes/expressions/identifier.cpp b/packages/compiler/src/ast/nodes/expressions/identifier.cpp index e3f476b3..d93312cd 100644 --- a/packages/compiler/src/ast/nodes/expressions/identifier.cpp +++ b/packages/compiler/src/ast/nodes/expressions/identifier.cpp @@ -33,7 +33,7 @@ std::optional AstIdentifier::get_definition() const return std::nullopt; } -llvm::Value* AstIdentifier::codegen( +llvm::Value* AstIdentifier::codegen_ptr( llvm::Module* module, llvm::IRBuilderBase* builder ) @@ -75,15 +75,7 @@ llvm::Value* AstIdentifier::codegen( if (auto* global = module->getNamedGlobal(internal_name)) { - return builder->CreateLoad( - global->getValueType(), - global - ); - } - - if (auto* arg = llvm::dyn_cast_or_null(val)) - { - return arg; + return global; } throw parsing_error( @@ -94,17 +86,35 @@ llvm::Value* AstIdentifier::codegen( } } - // Check if it's a function argument - if (auto* arg = llvm::dyn_cast_or_null(val)) + if (!val) { - return arg; - } + if (const auto global = module->getNamedGlobal(internal_name)) + { + return global; + } - if (auto* load = llvm::dyn_cast_or_null(val)) - { - return load; + if (auto* function = module->getFunction(internal_name)) + { + return function; + } + + throw parsing_error( + ErrorType::REFERENCE_ERROR, + std::format("Identifier '{}' not found in this scope", this->get_name()), + this->get_source_fragment() + ); } + return val; +} + +llvm::Value* AstIdentifier::codegen( + llvm::Module* module, + llvm::IRBuilderBase* builder +) +{ + llvm::Value* val = this->codegen_ptr(module, builder); + if (auto* alloca = llvm::dyn_cast_or_null(val)) { // Load the value from the allocated variable @@ -115,33 +125,24 @@ llvm::Value* AstIdentifier::codegen( ); } - if (const auto global = module->getNamedGlobal(internal_name)) + if (const auto* global = llvm::dyn_cast_or_null(val)) { // Only generate a Load instruction if we are inside a BasicBlock (Function context). if (builder->GetInsertBlock()) { return builder->CreateLoad( global->getValueType(), - global + val ); } // If we are in Global context (initializing a global variable), we cannot generate // instructions. We return the GlobalVariable* itself. This allows parent nodes (like // MemberAccessor) to perform Constant Folding or ConstantExpr GEPs on the address. - return global; + return val; } - if (auto* function = module->getFunction(internal_name)) - { - return function; - } - - throw parsing_error( - ErrorType::REFERENCE_ERROR, - std::format("Identifier '{}' not found in this scope", this->get_name()), - this->get_source_fragment() - ); + return val; } std::unique_ptr AstIdentifier::clone() diff --git a/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp b/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp index 5b2a59bd..158b4365 100644 --- a/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp @@ -199,7 +199,7 @@ llvm::Value* AstVariableReassignment::codegen( llvm::IRBuilderBase* builder ) { - llvm::Value* variable = this->get_identifier()->codegen(module, builder); + llvm::Value* variable = this->_identifier->codegen_ptr(module, builder); // Save the insertion point before codegen, as callable types (lambdas) may change it llvm::BasicBlock* saved_block = builder->GetInsertBlock(); diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp index 845d5091..ac44e32e 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp @@ -255,26 +255,28 @@ std::vector IAstFunction::get_function_implementation_da { const auto& definition = this->get_function_definition(); - // If the function is not generic, we just return a singular name (the regular internalized name) - if (definition->get_generic_overloads().empty()) + // If the function is generic, we return its instantiated overloads. + // If it's generic but has no overloads, return empty list. + if (this->is_generic()) { - return { - FunctionImplementation{ this->get_registered_function_name(), definition->get_llvm_function() } - }; - } + std::vector implementations; - std::vector implementations; + for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) + { + implementations.emplace_back( + get_overloaded_function_name(node->get_registered_function_name(), types), + llvm_function, + node->get_body() + ); + } - for (const auto& [types, llvm_function, node] : definition->get_generic_overloads()) - { - implementations.emplace_back( - get_overloaded_function_name(node->get_registered_function_name(), types), - llvm_function, - node->get_body() - ); + return implementations; } - return implementations; + // For non-generic functions, return the single implementation. + return { + FunctionImplementation{ this->get_registered_function_name(), definition->get_llvm_function() } + }; } std::unique_ptr AstFunctionParameter::clone() diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index cace6947..be1b8cb6 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -95,11 +95,7 @@ void IAstFunction::resolve_forward_references( if (const auto& overloads = definition->get_generic_overloads(); overloads.empty()) { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format("Generic function '{}' has no instantiations", this->get_plain_function_name()), - this->get_source_fragment() - ); + return; } for (const auto& [instantiated_generic_types, llvm_function, node] : definition->get_generic_overloads()) From 0b8966b443d164973dd9fbdedb4eeac33d96d202 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Mon, 16 Mar 2026 22:43:10 +0100 Subject: [PATCH 26/31] Improve handling of generic functions in declarations and calls - Added logic to temporarily update and restore parameter types during validation of generic function instantiations. - Refined LLVM code generation to resolve return and parameter types for generic functions. - Enhanced generic function matching in the function registry based on parameter counts. --- .../src/ast/context/function_registry.cpp | 9 ++ .../functions/call/function_call_codegen.cpp | 23 +++- .../function_declaration_validation.cpp | 107 ++++++++++-------- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 6e65994e..75799bbf 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -63,6 +63,15 @@ bool FunctionDefinition::matches_type_signature( const AstFunctionType* signature ) const { + if (this->get_internal_symbol_name() != name) + return false; + + // Handle matching for generic functions + if (this->get_type()->is_generic() && signature->is_generic()) + { + return this->get_type()->get_generic_parameter_names().size() == signature->get_generic_parameter_names().size(); + } + if (!this->_function_type->get_return_type()->equals(signature->get_return_type().get())) return false; diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp index bdb323ac..474c1dab 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp @@ -78,12 +78,27 @@ llvm::Function* AstFunctionCall::resolve_regular_callee(llvm::Module* module) std::vector param_types; param_types.reserve(fn_type->get_parameter_types().size()); - for (const auto& param : fn_type->get_parameter_types()) + llvm::Type* ret_type = nullptr; + + if (!this->_generic_type_arguments.empty()) { - param_types.push_back(param->get_llvm_type(module)); + const auto& param_names = fn_type->get_generic_parameter_names(); + for (const auto& param : fn_type->get_parameter_types()) + { + auto resolved = resolve_generics(param.get(), param_names, this->_generic_type_arguments); + param_types.push_back(resolved->get_llvm_type(module)); + } + auto resolved_ret = resolve_generics(fn_type->get_return_type().get(), param_names, this->_generic_type_arguments); + ret_type = resolved_ret->get_llvm_type(module); + } + else + { + for (const auto& param : fn_type->get_parameter_types()) + { + param_types.push_back(param->get_llvm_type(module)); + } + ret_type = fn_type->get_return_type()->get_llvm_type(module); } - - llvm::Type* ret_type = fn_type->get_return_type()->get_llvm_type(module); // When propagating varargs (call has '...'), the callee receives the caller's // va_list as an extra fixed pointer argument rather than as true variadic args. diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index 878b3b6b..eaace4c8 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -19,60 +19,71 @@ void IAstFunction::validate() } // - // For generic functions, we create a new copy of the function with all parameters resolved, and do validation - // on that copy. This is because we want to validate the function body with the actual types that will be used in - // the function, rather than the generic placeholders. - // - // create a copy of this function with the parameters instantiated - for (const auto definition = this->get_function_definition(); - const auto& [instantiated_generic_types, function, node] : definition->get_generic_overloads()) - { - auto instantiated_return_ty = resolve_generics( - this->_annotated_return_type.get(), - this->_generic_parameters, - instantiated_generic_types - ); + // For generic functions, we create a new copy of the function with all parameters resolved, and do validation + // on that copy. This is because we want to validate the function body with the actual types that will be used in + // the function, rather than the generic placeholders. + // + // create a copy of this function with the parameters instantiated + for (const auto definition = this->get_function_definition(); + const auto& [instantiated_generic_types, function, node] : definition->get_generic_overloads()) + { + auto instantiated_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + instantiated_generic_types + ); - std::vector> instantiated_function_params; - instantiated_function_params.reserve(this->_parameters.size()); + std::vector> instantiated_function_params; + instantiated_function_params.reserve(this->_parameters.size()); - for (const auto& param : this->_parameters) - { - instantiated_function_params.push_back( - std::make_unique( - param->get_source_fragment(), - param->get_context(), - param->get_name(), - resolve_generics(param->get_type(), this->_generic_parameters, instantiated_generic_types) - ) + // Temporarily update parameter types in the context so resolve_generics_in_body + // and subsequent validation can find the concrete types. + std::vector>> old_param_types; + for (const auto& param : this->_parameters) + { + if (auto def = this->get_context()->lookup_variable(param->get_name(), true)) + { + old_param_types.push_back({ def, def->get_type()->clone_ty() }); + def->set_type(resolve_generics(def->get_type(), this->_generic_parameters, instantiated_generic_types)); + } + + instantiated_function_params.push_back( + std::make_unique( + param->get_source_fragment(), + param->get_context(), + param->get_name(), + resolve_generics(param->get_type(), this->_generic_parameters, instantiated_generic_types) + ) + ); + } + + // Clone the body and resolve generic types on every expression within it. + auto resolved_body = this->_body->clone_as(); + resolve_generics_in_body( + resolved_body.get(), + this->_generic_parameters, + instantiated_generic_types ); - } - // Clone the body and resolve generic types on every expression within it. - // The cloned body's expressions have no types set (clone does not preserve inferred types), - // so resolve_generics_in_body re-infers each expression's type from context (which still - // carries the generic parameter names, e.g. T) and then substitutes them with the concrete - // instantiated types (e.g. i32). - auto resolved_body = this->_body->clone_as(); - resolve_generics_in_body( - resolved_body.get(), - this->_generic_parameters, - instantiated_generic_types - ); + node = std::make_unique( + this->get_context(), + this->_symbol, + std::move(instantiated_function_params), + std::move(resolved_body), + std::move(instantiated_return_ty), + this->get_visibility(), + this->_flags, + EMPTY_GENERIC_PARAMETER_LIST // Omit generics - They've been resolved + ); - node = std::make_unique( - this->get_context(), - this->_symbol, - std::move(instantiated_function_params), - std::move(resolved_body), - std::move(instantiated_return_ty), - this->get_visibility(), - this->_flags, - EMPTY_GENERIC_PARAMETER_LIST // Omit generics - They've been resolved - ); + validate_candidate(node.get()); - validate_candidate(node.get()); - } + // Restore original parameter types in the context + for (auto& [def, old_type] : old_param_types) + { + def->set_type(std::move(old_type)); + } + } } void IAstFunction::validate_candidate(IAstFunction* candidate) From 5d989d3f6a4c636ed65dec64203521c3e3c62411 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Tue, 17 Mar 2026 18:36:50 +0100 Subject: [PATCH 27/31] Added function call visitor to ensure function calls instantiate generic function variants before type resolution to ensure generic parameters are resolved correctly --- example.sr | 11 +- .../include/ast/definitions/definitions.h | 144 +++++++++++++++++ .../ast/definitions/function_definition.h | 5 +- .../compiler/include/ast/nodes/expression.h | 2 - .../compiler/include/ast/parsing_context.h | 145 +----------------- .../include/ast/{nodes => }/traversal.h | 0 packages/compiler/include/ast/visitor.h | 8 +- .../src/ast/context/function_registry.cpp | 2 +- packages/compiler/src/ast/generics.cpp | 8 +- .../nodes/functions/call/function_call.cpp | 11 -- .../function_declaration_codegen.cpp | 17 +- .../src/ast/traversal/expression_visitor.cpp | 13 -- .../ast/traversal/function_call_visitor.cpp | 19 +++ .../compiler/src/ast/traversal/traversal.cpp | 6 +- packages/compiler/src/compilation/program.cpp | 13 +- packages/compiler/tests/utils.h | 4 +- 16 files changed, 215 insertions(+), 193 deletions(-) create mode 100644 packages/compiler/include/ast/definitions/definitions.h rename packages/compiler/include/ast/{nodes => }/traversal.h (100%) create mode 100644 packages/compiler/src/ast/traversal/function_call_visitor.cpp diff --git a/example.sr b/example.sr index a6e7b963..24a6596f 100644 --- a/example.sr +++ b/example.sr @@ -5,32 +5,27 @@ import System::{ type Array = { elements: T[]; count: i32; + at: (i32) -> T; }; fn arrayOf(elements: T[], count: i32): Array { return Array::{ elements, count, + at: (index: i32): T -> elements[index] }; } -const increment: (i32) -> i32 = (x: i32): i32 -> { - return x + 1; -}; - type Callback = () -> void; fn makeCb(): Callback { const names = arrayOf(["Toyota", "Honda", "Ford", "Toyota"], 4); - return (): void -> io::print("\x1b[32mDriving the car: %s", names.elements[1]); + return (): void -> io::print("\x1b[32mDriving the car: %s", names.at(1)); } fn main(): i32 { - makeCb()(); - const result: i32 = increment(10); - return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/definitions/definitions.h b/packages/compiler/include/ast/definitions/definitions.h new file mode 100644 index 00000000..7663d2fa --- /dev/null +++ b/packages/compiler/include/ast/definitions/definitions.h @@ -0,0 +1,144 @@ +#pragma once +#include "ast/symbols.h" + + +namespace stride::ast::definition +{ + class FunctionDefinition; + + enum class SymbolType + { + CLASS, + VARIABLE, + ENUM, + ENUM_MEMBER, + STRUCT, + STRUCT_MEMBER + }; + + class IDefinition + { + Symbol _symbol; + VisibilityModifier _visibility; + + public: + explicit IDefinition( + Symbol symbol, + const VisibilityModifier modifier + ) : + _symbol(std::move(symbol)), + _visibility(modifier) {} + + virtual ~IDefinition() = default; + + [[nodiscard]] + std::string get_internal_symbol_name() const + { + return this->_symbol.internal_name; + } + + [[nodiscard]] + Symbol get_symbol() const + { + return this->_symbol; + } + + [[nodiscard]] + VisibilityModifier get_visibility() const + { + return this->_visibility; + } + + [[nodiscard]] + virtual std::unique_ptr clone() const = 0; + + void set_visibility(const VisibilityModifier visibility) + { + this->_visibility = visibility; + } + }; + + class TypeDefinition + : public IDefinition + { + std::unique_ptr _type; + GenericParameterList _generics; + + public: + explicit TypeDefinition( + Symbol type_name_symbol, + std::unique_ptr type, + GenericParameterList generics, + const VisibilityModifier visibility + ) : + IDefinition(std::move(type_name_symbol), visibility), + _type(std::move(type)), + _generics(std::move(generics)) {} + + [[nodiscard]] + IAstType* get_type() const + { + return this->_type.get(); + } + + [[nodiscard]] + GenericParameterList get_generics_parameters() const + { + return this->_generics; + } + + [[nodiscard]] + bool is_generic() const + { + return !this->_generics.empty(); + } + + [[nodiscard]] + std::unique_ptr clone() const override + { + return std::make_unique( + get_symbol(), + _type->clone_ty(), + get_generics_parameters(), + get_visibility()); + } + }; + + class FieldDefinition : public IDefinition + { + std::unique_ptr _type; + + /// Can be either a variable or a field in a struct/class + public: + explicit FieldDefinition( + const Symbol& symbol, + std::unique_ptr type, + const VisibilityModifier visibility + ) : + IDefinition(symbol, visibility), + _type(std::move(type)) {} + + [[nodiscard]] + IAstType* get_type() const + { + return this->_type.get(); + } + + [[nodiscard]] + std::string get_field_name() const + { + return this->get_symbol().name; + } + + [[nodiscard]] + std::unique_ptr clone() const override + { + return std::make_unique(get_symbol(), _type->clone_ty(), get_visibility()); + } + + void set_type(std::unique_ptr type) + { + this->_type = std::move(type); + } + }; +} // namespace stride::ast::definition diff --git a/packages/compiler/include/ast/definitions/function_definition.h b/packages/compiler/include/ast/definitions/function_definition.h index 4584406e..f4fb8762 100644 --- a/packages/compiler/include/ast/definitions/function_definition.h +++ b/packages/compiler/include/ast/definitions/function_definition.h @@ -1,8 +1,9 @@ #pragma once -#include "ast/parsing_context.h" +#include "ast/definitions/definitions.h" #include "ast/nodes/function_declaration.h" +#include #include namespace stride::ast::definition @@ -58,7 +59,7 @@ namespace stride::ast::definition return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0; } - void add_generic_overload(GenericTypeList generic_overload_types); + void add_generic_instantiation(GenericTypeList generic_overload_types); [[nodiscard]] const std::vector& get_generic_overloads() const diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index 5915c1fb..d5a06c63 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -457,8 +457,6 @@ namespace stride::ast llvm::IRBuilderBase* builder ) override; - bool is_reducible() override; - std::optional> reduce() override; std::unique_ptr clone() override; diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 2892de2c..ec9077ab 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -3,6 +3,7 @@ #include "modifiers.h" #include "symbols.h" #include "ast/nodes/types.h" +#include "definitions/definitions.h" #include #include @@ -17,9 +18,6 @@ namespace llvm namespace stride::ast { - class AstFunctionDeclaration; - enum class VisibilityModifier; - enum class ContextType { GLOBAL, @@ -29,147 +27,6 @@ namespace stride::ast CONTROL_FLOW }; - namespace definition - { - class FunctionDefinition; - - enum class SymbolType - { - CLASS, - VARIABLE, - ENUM, - ENUM_MEMBER, - STRUCT, - STRUCT_MEMBER - }; - - class IDefinition - { - Symbol _symbol; - VisibilityModifier _visibility; - - public: - explicit IDefinition( - Symbol symbol, - const VisibilityModifier modifier - ) : - _symbol(std::move(symbol)), - _visibility(modifier) {} - - virtual ~IDefinition() = default; - - [[nodiscard]] - std::string get_internal_symbol_name() const - { - return this->_symbol.internal_name; - } - - [[nodiscard]] - Symbol get_symbol() const - { - return this->_symbol; - } - - [[nodiscard]] - VisibilityModifier get_visibility() const - { - return this->_visibility; - } - - [[nodiscard]] - virtual std::unique_ptr clone() const = 0; - - void set_visibility(const VisibilityModifier visibility) - { - this->_visibility = visibility; - } - }; - - class TypeDefinition - : public IDefinition - { - std::unique_ptr _type; - GenericParameterList _generics; - - public: - explicit TypeDefinition( - Symbol type_name_symbol, - std::unique_ptr type, - GenericParameterList generics, - const VisibilityModifier visibility - ) : - IDefinition(std::move(type_name_symbol), visibility), - _type(std::move(type)), - _generics(std::move(generics)) {} - - [[nodiscard]] - IAstType* get_type() const - { - return this->_type.get(); - } - - [[nodiscard]] - GenericParameterList get_generics_parameters() const - { - return this->_generics; - } - - [[nodiscard]] - bool is_generic() const - { - return !this->_generics.empty(); - } - - [[nodiscard]] - std::unique_ptr clone() const override - { - return std::make_unique( - get_symbol(), - _type->clone_ty(), - get_generics_parameters(), - get_visibility()); - } - }; - - class FieldDefinition : public IDefinition - { - std::unique_ptr _type; - - /// Can be either a variable or a field in a struct/class - public: - explicit FieldDefinition( - const Symbol& symbol, - std::unique_ptr type, - const VisibilityModifier visibility - ) : - IDefinition(symbol, visibility), - _type(std::move(type)) {} - - [[nodiscard]] - IAstType* get_type() const - { - return this->_type.get(); - } - - [[nodiscard]] - std::string get_field_name() const - { - return this->get_symbol().name; - } - - [[nodiscard]] - std::unique_ptr clone() const override - { - return std::make_unique(get_symbol(), _type->clone_ty(), get_visibility()); - } - - void set_type(std::unique_ptr type) - { - this->_type = std::move(type); - } - }; - } // namespace definition - class ParsingContext { /** diff --git a/packages/compiler/include/ast/nodes/traversal.h b/packages/compiler/include/ast/traversal.h similarity index 100% rename from packages/compiler/include/ast/nodes/traversal.h rename to packages/compiler/include/ast/traversal.h diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h index a31fa344..05c1c6aa 100644 --- a/packages/compiler/include/ast/visitor.h +++ b/packages/compiler/include/ast/visitor.h @@ -1,6 +1,6 @@ #pragma once -#include "nodes/traversal.h" +#include "traversal.h" #include #include @@ -57,6 +57,12 @@ namespace stride::ast void accept(IAstFunction* function) override; }; + class FunctionCallVisitor : public IVisitor + { + public: + void accept(AstFunctionCall* function_call) override; + }; + class ImportVisitor : public IVisitor { std::string _current_file_name; // temporary values diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 75799bbf..7a368833 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -200,7 +200,7 @@ bool FunctionDefinition::has_generic_instantiation(const std::vectorset_type( stride::ast::resolve_generics(inferred_type.get(), param_names, instantiated_types) ); diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index b270c90d..3929f883 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -12,7 +12,6 @@ #include #include -#include #include using namespace stride::ast; @@ -270,14 +269,6 @@ std::unique_ptr AstFunctionCall::clone() ); } -bool AstFunctionCall::is_reducible() -{ - // TODO: implement - // Function calls can be reducible if the function returns - // a constant value or if all arguments are reducible. - return false; -} - std::optional> AstFunctionCall::reduce() { return std::nullopt; @@ -320,8 +311,6 @@ std::string AstFunctionCall::get_formatted_call() const std::string AstFunctionCall::to_string() { - std::ostringstream oss; - std::vector arg_types; for (const auto& arg : this->get_arguments()) { diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp index 33d0ba13..39d8b7f8 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp @@ -8,7 +8,6 @@ using namespace stride::ast; - llvm::Value* IAstFunction::codegen( llvm::Module* module, llvm::IRBuilderBase* builder @@ -16,7 +15,21 @@ llvm::Value* IAstFunction::codegen( { llvm::Function* function = nullptr; - for (const auto& [function_name, llvm_function_val, overload_body] : this->get_function_implementation_data()) + const auto& implementations = get_function_implementation_data(); + + if (implementations.empty()) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "No instantiations for function '{}' found", + this->get_plain_function_name() + ), + this->get_source_fragment() + ); + } + + for (const auto& [function_name, llvm_function_val, overload_body] : implementations) { if (!llvm_function_val) { diff --git a/packages/compiler/src/ast/traversal/expression_visitor.cpp b/packages/compiler/src/ast/traversal/expression_visitor.cpp index 35487df6..87e689b3 100644 --- a/packages/compiler/src/ast/traversal/expression_visitor.cpp +++ b/packages/compiler/src/ast/traversal/expression_visitor.cpp @@ -26,17 +26,4 @@ void ExpressionVisitor::accept(IAstExpression* expr) var_decl->get_visibility() ); } - else if (auto* function_call = dynamic_cast(expr); - function_call != nullptr && - !function_call->get_generic_type_arguments().empty() - ) - { - auto* definition = function_call->get_function_definition(); - if (auto* fn_def = dynamic_cast(definition)) - { - fn_def->add_generic_overload( - copy_generic_type_list(function_call->get_generic_type_arguments()) - ); - } - } } diff --git a/packages/compiler/src/ast/traversal/function_call_visitor.cpp b/packages/compiler/src/ast/traversal/function_call_visitor.cpp new file mode 100644 index 00000000..544aadb2 --- /dev/null +++ b/packages/compiler/src/ast/traversal/function_call_visitor.cpp @@ -0,0 +1,19 @@ +#include "ast/visitor.h" +#include "ast/definitions/function_definition.h" +#include "ast/nodes/expression.h" + +using namespace stride::ast; + +void FunctionCallVisitor::accept(AstFunctionCall* function_call) +{ + if (function_call->get_generic_type_arguments().empty()) + return; + + auto* definition = function_call->get_function_definition(); + if (auto* fn_def = dynamic_cast(definition)) + { + fn_def->add_generic_instantiation( + copy_generic_type_list(function_call->get_generic_type_arguments()) + ); + } +} diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 81de9fbc..e2f69f00 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -1,4 +1,4 @@ -#include "ast/nodes/traversal.h" +#include "../../../include/ast/traversal.h" #include "ast/casting.h" #include "ast/parsing_context.h" @@ -55,10 +55,12 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) if (var_decl->get_initial_value()) visit_expression(visitor, var_decl->get_initial_value()); } - else if (const auto* fn_call = cast_expr(node)) + else if (auto* fn_call = cast_expr(node)) { for (const auto& arg : fn_call->get_arguments()) visit_expression(visitor, arg.get()); + + visitor->accept(fn_call); } else if (const auto* array = cast_expr(node)) { diff --git a/packages/compiler/src/compilation/program.cpp b/packages/compiler/src/compilation/program.cpp index a9bdb569..1e7d9964 100644 --- a/packages/compiler/src/compilation/program.cpp +++ b/packages/compiler/src/compilation/program.cpp @@ -2,7 +2,7 @@ #include "ast/ast.h" #include "ast/visitor.h" -#include "ast/nodes/traversal.h" +#include "../../include/ast/traversal.h" #include "runtime/symbols.h" #include @@ -50,6 +50,7 @@ std::unique_ptr Program::prepare_module( ast::AstNodeTraverser traverser; ast::ExpressionVisitor expression_visitor; ast::FunctionVisitor function_visitor; + ast::FunctionCallVisitor function_call_visitor; ast::ImportVisitor import_visitor; // @@ -71,7 +72,15 @@ std::unique_ptr Program::prepare_module( import_visitor.cross_register_symbols(this->_ast.get()); // - // Second step - Type resolution and symbol forward declarations + // Generic resolution - Instantiates functions that have generic arguments + // + for (const auto& node : this->_ast->get_files() | std::views::values) + { + traverser.visit_block(&function_call_visitor, node.get()); + } + + // + // Third step - Type resolution and symbol forward declarations // for (const auto& node : this->_ast->get_files() | std::views::values) { diff --git a/packages/compiler/tests/utils.h b/packages/compiler/tests/utils.h index 5b2292c8..a657d96a 100644 --- a/packages/compiler/tests/utils.h +++ b/packages/compiler/tests/utils.h @@ -5,7 +5,7 @@ #include "ast/parsing_context.h" #include "ast/visitor.h" #include "ast/nodes/blocks.h" -#include "ast/nodes/traversal.h" +#include "../include/ast/traversal.h" #include "ast/tokens/tokenizer.h" #include "runtime/symbols.h" @@ -31,6 +31,7 @@ namespace stride::tests ast::AstNodeTraverser traverser; ast::ExpressionVisitor expression_visitor; ast::FunctionVisitor function_visitor; + ast::FunctionCallVisitor function_call_visitor; ast::ImportVisitor import_visitor; runtime::register_runtime_symbols(node->get_context()); @@ -39,6 +40,7 @@ namespace stride::tests traverser.visit_block(&import_visitor, node.get()); traverser.visit_block(&function_visitor, node.get()); + traverser.visit_block(&function_call_visitor, node.get()); traverser.visit_block(&expression_visitor, node.get()); From 47d332d504d8e91fb20646416dd269c292a89a3f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 18 Mar 2026 17:57:20 +0100 Subject: [PATCH 28/31] Enhance generic function handling by adding support for retrieving generic function definitions and improving instantiation logic --- packages/compiler/include/ast/generics.h | 14 -- .../compiler/include/ast/parsing_context.h | 6 + .../src/ast/context/function_registry.cpp | 11 +- packages/compiler/src/ast/generics.cpp | 155 ------------------ .../function_declaration_forward_refs.cpp | 5 - .../function_declaration_validation.cpp | 104 ++++++------ .../src/ast/nodes/types/alias_type.cpp | 2 - .../ast/traversal/function_call_visitor.cpp | 17 +- .../compiler/src/ast/traversal/traversal.cpp | 2 +- 9 files changed, 79 insertions(+), 237 deletions(-) diff --git a/packages/compiler/include/ast/generics.h b/packages/compiler/include/ast/generics.h index 01daf91f..35cd7ef5 100644 --- a/packages/compiler/include/ast/generics.h +++ b/packages/compiler/include/ast/generics.h @@ -50,18 +50,4 @@ namespace stride::ast GenericTypeList copy_generic_type_list(const GenericTypeList& list); std::string get_overloaded_function_name(std::string function_name, const GenericTypeList& overload_types); - - /** - * Recursively walks a function body AST and resolves generic type parameters - * on all expression nodes, in the same manner as resolve_generics does for types. - * - * For each expression, the type is re-inferred (via context lookup) and then - * any generic parameters are substituted with their concrete instantiated types. - * The walk is bottom-up so that child expression types are resolved before their parents. - */ - void resolve_generics_in_body( - IAstNode* node, - const GenericParameterList& param_names, - const GenericTypeList& instantiated_types - ); } diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index ec9077ab..ddbafb5d 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -117,6 +117,12 @@ namespace stride::ast size_t instantiated_generic_count = 0 ) const; + [[nodiscard]] + std::optional get_generic_function_definition( + const std::string& function_name, + size_t instantiated_generic_count = 0 + ) const; + std::optional get_function_definition( const std::string& function_name, IAstType* function_type diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index 7a368833..c5c7a916 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -31,6 +31,14 @@ std::optional ParsingContext::get_function_definition( return std::nullopt; } +std::optional ParsingContext::get_generic_function_definition( + const std::string& function_name, + const size_t instantiated_generic_count +) const +{ + return get_function_definition(function_name, {}, instantiated_generic_count); +} + std::optional ParsingContext::get_function_definition( const std::string& function_name, // We might call this function with an anonymous type, hence not having `AstFunctionType` @@ -69,7 +77,8 @@ bool FunctionDefinition::matches_type_signature( // Handle matching for generic functions if (this->get_type()->is_generic() && signature->is_generic()) { - return this->get_type()->get_generic_parameter_names().size() == signature->get_generic_parameter_names().size(); + return this->get_type()->get_generic_parameter_names().size() == signature->get_generic_parameter_names(). + size(); } if (!this->_function_type->get_return_type()->equals(signature->get_return_type().get())) diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index fe92f04d..092d2b34 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -347,158 +347,3 @@ static void resolve_expression_type( stride::ast::resolve_generics(inferred_type.get(), param_names, instantiated_types) ); } - -void stride::ast::resolve_generics_in_body( - IAstNode* node, - const GenericParameterList& param_names, - const GenericTypeList& instantiated_types -) -{ - if (!node || param_names.empty()) - return; - - // - // Statement / container nodes — recurse into children first (bottom-up). - // - - if (auto* block = dynamic_cast(node)) - { - for (const auto& child : block->get_children()) - resolve_generics_in_body(child.get(), param_names, instantiated_types); - return; - } - - if (auto* return_stmt = dynamic_cast(node)) - { - if (return_stmt->get_return_expression().has_value()) - resolve_generics_in_body( - return_stmt->get_return_expression().value().get(), - param_names, instantiated_types); - return; - } - - if (auto* conditional = dynamic_cast(node)) - { - resolve_generics_in_body(conditional->get_condition(), param_names, instantiated_types); - resolve_generics_in_body(conditional->get_body(), param_names, instantiated_types); - if (conditional->get_else_body()) - resolve_generics_in_body(conditional->get_else_body(), param_names, instantiated_types); - return; - } - - if (auto* while_loop = dynamic_cast(node)) - { - if (while_loop->get_condition()) - resolve_generics_in_body(while_loop->get_condition(), param_names, instantiated_types); - resolve_generics_in_body(while_loop->get_body(), param_names, instantiated_types); - return; - } - - if (auto* for_loop = dynamic_cast(node)) - { - if (for_loop->get_initializer()) - resolve_generics_in_body(for_loop->get_initializer(), param_names, instantiated_types); - if (for_loop->get_condition()) - resolve_generics_in_body(for_loop->get_condition(), param_names, instantiated_types); - if (for_loop->get_incrementor()) - resolve_generics_in_body(for_loop->get_incrementor(), param_names, instantiated_types); - resolve_generics_in_body(for_loop->get_body(), param_names, instantiated_types); - return; - } - - // - // Expression nodes — recurse into sub-expressions first, then resolve this node's type. - // - - auto* expr = dynamic_cast(node); - if (!expr) - return; - - // --- Recurse into child expressions (bottom-up order) --- - - if (auto* binary = cast_expr(expr)) - { - resolve_generics_in_body(binary->get_left(), param_names, instantiated_types); - resolve_generics_in_body(binary->get_right(), param_names, instantiated_types); - } - else if (auto* unary = cast_expr(expr)) - { - resolve_generics_in_body(&unary->get_operand(), param_names, instantiated_types); - } - else if (auto* var_decl = cast_expr(expr)) - { - if (var_decl->has_annotated_type()) - { - var_decl->set_annotated_type( - resolve_generics(var_decl->get_annotated_type().value(), param_names, instantiated_types) - ); - } - - if (var_decl->get_initial_value()) - resolve_generics_in_body(var_decl->get_initial_value(), param_names, instantiated_types); - } - else if (auto* fn_call = cast_expr(expr)) - { - for (const auto& arg : fn_call->get_arguments()) - resolve_generics_in_body(arg.get(), param_names, instantiated_types); - } - else if (auto* array = cast_expr(expr)) - { - for (const auto& elem : array->get_elements()) - resolve_generics_in_body(elem.get(), param_names, instantiated_types); - } - else if (auto* array_accessor = cast_expr(expr)) - { - resolve_generics_in_body(array_accessor->get_array_base(), param_names, instantiated_types); - resolve_generics_in_body(array_accessor->get_index(), param_names, instantiated_types); - } - else if (auto* struct_init = cast_expr(expr)) - { - // Resolve generic type arguments on the object initializer itself - // e.g. Array::{ ... } → Array::{ ... } - if (struct_init->has_generic_type_arguments()) - { - GenericTypeList resolved_args; - resolved_args.reserve(struct_init->get_generic_type_arguments().size()); - for (const auto& arg : struct_init->get_generic_type_arguments()) - { - resolved_args.push_back(resolve_generics(arg.get(), param_names, instantiated_types)); - } - struct_init->set_generic_type_arguments(std::move(resolved_args)); - } - - for (const auto& val : struct_init->get_initializers() | std::views::values) - resolve_generics_in_body(val.get(), param_names, instantiated_types); - } - else if (auto* tuple_init = cast_expr(expr)) - { - for (const auto& member : tuple_init->get_members()) - resolve_generics_in_body(member.get(), param_names, instantiated_types); - } - else if (auto* reassign = cast_expr(expr)) - { - resolve_generics_in_body(reassign->get_identifier(), param_names, instantiated_types); - resolve_generics_in_body(reassign->get_value(), param_names, instantiated_types); - } - else if (auto* chained = cast_expr(expr)) - { - resolve_generics_in_body(chained->get_base(), param_names, instantiated_types); - } - else if (auto* type_cast = cast_expr(expr)) - { - resolve_generics_in_body(type_cast->get_value(), param_names, instantiated_types); - } - else if (auto* indirect_call = cast_expr(expr)) - { - for (const auto& arg : indirect_call->get_args()) - resolve_generics_in_body(arg.get(), param_names, instantiated_types); - resolve_generics_in_body(indirect_call->get_callee(), param_names, instantiated_types); - } - else if (auto* function_node = cast_expr(expr)) - { - resolve_generics_in_body(function_node->get_body(), param_names, instantiated_types); - } - - // --- Now resolve the type for this expression node --- - resolve_expression_type(expr, param_names, instantiated_types); -} diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index be1b8cb6..b1ee613b 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -122,11 +122,6 @@ void IAstFunction::resolve_forward_references( } auto resolved_body = this->_body->clone_as(); - resolve_generics_in_body( - resolved_body.get(), - this->_generic_parameters, - instantiated_generic_types - ); node = std::make_unique( this->get_context(), diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index eaace4c8..381af1b4 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -19,71 +19,65 @@ void IAstFunction::validate() } // - // For generic functions, we create a new copy of the function with all parameters resolved, and do validation - // on that copy. This is because we want to validate the function body with the actual types that will be used in - // the function, rather than the generic placeholders. - // - // create a copy of this function with the parameters instantiated - for (const auto definition = this->get_function_definition(); - const auto& [instantiated_generic_types, function, node] : definition->get_generic_overloads()) - { - auto instantiated_return_ty = resolve_generics( - this->_annotated_return_type.get(), - this->_generic_parameters, - instantiated_generic_types - ); + // For generic functions, we create a new copy of the function with all parameters resolved, and do validation + // on that copy. This is because we want to validate the function body with the actual types that will be used in + // the function, rather than the generic placeholders. + // + // create a copy of this function with the parameters instantiated + for (const auto definition = this->get_function_definition(); + const auto& [instantiated_generic_types, function, node] : definition->get_generic_overloads()) + { + auto instantiated_return_ty = resolve_generics( + this->_annotated_return_type.get(), + this->_generic_parameters, + instantiated_generic_types + ); - std::vector> instantiated_function_params; - instantiated_function_params.reserve(this->_parameters.size()); + std::vector> instantiated_function_params; + instantiated_function_params.reserve(this->_parameters.size()); - // Temporarily update parameter types in the context so resolve_generics_in_body - // and subsequent validation can find the concrete types. - std::vector>> old_param_types; - for (const auto& param : this->_parameters) + // Temporarily update parameter types in the context so resolve_generics_in_body + // and subsequent validation can find the concrete types. + std::vector>> old_param_types; + for (const auto& param : this->_parameters) + { + if (auto def = this->get_context()->lookup_variable(param->get_name(), true)) { - if (auto def = this->get_context()->lookup_variable(param->get_name(), true)) - { - old_param_types.push_back({ def, def->get_type()->clone_ty() }); - def->set_type(resolve_generics(def->get_type(), this->_generic_parameters, instantiated_generic_types)); - } - - instantiated_function_params.push_back( - std::make_unique( - param->get_source_fragment(), - param->get_context(), - param->get_name(), - resolve_generics(param->get_type(), this->_generic_parameters, instantiated_generic_types) - ) - ); + old_param_types.push_back({ def, def->get_type()->clone_ty() }); + def->set_type(resolve_generics(def->get_type(), this->_generic_parameters, instantiated_generic_types)); } - // Clone the body and resolve generic types on every expression within it. - auto resolved_body = this->_body->clone_as(); - resolve_generics_in_body( - resolved_body.get(), - this->_generic_parameters, - instantiated_generic_types + instantiated_function_params.push_back( + std::make_unique( + param->get_source_fragment(), + param->get_context(), + param->get_name(), + resolve_generics(param->get_type(), this->_generic_parameters, instantiated_generic_types) + ) ); + } - node = std::make_unique( - this->get_context(), - this->_symbol, - std::move(instantiated_function_params), - std::move(resolved_body), - std::move(instantiated_return_ty), - this->get_visibility(), - this->_flags, - EMPTY_GENERIC_PARAMETER_LIST // Omit generics - They've been resolved - ); + // Clone the body and resolve generic types on every expression within it. + + node = std::make_unique( + this->get_context(), + this->_symbol, + std::move(instantiated_function_params), + this->_body->clone_as(), + std::move(instantiated_return_ty), + this->get_visibility(), + this->_flags, + EMPTY_GENERIC_PARAMETER_LIST // Omit generics - They've been resolved + ); - validate_candidate(node.get()); + validate_candidate(node.get()); - // Restore original parameter types in the context - for (auto& [def, old_type] : old_param_types) - { - def->set_type(std::move(old_type)); - } + // Restore original parameter types in the context + for (auto& [def, old_type] : old_param_types) + { + def->set_type(std::move(old_type)); } + } } void IAstFunction::validate_candidate(IAstFunction* candidate) diff --git a/packages/compiler/src/ast/nodes/types/alias_type.cpp b/packages/compiler/src/ast/nodes/types/alias_type.cpp index 76be6da8..7ea328f0 100644 --- a/packages/compiler/src/ast/nodes/types/alias_type.cpp +++ b/packages/compiler/src/ast/nodes/types/alias_type.cpp @@ -160,9 +160,7 @@ IAstType* AstAliasType::get_underlying_type() { // Prevent reinstantiating type if it's a complex type if (this->_underlying_type != nullptr) - { return this->_underlying_type.get(); - } const auto& reference_type_definition = this->get_type_definition(); diff --git a/packages/compiler/src/ast/traversal/function_call_visitor.cpp b/packages/compiler/src/ast/traversal/function_call_visitor.cpp index 544aadb2..acea54b5 100644 --- a/packages/compiler/src/ast/traversal/function_call_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_call_visitor.cpp @@ -9,11 +9,20 @@ void FunctionCallVisitor::accept(AstFunctionCall* function_call) if (function_call->get_generic_type_arguments().empty()) return; - auto* definition = function_call->get_function_definition(); - if (auto* fn_def = dynamic_cast(definition)) + auto* definition = function_call->get_context()->get_generic_function_definition( + function_call->get_function_name(), + function_call->get_generic_type_arguments().size() + ).value_or(nullptr); + + if (!definition) { - fn_def->add_generic_instantiation( - copy_generic_type_list(function_call->get_generic_type_arguments()) + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format("Could not find generic function definition for '{}'", function_call->get_function_name()), + function_call->get_source_fragment() ); } + definition->add_generic_instantiation( + copy_generic_type_list(function_call->get_generic_type_arguments()) + ); } diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index e2f69f00..4110b259 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -38,8 +38,8 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) // IAstFunction is an expression but needs special handling (body traversal + params) if (auto* fn = cast_expr(node)) { - visitor->accept(fn); visit_block(visitor, fn->get_body()); + visitor->accept(fn); } else if (const auto* binary = cast_expr(node)) { From f0e6eb89fa80a75936ebe8994ae1614d4c60bac3 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 25 Mar 2026 08:05:00 +0100 Subject: [PATCH 29/31] Renamed `parsing_context` to `symbol_table` --- packages/compiler/include/ast/ast.h | 6 +- packages/compiler/include/ast/generics.h | 4 +- .../compiler/include/ast/nodes/ast_node.h | 8 +- packages/compiler/include/ast/nodes/blocks.h | 6 +- .../include/ast/nodes/conditional_statement.h | 4 +- .../ast/nodes/control_flow_statements.h | 10 +-- .../compiler/include/ast/nodes/expression.h | 74 +++++++++---------- .../compiler/include/ast/nodes/for_loop.h | 4 +- .../include/ast/nodes/function_declaration.h | 20 ++--- packages/compiler/include/ast/nodes/import.h | 6 +- .../include/ast/nodes/literal_values.h | 28 +++---- packages/compiler/include/ast/nodes/module.h | 4 +- packages/compiler/include/ast/nodes/package.h | 4 +- .../include/ast/nodes/return_statement.h | 6 +- packages/compiler/include/ast/nodes/switch.h | 6 +- .../include/ast/nodes/type_definition.h | 8 +- packages/compiler/include/ast/nodes/types.h | 28 +++---- .../compiler/include/ast/nodes/while_loop.h | 6 +- .../ast/{parsing_context.h => symbol_table.h} | 24 +++--- packages/compiler/include/ast/symbols.h | 2 +- packages/compiler/include/program.h | 2 +- packages/compiler/include/runtime/symbols.h | 4 +- packages/compiler/src/ast/ast.cpp | 8 +- .../src/ast/context/field_registry.cpp | 16 ++-- .../src/ast/context/function_registry.cpp | 14 ++-- .../{parsing_context.cpp => symbol_table.cpp} | 12 +-- .../src/ast/context/type_registry.cpp | 14 ++-- packages/compiler/src/ast/generics.cpp | 4 +- packages/compiler/src/ast/nodes/blocks.cpp | 2 +- .../src/ast/nodes/conditional_statement.cpp | 12 +-- .../nodes/control_flow/break_statement.cpp | 8 +- .../nodes/control_flow/continue_statement.cpp | 8 +- .../nodes/expressions/array_initializer.cpp | 2 +- .../expressions/array_member_accessor.cpp | 2 +- .../nodes/expressions/binary_operation.cpp | 2 +- .../src/ast/nodes/expressions/expression.cpp | 20 ++--- .../src/ast/nodes/expressions/identifier.cpp | 2 +- .../ast/nodes/expressions/member_accessor.cpp | 6 +- .../nodes/expressions/object_initializer.cpp | 6 +- .../src/ast/nodes/expressions/type_cast.cpp | 2 +- .../ast/nodes/expressions/unary_operation.cpp | 4 +- .../expressions/variable_declaration.cpp | 6 +- .../expressions/variable_reassignation.cpp | 4 +- packages/compiler/src/ast/nodes/for_loop.cpp | 16 ++-- .../nodes/functions/call/function_call.cpp | 4 +- .../functions/call/function_call_codegen.cpp | 2 +- .../declaration/function_declaration.cpp | 12 +-- .../function_declaration_forward_refs.cpp | 4 +- .../declaration/function_parameters.cpp | 4 +- .../src/ast/nodes/functions/symbols.cpp | 4 +- packages/compiler/src/ast/nodes/import.cpp | 6 +- packages/compiler/src/ast/nodes/literals.cpp | 12 +-- packages/compiler/src/ast/nodes/module.cpp | 6 +- packages/compiler/src/ast/nodes/package.cpp | 2 +- .../src/ast/nodes/return_statement.cpp | 4 +- .../src/ast/nodes/types/alias_type.cpp | 4 +- .../src/ast/nodes/types/enum_type.cpp | 8 +- .../src/ast/nodes/types/function_type.cpp | 2 +- .../src/ast/nodes/types/object_type.cpp | 8 +- .../src/ast/nodes/types/primitive_type.cpp | 2 +- .../src/ast/nodes/types/tuple_type.cpp | 2 +- .../src/ast/nodes/types/type_definition.cpp | 4 +- .../compiler/src/ast/nodes/types/types.cpp | 4 +- .../compiler/src/ast/nodes/while_loop.cpp | 10 +-- packages/compiler/src/ast/optionals.cpp | 2 +- .../src/ast/traversal/expression_visitor.cpp | 2 +- .../src/ast/traversal/function_visitor.cpp | 2 +- .../compiler/src/ast/traversal/traversal.cpp | 2 +- packages/compiler/src/ast/type_inference.cpp | 2 +- packages/compiler/src/stl/runtime_symbols.cpp | 4 +- .../compiler/tests/test_chained_accessor.cpp | 6 +- packages/compiler/tests/test_type_casts.cpp | 2 +- .../compiler/tests/test_type_inference.cpp | 6 +- packages/compiler/tests/test_types.cpp | 2 +- packages/compiler/tests/utils.h | 6 +- 75 files changed, 287 insertions(+), 287 deletions(-) rename packages/compiler/include/ast/{parsing_context.h => symbol_table.h} (91%) rename packages/compiler/src/ast/context/{parsing_context.cpp => symbol_table.cpp} (91%) diff --git a/packages/compiler/include/ast/ast.h b/packages/compiler/include/ast/ast.h index 2fd519e0..1eee5a35 100644 --- a/packages/compiler/include/ast/ast.h +++ b/packages/compiler/include/ast/ast.h @@ -1,5 +1,5 @@ #pragma once -#include "parsing_context.h" +#include "symbol_table.h" #include #include @@ -41,12 +41,12 @@ namespace stride::ast }; std::unique_ptr parse_next_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); std::unique_ptr parse_sequential( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); } diff --git a/packages/compiler/include/ast/generics.h b/packages/compiler/include/ast/generics.h index 35cd7ef5..47f61669 100644 --- a/packages/compiler/include/ast/generics.h +++ b/packages/compiler/include/ast/generics.h @@ -16,7 +16,7 @@ namespace stride::ast } class AstAliasType; - class ParsingContext; + class SymbolTable; class IAstType; class TokenSet; @@ -28,7 +28,7 @@ namespace stride::ast GenericParameterList parse_generic_declaration(TokenSet& set); - GenericTypeList parse_generic_type_arguments(const std::shared_ptr& context, TokenSet& set); + GenericTypeList parse_generic_type_arguments(const std::shared_ptr& context, TokenSet& set); std::unique_ptr resolve_generics( IAstType* type, diff --git a/packages/compiler/include/ast/nodes/ast_node.h b/packages/compiler/include/ast/nodes/ast_node.h index c4e251a3..1c469776 100644 --- a/packages/compiler/include/ast/nodes/ast_node.h +++ b/packages/compiler/include/ast/nodes/ast_node.h @@ -17,17 +17,17 @@ namespace stride::ast { class AstBlock; - class ParsingContext; + class SymbolTable; class IAstNode { const SourceFragment _source_position; - const std::shared_ptr _context; + const std::shared_ptr _context; public: explicit IAstNode( const SourceFragment& source, - const std::shared_ptr& context + const std::shared_ptr& context ) : _source_position(source), _context(context) {} @@ -45,7 +45,7 @@ namespace stride::ast } [[nodiscard]] - std::shared_ptr get_context() const + std::shared_ptr get_context() const { return this->_context; } diff --git a/packages/compiler/include/ast/nodes/blocks.h b/packages/compiler/include/ast/nodes/blocks.h index 58bb3938..323f74fc 100644 --- a/packages/compiler/include/ast/nodes/blocks.h +++ b/packages/compiler/include/ast/nodes/blocks.h @@ -18,7 +18,7 @@ namespace stride::ast public: explicit AstBlock( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::vector> children ) : IAstNode(source, context), @@ -49,7 +49,7 @@ namespace stride::ast ~AstBlock() override = default; static std::unique_ptr create_empty( - const std::shared_ptr& context, + const std::shared_ptr& context, const SourceFragment& source ) { @@ -60,7 +60,7 @@ namespace stride::ast }; std::unique_ptr parse_block( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); std::optional collect_block(TokenSet& set); diff --git a/packages/compiler/include/ast/nodes/conditional_statement.h b/packages/compiler/include/ast/nodes/conditional_statement.h index 2c9d207b..668c7a0d 100644 --- a/packages/compiler/include/ast/nodes/conditional_statement.h +++ b/packages/compiler/include/ast/nodes/conditional_statement.h @@ -18,7 +18,7 @@ namespace stride::ast public: explicit AstConditionalStatement( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr condition, std::unique_ptr body, std::unique_ptr else_body @@ -62,6 +62,6 @@ namespace stride::ast }; std::unique_ptr parse_if_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/control_flow_statements.h b/packages/compiler/include/ast/nodes/control_flow_statements.h index 1a55135d..9a3b8c5b 100644 --- a/packages/compiler/include/ast/nodes/control_flow_statements.h +++ b/packages/compiler/include/ast/nodes/control_flow_statements.h @@ -10,7 +10,7 @@ namespace stride::ast public: explicit IAstControlFlowStatement( const SourceFragment& source, - const std::shared_ptr& context + const std::shared_ptr& context ) : IAstNode(source, context) {} }; @@ -21,7 +21,7 @@ namespace stride::ast public: explicit AstContinueStatement( const SourceFragment& source, - const std::shared_ptr& context + const std::shared_ptr& context ) : IAstControlFlowStatement(source, context) {} @@ -43,7 +43,7 @@ namespace stride::ast public: explicit AstBreakStatement( const SourceFragment& source, - const std::shared_ptr& context + const std::shared_ptr& context ) : IAstControlFlowStatement(source, context) {} @@ -60,12 +60,12 @@ namespace stride::ast }; std::unique_ptr parse_continue_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); std::unique_ptr parse_break_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); } diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index d5a06c63..7ca2f7d9 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -16,7 +16,7 @@ namespace stride::ast enum class TokenType; class AstLiteral; class AstFunctionParameter; - class ParsingContext; + class SymbolTable; namespace definition { @@ -93,13 +93,13 @@ namespace stride::ast public: explicit IAstExpression( const SourceFragment& source_position, - const std::shared_ptr& context + const std::shared_ptr& context ) : IAstNode(source_position, context) {} explicit IAstExpression( const SourceFragment& source_position, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr type ) : IAstExpression(source_position, context) @@ -157,7 +157,7 @@ namespace stride::ast public: explicit AstArray( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, ExpressionList elements ) : IAstExpression(source, context), @@ -193,7 +193,7 @@ namespace stride::ast public: explicit AstIdentifier( - const std::shared_ptr& context, + const std::shared_ptr& context, Symbol symbol ) : IAstExpression(symbol.symbol_position, context), @@ -248,7 +248,7 @@ namespace stride::ast public: explicit AstArrayMemberAccessor( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr array_base, std::unique_ptr index_expr ) : @@ -295,7 +295,7 @@ namespace stride::ast public: explicit AstChainedExpression( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr base, std::unique_ptr followup ) : @@ -348,7 +348,7 @@ namespace stride::ast public: explicit AstIndirectCall( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr callee, ExpressionList args ) : @@ -402,7 +402,7 @@ namespace stride::ast public: explicit AstFunctionCall( - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr function_name_identifier, ExpressionList arguments, GenericTypeList generic_type_arguments, @@ -507,7 +507,7 @@ namespace stride::ast public: explicit AstVariableDeclaration( - const std::shared_ptr& context, + const std::shared_ptr& context, Symbol symbol, std::optional> variable_type, std::unique_ptr initial_value, @@ -606,7 +606,7 @@ namespace stride::ast explicit IBinaryOp( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr lsh, std::unique_ptr rsh ) : @@ -635,7 +635,7 @@ namespace stride::ast public: explicit AstBinaryArithmeticOp( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr left, const BinaryOpType op, std::unique_ptr right @@ -677,7 +677,7 @@ namespace stride::ast public: explicit AstLogicalOp( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr left, const LogicalOpType op, std::unique_ptr right @@ -710,7 +710,7 @@ namespace stride::ast public: explicit AstComparisonOp( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr left, const ComparisonOpType op, std::unique_ptr right @@ -744,7 +744,7 @@ namespace stride::ast public: explicit AstUnaryOp( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const UnaryOpType op, std::unique_ptr operand ) : @@ -799,7 +799,7 @@ namespace stride::ast public: explicit AstVariableReassignment( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr identifier, const MutativeAssignmentType op, std::unique_ptr value @@ -867,7 +867,7 @@ namespace stride::ast public: explicit AstObjectInitializer( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string struct_name, std::vector member_initializers, GenericTypeList generic_type_arguments = {} @@ -927,7 +927,7 @@ namespace stride::ast public: explicit AstVariadicArgReference( const SourceFragment& source, - const std::shared_ptr& context) : + const std::shared_ptr& context) : IAstExpression(source, context) {} llvm::Value* codegen( @@ -958,7 +958,7 @@ namespace stride::ast public: explicit AstTupleInitializer( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, ExpressionList members ) : IAstExpression(source, context), @@ -990,7 +990,7 @@ namespace stride::ast public: explicit AstTypeCastOp( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr value, std::unique_ptr target_type ) : @@ -1029,46 +1029,46 @@ namespace stride::ast /// Parses a complete standalone expression from tokens std::unique_ptr parse_standalone_expression( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); /// Parses an expression that appears inline, e.g., within a statement or as a sub-expression std::unique_ptr parse_inline_expression( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); /// Parses a single part of a standalone expression std::unique_ptr parse_inline_expression_part( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); /// Parses a variable declaration statement std::unique_ptr parse_variable_declaration( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier); /// Parses a variable declaration that appears inline within a larger expression context std::unique_ptr parse_variable_declaration_inline( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier); /// Parses a function invocation into an AstFunctionCall expression node std::unique_ptr parse_function_call( - const std::shared_ptr& context, + const std::shared_ptr& context, AstIdentifier* identifier, TokenSet& set); /// Parses a variable assignment statement std::optional> parse_variable_reassignment( - const std::shared_ptr& context, + const std::shared_ptr& context, AstIdentifier* identifier, TokenSet& set); /// Parses a binary arithmetic operation using precedence climbing std::optional> parse_arithmetic_binary_operation_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs, int min_precedence @@ -1076,58 +1076,58 @@ namespace stride::ast /// Parses a single chained member access step: consumes `.identifier` and wraps lhs std::unique_ptr parse_chained_member_access( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs ); /// Parses a unary operator expression std::optional> parse_binary_unary_op( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); /// Parses an array initializer expression, e.g., [1, 2, 3] std::unique_ptr parse_array_initializer( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); /// Parses an array subscript: consumes `[]` and wraps the base expression std::unique_ptr parse_array_member_accessor( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr array_base ); /// Parses an indirect call: consumes `()` and wraps the callee expression std::unique_ptr parse_indirect_call( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr callee ); /// Parses a struct initializer expression into an AstObjectInitializer node std::unique_ptr parse_object_initializer( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); /// Parses a dot-separated identifier into its individual name segments, e.g., `foo::bar::baz` std::unique_ptr parse_segmented_identifier( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const std::string& error_message ); /// Parses a lambda function literal into an expression node std::unique_ptr parse_anonymous_fn_expression( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); std::optional> parse_type_cast_op( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, IAstExpression* lhs ); diff --git a/packages/compiler/include/ast/nodes/for_loop.h b/packages/compiler/include/ast/nodes/for_loop.h index c54ce23c..53062ed3 100644 --- a/packages/compiler/include/ast/nodes/for_loop.h +++ b/packages/compiler/include/ast/nodes/for_loop.h @@ -17,7 +17,7 @@ namespace stride::ast public: explicit AstForLoop( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr initiator, std::unique_ptr condition, std::unique_ptr increment, @@ -65,7 +65,7 @@ namespace stride::ast }; std::unique_ptr parse_for_loop_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/function_declaration.h b/packages/compiler/include/ast/nodes/function_declaration.h index c39cd67c..619a5db5 100644 --- a/packages/compiler/include/ast/nodes/function_declaration.h +++ b/packages/compiler/include/ast/nodes/function_declaration.h @@ -4,7 +4,7 @@ #include "blocks.h" #include "expression.h" #include "ast/modifiers.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include @@ -32,7 +32,7 @@ namespace stride::ast public: explicit AstFunctionParameter( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string param_name, std::unique_ptr param_type ) : @@ -100,7 +100,7 @@ namespace stride::ast public: explicit IAstFunction( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, Symbol symbol, std::vector> parameters, std::unique_ptr body, @@ -254,8 +254,8 @@ namespace stride::ast static void collect_free_variables( IAstNode* node, - const std::shared_ptr& lambda_context, - const std::shared_ptr& outer_context, + const std::shared_ptr& lambda_context, + const std::shared_ptr& outer_context, std::vector& captures ); @@ -268,7 +268,7 @@ namespace stride::ast { public: explicit AstFunctionDeclaration( - const std::shared_ptr& context, + const std::shared_ptr& context, Symbol symbol, std::vector> parameters, std::unique_ptr body, @@ -297,7 +297,7 @@ namespace stride::ast { public: explicit AstLambdaFunctionExpression( - const std::shared_ptr& context, + const std::shared_ptr& context, Symbol symbol, std::vector> parameters, std::unique_ptr body, @@ -321,19 +321,19 @@ namespace stride::ast }; std::unique_ptr parse_fn_declaration( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier ); void parse_standalone_fn_param( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::vector>& parameters ); void parse_function_parameters( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::vector>& parameters, int& function_flags diff --git a/packages/compiler/include/ast/nodes/import.h b/packages/compiler/include/ast/nodes/import.h index 38516978..fde2b062 100644 --- a/packages/compiler/include/ast/nodes/import.h +++ b/packages/compiler/include/ast/nodes/import.h @@ -9,7 +9,7 @@ namespace stride::ast { class TokenSet; - class ParsingContext; + class SymbolTable; typedef struct Dependency { @@ -27,7 +27,7 @@ namespace stride::ast public: explicit AstImport( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr package_identifier, std::vector> import_list ) : @@ -60,7 +60,7 @@ namespace stride::ast }; std::unique_ptr parse_import_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/literal_values.h b/packages/compiler/include/ast/nodes/literal_values.h index 5b4bab7a..3214d600 100644 --- a/packages/compiler/include/ast/nodes/literal_values.h +++ b/packages/compiler/include/ast/nodes/literal_values.h @@ -18,7 +18,7 @@ namespace stride::ast public: explicit AstLiteral( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const PrimitiveType type ) : IAstExpression(source, context), @@ -50,7 +50,7 @@ namespace stride::ast public: explicit IAstLiteralBase( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const PrimitiveType type, T value ) : @@ -70,7 +70,7 @@ namespace stride::ast public: explicit AstStringLiteral( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string val ) : // Strings are only considered to be a single byte, @@ -96,7 +96,7 @@ namespace stride::ast public: explicit AstIntLiteral( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const PrimitiveType type, const int64_t value, const int flags = SRFLAG_TYPE_INT_SIGNED @@ -131,7 +131,7 @@ namespace stride::ast public: explicit AstFpLiteral( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const PrimitiveType type, const long double value ) : @@ -153,7 +153,7 @@ namespace stride::ast public: explicit AstBooleanLiteral( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const bool value ) : IAstLiteralBase(source, context, PrimitiveType::BOOL, value) {} @@ -173,7 +173,7 @@ namespace stride::ast public: explicit AstCharLiteral( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const char value ) : IAstLiteralBase(source, context, PrimitiveType::CHAR, value) {} @@ -194,7 +194,7 @@ namespace stride::ast public: AstNilLiteral( const SourceFragment& source, - const std::shared_ptr& context + const std::shared_ptr& context ) : AstLiteral(source, context, PrimitiveType::NIL) {} @@ -212,27 +212,27 @@ namespace stride::ast }; std::optional> parse_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); std::optional> parse_boolean_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); std::optional> parse_float_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); std::optional> parse_integer_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); std::optional> parse_string_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); std::optional> parse_char_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); inline bool is_literal_ast_node(IAstNode* node) diff --git a/packages/compiler/include/ast/nodes/module.h b/packages/compiler/include/ast/nodes/module.h index 306a61f7..53bb0c04 100644 --- a/packages/compiler/include/ast/nodes/module.h +++ b/packages/compiler/include/ast/nodes/module.h @@ -19,7 +19,7 @@ namespace stride::ast explicit AstModule( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string name, std::unique_ptr body ) : @@ -54,6 +54,6 @@ namespace stride::ast }; std::unique_ptr parse_module_statement( - const std::shared_ptr&, + const std::shared_ptr&, TokenSet& set); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/package.h b/packages/compiler/include/ast/nodes/package.h index 2e2f12ea..28f4ffcc 100644 --- a/packages/compiler/include/ast/nodes/package.h +++ b/packages/compiler/include/ast/nodes/package.h @@ -14,7 +14,7 @@ namespace stride::ast public: explicit AstPackage( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string package_name ) : IAstNode(source, context), @@ -41,6 +41,6 @@ namespace stride::ast bool is_package_declaration(const TokenSet& set); std::unique_ptr parse_package_declaration( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/return_statement.h b/packages/compiler/include/ast/nodes/return_statement.h index 62a92063..512720d7 100644 --- a/packages/compiler/include/ast/nodes/return_statement.h +++ b/packages/compiler/include/ast/nodes/return_statement.h @@ -18,7 +18,7 @@ namespace stride namespace stride::ast { - class ParsingContext; + class SymbolTable; class AstReturnStatement : public IAstNode @@ -28,7 +28,7 @@ namespace stride::ast public: explicit AstReturnStatement( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::optional> value ) : IAstNode(source, context), @@ -64,6 +64,6 @@ namespace stride::ast }; std::unique_ptr parse_return_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set); } // namespace stride::ast diff --git a/packages/compiler/include/ast/nodes/switch.h b/packages/compiler/include/ast/nodes/switch.h index c9c14e6d..8c2d7c33 100644 --- a/packages/compiler/include/ast/nodes/switch.h +++ b/packages/compiler/include/ast/nodes/switch.h @@ -1,7 +1,7 @@ #pragma once #include "ast_node.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/tokens/token_set.h" #include @@ -15,7 +15,7 @@ namespace stride::ast public: AstSwitchBranch( const SourceFragment& source, - const std::shared_ptr& context) : + const std::shared_ptr& context) : IAstNode(source, context) {} }; @@ -28,7 +28,7 @@ namespace stride::ast public: explicit AstSwitch( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string name) : IAstNode(source, context), _name(std::move(name)) {} diff --git a/packages/compiler/include/ast/nodes/type_definition.h b/packages/compiler/include/ast/nodes/type_definition.h index 04a0c5c3..8e9c99b3 100644 --- a/packages/compiler/include/ast/nodes/type_definition.h +++ b/packages/compiler/include/ast/nodes/type_definition.h @@ -22,7 +22,7 @@ namespace stride::ast public: explicit AstTypeDefinition( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string name, std::unique_ptr type, const VisibilityModifier visibility, @@ -79,19 +79,19 @@ namespace stride::ast }; std::unique_ptr parse_type_definition( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier ); EnumMemberPair parse_enumerable_member( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, size_t element_index ); std::unique_ptr parse_enum_type_definition( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier); } diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index e475db84..29cc4442 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -69,7 +69,7 @@ namespace stride::ast public: explicit IAstType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const int flags ) : IAstNode(source, context), @@ -190,7 +190,7 @@ namespace stride::ast public: explicit AstPrimitiveType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, const PrimitiveType type, const int flags = SRFLAG_NONE ) : @@ -273,7 +273,7 @@ namespace stride::ast public: explicit AstAliasType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string name, const int flags = SRFLAG_NONE, GenericTypeList generic_parameters = EMPTY_GENERIC_TYPE_LIST @@ -348,7 +348,7 @@ namespace stride::ast public: explicit AstFunctionType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::vector> parameters, std::unique_ptr return_type, GenericParameterList generic_parameter_names = {}, @@ -421,7 +421,7 @@ namespace stride::ast public: explicit AstArrayType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr element_type, const size_t initial_length, const int flags = SRFLAG_NONE @@ -482,7 +482,7 @@ namespace stride::ast public: explicit AstObjectType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string type_name, ObjectTypeMemberList members, const int flags = SRFLAG_NONE, @@ -534,7 +534,7 @@ namespace stride::ast public: explicit AstEnumType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string enum_name, std::vector members, int flags = SRFLAG_NONE @@ -584,7 +584,7 @@ namespace stride::ast public: explicit AstTupleType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::vector> members, const int flags = SRFLAG_NONE ) : @@ -646,7 +646,7 @@ namespace stride::ast }; std::unique_ptr parse_type( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options); @@ -658,27 +658,27 @@ namespace stride::ast ); std::optional> parse_primitive_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options); std::optional> parse_alias_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options); std::optional> parse_function_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options); std::optional> parse_object_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options); std::optional> parse_tuple_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options); diff --git a/packages/compiler/include/ast/nodes/while_loop.h b/packages/compiler/include/ast/nodes/while_loop.h index 04958e74..b0e9fb12 100644 --- a/packages/compiler/include/ast/nodes/while_loop.h +++ b/packages/compiler/include/ast/nodes/while_loop.h @@ -6,7 +6,7 @@ namespace stride::ast { class IAstNode; - class ParsingContext; + class SymbolTable; class TokenSet; enum class VisibilityModifier; @@ -20,7 +20,7 @@ namespace stride::ast public: explicit AstWhileLoop( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::unique_ptr condition, std::unique_ptr body ) : @@ -52,7 +52,7 @@ namespace stride::ast }; std::unique_ptr parse_while_loop_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier); } // namespace stride::ast diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/symbol_table.h similarity index 91% rename from packages/compiler/include/ast/parsing_context.h rename to packages/compiler/include/ast/symbol_table.h index ddbafb5d..5aae36cc 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/symbol_table.h @@ -27,7 +27,7 @@ namespace stride::ast CONTROL_FLOW }; - class ParsingContext + class SymbolTable { /** * Name of the context. This can be used for function name mangling, @@ -35,7 +35,7 @@ namespace stride::ast */ std::string _context_name; ContextType _context_type; - std::shared_ptr _parent_registry; + std::shared_ptr _parent_registry; std::vector> _symbols; @@ -44,27 +44,27 @@ namespace stride::ast static inline std::vector> control_flow_loop_blocks; public: - explicit ParsingContext( + explicit SymbolTable( std::string context_name, const ContextType type, - std::shared_ptr parent) : + std::shared_ptr parent) : _context_name(std::move(context_name)), _context_type(type), _parent_registry(std::move(parent)) {} /// Non-specific scope context definitions, e.g., for/while-loop blocks - explicit ParsingContext( - std::shared_ptr parent, + explicit SymbolTable( + std::shared_ptr parent, const ContextType type ) : // Context gets the same name as the parent - ParsingContext(parent->_context_name, type, std::move(parent)) {} + SymbolTable(parent->_context_name, type, std::move(parent)) {} /// Root node initialization - explicit ParsingContext() : - ParsingContext("", ContextType::GLOBAL, nullptr) {} + explicit SymbolTable() : + SymbolTable("", ContextType::GLOBAL, nullptr) {} - ParsingContext& operator=(const ParsingContext&) = delete; + SymbolTable& operator=(const SymbolTable&) = delete; [[nodiscard]] ContextType get_context_type() const @@ -141,7 +141,7 @@ namespace stride::ast const std::string& internal_name) const; [[nodiscard]] - std::shared_ptr get_parent_context() const + std::shared_ptr get_parent_context() const { return this->_parent_registry; } @@ -219,7 +219,7 @@ namespace stride::ast } [[nodiscard]] - const ParsingContext& traverse_to_root() const; + const SymbolTable& traverse_to_root() const; }; std::string scope_type_to_str(const ContextType& scope_type); diff --git a/packages/compiler/include/ast/symbols.h b/packages/compiler/include/ast/symbols.h index 4cf2c8e6..85e79af1 100644 --- a/packages/compiler/include/ast/symbols.h +++ b/packages/compiler/include/ast/symbols.h @@ -56,7 +56,7 @@ namespace stride::ast using SymbolNameSegments = std::vector; Symbol resolve_internal_function_name( - const std::shared_ptr& context, + const std::shared_ptr& context, const SourceFragment& position, const SymbolNameSegments& function_name_segments, const std::vector& parameter_types); diff --git a/packages/compiler/include/program.h b/packages/compiler/include/program.h index d7bd1767..0e9c6c36 100644 --- a/packages/compiler/include/program.h +++ b/packages/compiler/include/program.h @@ -1,7 +1,7 @@ #pragma once #include "cli.h" #include "ast/ast.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/ast_node.h" #include "ast/nodes/blocks.h" diff --git a/packages/compiler/include/runtime/symbols.h b/packages/compiler/include/runtime/symbols.h index 7e76d6df..30c0ec9e 100644 --- a/packages/compiler/include/runtime/symbols.h +++ b/packages/compiler/include/runtime/symbols.h @@ -6,12 +6,12 @@ namespace llvm::orc { } namespace stride::ast { - class ParsingContext; + class SymbolTable; } namespace stride::runtime { - void register_runtime_symbols(const std::shared_ptr& context); + void register_runtime_symbols(const std::shared_ptr& context); void register_jit_symbols(llvm::orc::LLJIT* jit); } diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp index 7d49dda6..94d4d3fb 100644 --- a/packages/compiler/src/ast/ast.cpp +++ b/packages/compiler/src/ast/ast.cpp @@ -2,7 +2,7 @@ #include "files.h" #include "ast/modifiers.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/blocks.h" #include "ast/nodes/conditional_statement.h" #include "ast/nodes/control_flow_statements.h" @@ -56,7 +56,7 @@ std::pair> Ast::parse_file(const FilePath& p const auto source_file = read_file(path); auto tokens = tokenizer::tokenize(source_file); - const auto context = std::make_shared(); + const auto context = std::make_shared(); auto file_node = parse_sequential(context, tokens); @@ -107,7 +107,7 @@ void Ast::optimize() } std::unique_ptr stride::ast::parse_next_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { // Phase 1 - These sequences are simple to parse; they have no visibility modifiers, hence we @@ -171,7 +171,7 @@ std::unique_ptr stride::ast::parse_next_statement( } std::unique_ptr stride::ast::parse_sequential( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/context/field_registry.cpp b/packages/compiler/src/ast/context/field_registry.cpp index ee5307ef..15d19db6 100644 --- a/packages/compiler/src/ast/context/field_registry.cpp +++ b/packages/compiler/src/ast/context/field_registry.cpp @@ -1,11 +1,11 @@ #include "errors.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include using namespace stride::ast; -definition::FieldDefinition* ParsingContext::get_variable_def( +definition::FieldDefinition* SymbolTable::get_variable_def( const std::string& variable_name, const bool use_raw_name ) const @@ -25,7 +25,7 @@ definition::FieldDefinition* ParsingContext::get_variable_def( return nullptr; } -bool ParsingContext::is_field_defined_in_scope( +bool SymbolTable::is_field_defined_in_scope( const std::string& variable_name) const { return std::ranges::any_of( @@ -41,7 +41,7 @@ bool ParsingContext::is_field_defined_in_scope( }); } -bool ParsingContext::is_field_defined_globally( +bool SymbolTable::is_field_defined_globally( const std::string& field_name) const { auto current = this; @@ -56,7 +56,7 @@ bool ParsingContext::is_field_defined_globally( return false; } -void ParsingContext::define_variable_globally( +void SymbolTable::define_variable_globally( Symbol variable_symbol, std::unique_ptr type, VisibilityModifier visibility, @@ -88,7 +88,7 @@ void ParsingContext::define_variable_globally( ); } - auto& global_scope = const_cast(this->traverse_to_root()); + auto& global_scope = const_cast(this->traverse_to_root()); global_scope._symbols.push_back( std::make_unique( std::move(variable_symbol), @@ -98,7 +98,7 @@ void ParsingContext::define_variable_globally( ); } -void ParsingContext::define_variable( +void SymbolTable::define_variable( Symbol variable_sym, std::unique_ptr type, VisibilityModifier visibility, @@ -149,7 +149,7 @@ void ParsingContext::define_variable( ); } -definition::FieldDefinition* ParsingContext::lookup_variable( +definition::FieldDefinition* SymbolTable::lookup_variable( const std::string& name, const bool use_raw_name ) diff --git a/packages/compiler/src/ast/context/function_registry.cpp b/packages/compiler/src/ast/context/function_registry.cpp index c5c7a916..6e527a4e 100644 --- a/packages/compiler/src/ast/context/function_registry.cpp +++ b/packages/compiler/src/ast/context/function_registry.cpp @@ -1,6 +1,6 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/definitions/function_definition.h" #include @@ -8,7 +8,7 @@ using namespace stride::ast; using namespace stride::ast::definition; -std::optional ParsingContext::get_function_definition( +std::optional SymbolTable::get_function_definition( const std::string& function_name, const std::vector>& parameter_types, const size_t instantiated_generic_count @@ -31,7 +31,7 @@ std::optional ParsingContext::get_function_definition( return std::nullopt; } -std::optional ParsingContext::get_generic_function_definition( +std::optional SymbolTable::get_generic_function_definition( const std::string& function_name, const size_t instantiated_generic_count ) const @@ -39,7 +39,7 @@ std::optional ParsingContext::get_generic_function_definiti return get_function_definition(function_name, {}, instantiated_generic_count); } -std::optional ParsingContext::get_function_definition( +std::optional SymbolTable::get_function_definition( const std::string& function_name, // We might call this function with an anonymous type, hence not having `AstFunctionType` IAstType* function_type @@ -140,14 +140,14 @@ const return true; } -void ParsingContext::define_function( +void SymbolTable::define_function( Symbol function_name, std::unique_ptr function_type, const VisibilityModifier visibility, const int flags ) const { - auto& global_scope = const_cast(this->traverse_to_root()); + auto& global_scope = const_cast(this->traverse_to_root()); if (this->is_function_defined_globally(function_name.internal_name, function_type.get())) { @@ -163,7 +163,7 @@ void ParsingContext::define_function( ); } -bool ParsingContext::is_function_defined_globally( +bool SymbolTable::is_function_defined_globally( const std::string& function_name, const AstFunctionType* function_type ) const diff --git a/packages/compiler/src/ast/context/parsing_context.cpp b/packages/compiler/src/ast/context/symbol_table.cpp similarity index 91% rename from packages/compiler/src/ast/context/parsing_context.cpp rename to packages/compiler/src/ast/context/symbol_table.cpp index 0e8b05b3..62af1224 100644 --- a/packages/compiler/src/ast/context/parsing_context.cpp +++ b/packages/compiler/src/ast/context/symbol_table.cpp @@ -1,4 +1,4 @@ -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "errors.h" #include "ast/symbols.h" @@ -28,7 +28,7 @@ std::string stride::ast::scope_type_to_str(const ContextType& scope_type) return "unknown"; } -const ParsingContext& ParsingContext::traverse_to_root() const +const SymbolTable& SymbolTable::traverse_to_root() const { auto current = this; while (current->_parent_registry) @@ -38,12 +38,12 @@ const ParsingContext& ParsingContext::traverse_to_root() const return *current; } -void ParsingContext::define(std::unique_ptr definition) +void SymbolTable::define(std::unique_ptr definition) { this->_symbols.push_back(std::move(definition)); } -std::optional> ParsingContext::get_definition_by_internal_name(const std::string& internal_name) const +std::optional> SymbolTable::get_definition_by_internal_name(const std::string& internal_name) const { auto current = this; while (current != nullptr) @@ -85,7 +85,7 @@ static size_t levenshtein_distance(const std::string& a, const std::string& b) return prev[len_b]; } -IDefinition* ParsingContext::fuzzy_find(const std::string& symbol_name) const +IDefinition* SymbolTable::fuzzy_find(const std::string& symbol_name) const { IDefinition* best_match = nullptr; size_t best_distance = std::numeric_limits::max(); @@ -154,7 +154,7 @@ IDefinition* ParsingContext::fuzzy_find(const std::string& symbol_name) const return nullptr; } -IDefinition* ParsingContext::lookup_symbol(const std::string& symbol_name) const +IDefinition* SymbolTable::lookup_symbol(const std::string& symbol_name) const { auto current = this; while (current != nullptr) diff --git a/packages/compiler/src/ast/context/type_registry.cpp b/packages/compiler/src/ast/context/type_registry.cpp index ce6571fe..e75cdf85 100644 --- a/packages/compiler/src/ast/context/type_registry.cpp +++ b/packages/compiler/src/ast/context/type_registry.cpp @@ -1,18 +1,18 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include using namespace stride::ast; using namespace stride::ast::definition; -bool ParsingContext::is_type_defined(const std::string& type_name) const +bool SymbolTable::is_type_defined(const std::string& type_name) const { return get_type_definition(type_name).has_value(); } -bool ParsingContext::is_struct_type_defined(const std::string& struct_name) const +bool SymbolTable::is_struct_type_defined(const std::string& struct_name) const { const auto type_def = get_type_definition(struct_name); @@ -20,7 +20,7 @@ bool ParsingContext::is_struct_type_defined(const std::string& struct_name) cons cast_type(type_def.value()->get_type()) != nullptr; } -std::optional ParsingContext::get_type_definition(const std::string& name) const +std::optional SymbolTable::get_type_definition(const std::string& name) const { auto current = this; @@ -52,7 +52,7 @@ std::optional ParsingContext::get_type_definition(const std::st /// Gets the root struct type layout for name. /// Will recursively look up the parent struct definition if name is a reference struct type. -std::optional ParsingContext::get_object_type(const std::string& name) const +std::optional SymbolTable::get_object_type(const std::string& name) const { const auto type_def = get_type_definition(name); @@ -111,14 +111,14 @@ std::optional ParsingContext::get_object_type(const std::string& return std::nullopt; } -void ParsingContext::define_type( +void SymbolTable::define_type( const Symbol& type_name, std::unique_ptr type, GenericParameterList generics, const VisibilityModifier visibility ) const { - auto& root_context = const_cast(this->traverse_to_root()); + auto& root_context = const_cast(this->traverse_to_root()); if (const auto existing_def = root_context.get_type_definition(type_name.internal_name); existing_def.has_value()) diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index 092d2b34..5d0c0801 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -2,7 +2,7 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/type_inference.h" #include "ast/nodes/blocks.h" #include "ast/nodes/conditional_statement.h" @@ -42,7 +42,7 @@ GenericParameterList stride::ast::parse_generic_declaration(TokenSet& set) return generic_params; } -GenericTypeList stride::ast::parse_generic_type_arguments(const std::shared_ptr& context, TokenSet& set) +GenericTypeList stride::ast::parse_generic_type_arguments(const std::shared_ptr& context, TokenSet& set) { GenericTypeList generic_params; if (set.peek_next_eq(TokenType::LT)) diff --git a/packages/compiler/src/ast/nodes/blocks.cpp b/packages/compiler/src/ast/nodes/blocks.cpp index e29fcce4..ae1edfec 100644 --- a/packages/compiler/src/ast/nodes/blocks.cpp +++ b/packages/compiler/src/ast/nodes/blocks.cpp @@ -11,7 +11,7 @@ using namespace stride::ast; std::unique_ptr stride::ast::parse_block( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/conditional_statement.cpp b/packages/compiler/src/ast/nodes/conditional_statement.cpp index 7c561eba..8df813d6 100644 --- a/packages/compiler/src/ast/nodes/conditional_statement.cpp +++ b/packages/compiler/src/ast/nodes/conditional_statement.cpp @@ -3,7 +3,7 @@ #include "errors.h" #include "ast/ast.h" #include "ast/conditionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/tokens/token_set.h" #include @@ -14,7 +14,7 @@ using namespace stride::ast; using namespace stride::ast::definition; std::unique_ptr parse_else_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -24,7 +24,7 @@ std::unique_ptr parse_else_optional( } const auto reference_token = set.next(); - const auto else_block_context = std::make_shared( + const auto else_block_context = std::make_shared( context, context->get_context_type() ); @@ -33,7 +33,7 @@ std::unique_ptr parse_else_optional( // have to parse_file the block separately if (set.peek_next_eq(TokenType::LBRACE)) { - const auto else_context = std::make_shared( + const auto else_context = std::make_shared( context, context->get_context_type()); return parse_block(else_context, set); @@ -51,13 +51,13 @@ std::unique_ptr parse_else_optional( } std::unique_ptr stride::ast::parse_if_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { const auto reference_token = set.expect(TokenType::KEYWORD_IF); - auto conditional_context = std::make_shared( + auto conditional_context = std::make_shared( context, context->get_context_type()); auto if_header_body = collect_parenthesized_block(set); diff --git a/packages/compiler/src/ast/nodes/control_flow/break_statement.cpp b/packages/compiler/src/ast/nodes/control_flow/break_statement.cpp index 9223ba98..52fad5c2 100644 --- a/packages/compiler/src/ast/nodes/control_flow/break_statement.cpp +++ b/packages/compiler/src/ast/nodes/control_flow/break_statement.cpp @@ -1,5 +1,5 @@ #include "errors.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/control_flow_statements.h" #include @@ -8,7 +8,7 @@ using namespace stride::ast; std::unique_ptr stride::ast::parse_break_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -20,7 +20,7 @@ std::unique_ptr stride::ast::parse_break_statement( llvm::Value* AstBreakStatement::codegen(llvm::Module* module, llvm::IRBuilderBase* builder) { - if (ParsingContext::get_control_flow_blocks().empty()) + if (SymbolTable::get_control_flow_blocks().empty()) { throw parsing_error( ErrorType::COMPILATION_ERROR, @@ -42,7 +42,7 @@ llvm::Value* AstBreakStatement::codegen(llvm::Module* module, llvm::IRBuilderBas return nullptr; } - llvm::BasicBlock* break_target = ParsingContext::get_control_flow_blocks().back().second; + llvm::BasicBlock* break_target = SymbolTable::get_control_flow_blocks().back().second; if (!break_target) { return nullptr; diff --git a/packages/compiler/src/ast/nodes/control_flow/continue_statement.cpp b/packages/compiler/src/ast/nodes/control_flow/continue_statement.cpp index 645fac11..09e5dc99 100644 --- a/packages/compiler/src/ast/nodes/control_flow/continue_statement.cpp +++ b/packages/compiler/src/ast/nodes/control_flow/continue_statement.cpp @@ -1,5 +1,5 @@ #include "errors.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/control_flow_statements.h" #include @@ -7,7 +7,7 @@ using namespace stride::ast; std::unique_ptr stride::ast::parse_continue_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -19,7 +19,7 @@ std::unique_ptr stride::ast::parse_continue_statement( llvm::Value* AstContinueStatement::codegen(llvm::Module* module, llvm::IRBuilderBase* builder) { - if (ParsingContext::get_control_flow_blocks().empty()) + if (SymbolTable::get_control_flow_blocks().empty()) { throw parsing_error( ErrorType::COMPILATION_ERROR, @@ -28,7 +28,7 @@ llvm::Value* AstContinueStatement::codegen(llvm::Module* module, llvm::IRBuilder ); } - const auto continue_block = ParsingContext::get_control_flow_blocks().back().first; + const auto continue_block = SymbolTable::get_control_flow_blocks().back().first; builder->CreateBr(continue_block); // Since we branched, create a new block for unreachable code, but since it's a statement, return nullptr diff --git a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp index be316bd2..a78ffcb3 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp @@ -11,7 +11,7 @@ bool stride::ast::is_array_initializer(const TokenSet& set) } std::unique_ptr stride::ast::parse_array_initializer( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp index db10cd7e..2ba0e9b1 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp @@ -13,7 +13,7 @@ using namespace stride::ast; std::unique_ptr stride::ast::parse_array_member_accessor( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr array_base) { diff --git a/packages/compiler/src/ast/nodes/expressions/binary_operation.cpp b/packages/compiler/src/ast/nodes/expressions/binary_operation.cpp index f616e182..34ed7bc9 100644 --- a/packages/compiler/src/ast/nodes/expressions/binary_operation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/binary_operation.cpp @@ -80,7 +80,7 @@ std::string AstBinaryArithmeticOp::to_string() */ std::optional> stride::ast::parse_arithmetic_binary_operation_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs, const int min_precedence diff --git a/packages/compiler/src/ast/nodes/expressions/expression.cpp b/packages/compiler/src/ast/nodes/expressions/expression.cpp index 4b5a0d41..f6d72eb8 100644 --- a/packages/compiler/src/ast/nodes/expressions/expression.cpp +++ b/packages/compiler/src/ast/nodes/expressions/expression.cpp @@ -25,7 +25,7 @@ std::string IAstExpression::to_string() } std::unique_ptr stride::ast::parse_inline_expression_part( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -144,7 +144,7 @@ std::unique_ptr stride::ast::parse_inline_expression_part( */ std::unique_ptr parse_arithmetic_tier( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -176,7 +176,7 @@ std::unique_ptr parse_arithmetic_tier( } std::unique_ptr parse_comparison_tier( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -198,7 +198,7 @@ std::unique_ptr parse_comparison_tier( } std::unique_ptr parse_logical_tier( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -222,7 +222,7 @@ std::unique_ptr parse_logical_tier( // Kept for backward compatibility / external usage if any, but now updated to use correct tiers for // RHS std::optional> parse_logical_operation_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs) { @@ -249,7 +249,7 @@ std::optional> parse_logical_operation_optional( // Kept for backward compatibility / external usage if any std::optional> parse_comparative_operation_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs) { @@ -276,7 +276,7 @@ std::optional> parse_comparative_operation_optio } std::unique_ptr parse_expression_internal( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -292,7 +292,7 @@ std::unique_ptr parse_expression_internal( * General expression parsing. These can occur in global / function scopes */ std::unique_ptr stride::ast::parse_standalone_expression( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -304,7 +304,7 @@ std::unique_ptr stride::ast::parse_standalone_expression( } std::unique_ptr stride::ast::parse_inline_expression( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -312,7 +312,7 @@ std::unique_ptr stride::ast::parse_inline_expression( } std::unique_ptr stride::ast::parse_segmented_identifier( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const std::string& error_message) { diff --git a/packages/compiler/src/ast/nodes/expressions/identifier.cpp b/packages/compiler/src/ast/nodes/expressions/identifier.cpp index d93312cd..4897377e 100644 --- a/packages/compiler/src/ast/nodes/expressions/identifier.cpp +++ b/packages/compiler/src/ast/nodes/expressions/identifier.cpp @@ -1,6 +1,6 @@ #include "errors.h" #include "ast/closures.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include diff --git a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp index 1b8167ef..d465c94f 100644 --- a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp @@ -2,7 +2,7 @@ #include "formatting.h" #include "ast/casting.h" #include "ast/closures.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/blocks.h" #include "ast/nodes/expression.h" #include "ast/tokens/token_set.h" @@ -21,7 +21,7 @@ bool stride::ast::is_member_accessor(const TokenSet& set) /// Consumes `.identifier` and wraps `lhs` in an AstChainedExpression. std::unique_ptr stride::ast::parse_chained_member_access( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr lhs ) @@ -46,7 +46,7 @@ std::unique_ptr stride::ast::parse_chained_member_access( /// Consumes `()` and wraps `callee` in an AstIndirectCall. std::unique_ptr stride::ast::parse_indirect_call( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::unique_ptr callee ) diff --git a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp index dafac600..1b10fbcd 100644 --- a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp @@ -1,6 +1,6 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/blocks.h" #include "ast/nodes/expression.h" #include "ast/tokens/token_set.h" @@ -55,7 +55,7 @@ bool stride::ast::is_struct_initializer(const TokenSet& set) } StructMemberInitializerPair parse_object_member_initializer( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -144,7 +144,7 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ } std::unique_ptr stride::ast::parse_object_initializer( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/expressions/type_cast.cpp b/packages/compiler/src/ast/nodes/expressions/type_cast.cpp index 00d642c9..9e2e58d9 100644 --- a/packages/compiler/src/ast/nodes/expressions/type_cast.cpp +++ b/packages/compiler/src/ast/nodes/expressions/type_cast.cpp @@ -7,7 +7,7 @@ using namespace stride::ast; std::optional> stride::ast::parse_type_cast_op( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, IAstExpression* lhs ) diff --git a/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp b/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp index b29ec7ca..fddb21a3 100644 --- a/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp @@ -1,6 +1,6 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include "ast/nodes/literal_values.h" #include "ast/tokens/token_set.h" @@ -55,7 +55,7 @@ bool requires_identifier_operand(const UnaryOpType op) } std::optional> stride::ast::parse_binary_unary_op( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp index 36c91946..4cef982d 100644 --- a/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp +++ b/packages/compiler/src/ast/nodes/expressions/variable_declaration.cpp @@ -3,7 +3,7 @@ #include "ast/flags.h" #include "ast/modifiers.h" #include "ast/optionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include "ast/nodes/function_declaration.h" #include "ast/nodes/literal_values.h" @@ -18,7 +18,7 @@ using namespace stride::ast; std::unique_ptr stride::ast::parse_variable_declaration( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const VisibilityModifier modifier ) @@ -31,7 +31,7 @@ std::unique_ptr stride::ast::parse_variable_declaration( } std::unique_ptr stride::ast::parse_variable_declaration_inline( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier ) diff --git a/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp b/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp index 158b4365..c21c8eca 100644 --- a/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp @@ -1,7 +1,7 @@ #include "errors.h" #include "ast/casting.h" #include "ast/optionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" @@ -102,7 +102,7 @@ MutativeAssignmentType parse_mutative_assignment_type(const Token& token) } std::optional> stride::ast::parse_variable_reassignment( - const std::shared_ptr& context, + const std::shared_ptr& context, AstIdentifier* identifier, TokenSet& set ) diff --git a/packages/compiler/src/ast/nodes/for_loop.cpp b/packages/compiler/src/ast/nodes/for_loop.cpp index d9958f4f..11bf84f8 100644 --- a/packages/compiler/src/ast/nodes/for_loop.cpp +++ b/packages/compiler/src/ast/nodes/for_loop.cpp @@ -2,7 +2,7 @@ #include "ast/conditionals.h" #include "ast/modifiers.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/tokens/token.h" #include "ast/tokens/token_set.h" @@ -13,7 +13,7 @@ using namespace stride::ast; using namespace stride::ast::definition; std::unique_ptr collect_initiator( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { auto initiator = collect_until_token(set, TokenType::SEMICOLON); @@ -31,7 +31,7 @@ std::unique_ptr collect_initiator( } std::unique_ptr collect_condition( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { auto condition = collect_until_token(set, TokenType::SEMICOLON); @@ -46,7 +46,7 @@ std::unique_ptr collect_condition( } std::unique_ptr collect_incrementor( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { if (!set.has_next()) @@ -57,7 +57,7 @@ std::unique_ptr collect_incrementor( } std::unique_ptr stride::ast::parse_for_loop_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, [[maybe_unused]] VisibilityModifier modifier ) @@ -71,7 +71,7 @@ std::unique_ptr stride::ast::parse_for_loop_statement( } auto header_body = header_body_opt.value(); - const auto for_body_context = std::make_shared( + const auto for_body_context = std::make_shared( context, ContextType::CONTROL_FLOW); @@ -123,11 +123,11 @@ llvm::Value* AstForLoop::codegen( if (this->get_body()) { - ParsingContext::push_control_flow_block(loop_continue_bb, loop_end_bb); + SymbolTable::push_control_flow_block(loop_continue_bb, loop_end_bb); this->get_body()->codegen(module, builder); - ParsingContext::pop_control_flow_block(); + SymbolTable::pop_control_flow_block(); } // If we already have a terminator (e.g., from a break or continue), diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp index 3929f883..eba9500f 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call.cpp @@ -2,7 +2,7 @@ #include "formatting.h" #include "ast/casting.h" #include "ast/flags.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/symbols.h" #include "ast/definitions/function_definition.h" #include "ast/nodes/blocks.h" @@ -18,7 +18,7 @@ using namespace stride::ast; using namespace stride::ast::definition; std::unique_ptr stride::ast::parse_function_call( - const std::shared_ptr& context, + const std::shared_ptr& context, AstIdentifier* identifier, TokenSet& set ) diff --git a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp index 474c1dab..50ede849 100644 --- a/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/call/function_call_codegen.cpp @@ -1,7 +1,7 @@ #include "ast/casting.h" #include "ast/closures.h" #include "ast/optionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/definitions/function_definition.h" #include "ast/nodes/expression.h" diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp index ac44e32e..3d1e11f0 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration.cpp @@ -3,7 +3,7 @@ #include "errors.h" #include "ast/closures.h" #include "ast/modifiers.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/symbols.h" #include "ast/definitions/function_definition.h" #include "ast/nodes/blocks.h" @@ -25,7 +25,7 @@ using namespace stride::ast::definition; * Will attempt to parse the provided token stream into an AstFunctionDefinitionNode. */ std::unique_ptr stride::ast::parse_fn_declaration( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier ) @@ -50,7 +50,7 @@ std::unique_ptr stride::ast::parse_fn_declaration( const auto fn_name_tok = set.expect(TokenType::IDENTIFIER, "Expected function name"); const auto& fn_name = fn_name_tok.get_lexeme(); - auto function_context = std::make_shared(context, ContextType::FUNCTION); + auto function_context = std::make_shared(context, ContextType::FUNCTION); GenericParameterList generic_parameter_names = parse_generic_declaration(set); @@ -111,7 +111,7 @@ std::unique_ptr stride::ast::parse_fn_declaration( } std::unique_ptr consume_anonymous_fn_body( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { if (!set.peek_next_eq(TokenType::LBRACE)) @@ -133,7 +133,7 @@ std::unique_ptr consume_anonymous_fn_body( } std::unique_ptr stride::ast::parse_anonymous_fn_expression( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -141,7 +141,7 @@ std::unique_ptr stride::ast::parse_anonymous_fn_expression( std::vector> parameters = {}; int function_flags = SRFLAG_FN_TYPE_ANONYMOUS; - auto function_context = std::make_shared( + auto function_context = std::make_shared( context, ContextType::FUNCTION ); diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp index b1ee613b..335106ee 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_forward_refs.cpp @@ -170,8 +170,8 @@ void IAstFunction::resolve_forward_references( void IAstFunction::collect_free_variables( IAstNode* node, - const std::shared_ptr& lambda_context, - const std::shared_ptr& outer_context, + const std::shared_ptr& lambda_context, + const std::shared_ptr& outer_context, std::vector& captures ) { diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_parameters.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_parameters.cpp index afbe9d6f..7acc46a7 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_parameters.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_parameters.cpp @@ -5,7 +5,7 @@ using namespace stride::ast; void stride::ast::parse_function_parameters( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::vector>& parameters, int& function_flags @@ -57,7 +57,7 @@ void stride::ast::parse_function_parameters( } void stride::ast::parse_standalone_fn_param( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, std::vector>& parameters ) diff --git a/packages/compiler/src/ast/nodes/functions/symbols.cpp b/packages/compiler/src/ast/nodes/functions/symbols.cpp index 38e56694..206656df 100644 --- a/packages/compiler/src/ast/nodes/functions/symbols.cpp +++ b/packages/compiler/src/ast/nodes/functions/symbols.cpp @@ -1,7 +1,7 @@ #include "ast/symbols.h" #include "formatting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include @@ -13,7 +13,7 @@ using namespace stride::ast; * clashes between different contexts. */ Symbol stride::ast::resolve_internal_function_name( - const std::shared_ptr& context, + const std::shared_ptr& context, const SourceFragment& position, const SymbolNameSegments& function_name_segments, const std::vector& parameter_types) diff --git a/packages/compiler/src/ast/nodes/import.cpp b/packages/compiler/src/ast/nodes/import.cpp index 0dd01ca6..70186a01 100644 --- a/packages/compiler/src/ast/nodes/import.cpp +++ b/packages/compiler/src/ast/nodes/import.cpp @@ -2,7 +2,7 @@ #include "errors.h" #include "formatting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include "ast/tokens/token_set.h" @@ -15,7 +15,7 @@ using namespace stride::ast; * It parses the identifiers and returns them as a vector of Symbol objects. */ std::vector> consume_import_submodules( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -64,7 +64,7 @@ std::vector> consume_import_submodules( * Attempts to parse an import expression from the given TokenSet. */ std::unique_ptr stride::ast::parse_import_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/literals.cpp b/packages/compiler/src/ast/nodes/literals.cpp index 5698d870..55a074d1 100644 --- a/packages/compiler/src/ast/nodes/literals.cpp +++ b/packages/compiler/src/ast/nodes/literals.cpp @@ -7,7 +7,7 @@ using namespace stride::ast; std::optional> stride::ast::parse_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { if (auto str = parse_string_literal_optional(context, set); str. @@ -52,7 +52,7 @@ std::optional> stride::ast::parse_literal_optional( } std::optional> stride::ast::parse_string_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -71,7 +71,7 @@ std::optional> stride::ast::parse_string_literal_opt } std::optional> stride::ast::parse_float_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -104,7 +104,7 @@ std::optional> stride::ast::parse_float_literal_opti } std::optional> stride::ast::parse_boolean_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -123,7 +123,7 @@ std::optional> stride::ast::parse_boolean_literal_op } std::optional> stride::ast::parse_char_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { @@ -166,7 +166,7 @@ std::string format_int_conversion_error( } std::optional> stride::ast::parse_integer_literal_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/module.cpp b/packages/compiler/src/ast/nodes/module.cpp index 35f42ab9..1ba07855 100644 --- a/packages/compiler/src/ast/nodes/module.cpp +++ b/packages/compiler/src/ast/nodes/module.cpp @@ -1,6 +1,6 @@ #include "ast/nodes/module.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/symbols.h" #include "ast/nodes/blocks.h" #include "ast/tokens/token.h" @@ -24,7 +24,7 @@ std::string AstModule::to_string() } std::unique_ptr stride::ast::parse_module_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { const auto reference_token = set.expect(TokenType::KEYWORD_MODULE); @@ -42,7 +42,7 @@ std::unique_ptr stride::ast::parse_module_statement( const auto module_name = resolve_internal_name(module_name_segments); - const auto module_context = std::make_shared(module_name, ContextType::MODULE, context); + const auto module_context = std::make_shared(module_name, ContextType::MODULE, context); auto module_body = parse_block(module_context, set); return std::make_unique( diff --git a/packages/compiler/src/ast/nodes/package.cpp b/packages/compiler/src/ast/nodes/package.cpp index 26265b8b..59038039 100644 --- a/packages/compiler/src/ast/nodes/package.cpp +++ b/packages/compiler/src/ast/nodes/package.cpp @@ -19,7 +19,7 @@ bool stride::ast::is_package_declaration(const TokenSet& set) } std::unique_ptr stride::ast::parse_package_declaration( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set) { const size_t initial_offset = set.position(); diff --git a/packages/compiler/src/ast/nodes/return_statement.cpp b/packages/compiler/src/ast/nodes/return_statement.cpp index e226761d..ff612d93 100644 --- a/packages/compiler/src/ast/nodes/return_statement.cpp +++ b/packages/compiler/src/ast/nodes/return_statement.cpp @@ -2,7 +2,7 @@ #include "errors.h" #include "ast/optionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/tokens/token_set.h" #include @@ -13,7 +13,7 @@ using namespace stride::ast; using namespace stride::ast::definition; std::unique_ptr stride::ast::parse_return_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set ) { diff --git a/packages/compiler/src/ast/nodes/types/alias_type.cpp b/packages/compiler/src/ast/nodes/types/alias_type.cpp index 7ea328f0..8d224fee 100644 --- a/packages/compiler/src/ast/nodes/types/alias_type.cpp +++ b/packages/compiler/src/ast/nodes/types/alias_type.cpp @@ -1,5 +1,5 @@ #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include "ast/nodes/types.h" #include "ast/tokens/token.h" @@ -8,7 +8,7 @@ using namespace stride::ast; std::optional> stride::ast::parse_alias_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options ) diff --git a/packages/compiler/src/ast/nodes/types/enum_type.cpp b/packages/compiler/src/ast/nodes/types/enum_type.cpp index cef50449..142758ad 100644 --- a/packages/compiler/src/ast/nodes/types/enum_type.cpp +++ b/packages/compiler/src/ast/nodes/types/enum_type.cpp @@ -1,5 +1,5 @@ #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/blocks.h" #include "ast/nodes/literal_values.h" #include "ast/nodes/types.h" @@ -15,7 +15,7 @@ using namespace stride::ast; AstEnumType::AstEnumType( const SourceFragment& source, - const std::shared_ptr& context, + const std::shared_ptr& context, std::string enum_name, std::vector members, const int flags @@ -31,7 +31,7 @@ AstEnumType::AstEnumType( * */ EnumMemberPair stride::ast::parse_enumerable_member( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, size_t element_index ) @@ -69,7 +69,7 @@ EnumMemberPair stride::ast::parse_enumerable_member( } std::unique_ptr stride::ast::parse_enum_type_definition( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, [[maybe_unused]] VisibilityModifier modifier ) diff --git a/packages/compiler/src/ast/nodes/types/function_type.cpp b/packages/compiler/src/ast/nodes/types/function_type.cpp index 7865a74f..47f44f33 100644 --- a/packages/compiler/src/ast/nodes/types/function_type.cpp +++ b/packages/compiler/src/ast/nodes/types/function_type.cpp @@ -8,7 +8,7 @@ using namespace stride::ast; std::optional> stride::ast::parse_function_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options ) diff --git a/packages/compiler/src/ast/nodes/types/object_type.cpp b/packages/compiler/src/ast/nodes/types/object_type.cpp index 32d94b0d..aaf3fa5f 100644 --- a/packages/compiler/src/ast/nodes/types/object_type.cpp +++ b/packages/compiler/src/ast/nodes/types/object_type.cpp @@ -1,6 +1,6 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/blocks.h" #include "ast/nodes/types.h" #include "ast/tokens/token.h" @@ -12,7 +12,7 @@ using namespace stride::ast; void parse_object_member( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, ObjectTypeMemberList& fields, const TypeParsingOptions& options @@ -65,7 +65,7 @@ void parse_object_member( * */ std::optional> stride::ast::parse_object_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options ) @@ -81,7 +81,7 @@ std::optional> stride::ast::parse_object_type_optional auto struct_body_set = collect_block_required(set, "A struct must have at least 1 member"); ObjectTypeMemberList struct_fields; - const auto struct_type_context = std::make_shared(context, context->get_context_type()); + const auto struct_type_context = std::make_shared(context, context->get_context_type()); // Parse fields while (struct_body_set.has_next()) diff --git a/packages/compiler/src/ast/nodes/types/primitive_type.cpp b/packages/compiler/src/ast/nodes/types/primitive_type.cpp index 21581558..3d553de8 100644 --- a/packages/compiler/src/ast/nodes/types/primitive_type.cpp +++ b/packages/compiler/src/ast/nodes/types/primitive_type.cpp @@ -95,7 +95,7 @@ std::string stride::ast::primitive_type_to_str(const PrimitiveType type, const i } std::optional> stride::ast::parse_primitive_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options ) diff --git a/packages/compiler/src/ast/nodes/types/tuple_type.cpp b/packages/compiler/src/ast/nodes/types/tuple_type.cpp index b792a4ca..ee1b95a1 100644 --- a/packages/compiler/src/ast/nodes/types/tuple_type.cpp +++ b/packages/compiler/src/ast/nodes/types/tuple_type.cpp @@ -9,7 +9,7 @@ using namespace stride::ast; std::optional> stride::ast::parse_tuple_type_optional( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options ) diff --git a/packages/compiler/src/ast/nodes/types/type_definition.cpp b/packages/compiler/src/ast/nodes/types/type_definition.cpp index 6f6f5e02..1d92a1bd 100644 --- a/packages/compiler/src/ast/nodes/types/type_definition.cpp +++ b/packages/compiler/src/ast/nodes/types/type_definition.cpp @@ -1,13 +1,13 @@ #include "ast/nodes/type_definition.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/expression.h" #include "ast/tokens/token_set.h" using namespace stride::ast; std::unique_ptr stride::ast::parse_type_definition( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, VisibilityModifier modifier ) diff --git a/packages/compiler/src/ast/nodes/types/types.cpp b/packages/compiler/src/ast/nodes/types/types.cpp index 5e36844d..15d06784 100644 --- a/packages/compiler/src/ast/nodes/types/types.cpp +++ b/packages/compiler/src/ast/nodes/types/types.cpp @@ -2,7 +2,7 @@ #include "errors.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/literal_values.h" #include "ast/tokens/token_set.h" @@ -13,7 +13,7 @@ using namespace stride::ast; std::unique_ptr stride::ast::parse_type( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, const TypeParsingOptions& options ) diff --git a/packages/compiler/src/ast/nodes/while_loop.cpp b/packages/compiler/src/ast/nodes/while_loop.cpp index f9b37c6c..2ae48e3d 100644 --- a/packages/compiler/src/ast/nodes/while_loop.cpp +++ b/packages/compiler/src/ast/nodes/while_loop.cpp @@ -1,7 +1,7 @@ #include "ast/nodes/while_loop.h" #include "ast/conditionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/nodes/blocks.h" #include "ast/tokens/token_set.h" @@ -18,7 +18,7 @@ void AstWhileLoop::validate() } std::unique_ptr stride::ast::parse_while_loop_statement( - const std::shared_ptr& context, + const std::shared_ptr& context, TokenSet& set, [[maybe_unused]] VisibilityModifier modifier ) @@ -31,7 +31,7 @@ std::unique_ptr stride::ast::parse_while_loop_statement( set.throw_error("Expected while loop condition"); } - const auto while_body_context = std::make_shared( + const auto while_body_context = std::make_shared( context, ContextType::CONTROL_FLOW ); @@ -71,11 +71,11 @@ llvm::Value* AstWhileLoop::codegen( if (this->get_body()) { - ParsingContext::push_control_flow_block(loop_cond_bb, loop_end_bb); + SymbolTable::push_control_flow_block(loop_cond_bb, loop_end_bb); this->get_body()->codegen(module, builder); - ParsingContext::pop_control_flow_block(); + SymbolTable::pop_control_flow_block(); } builder->CreateBr(loop_cond_bb); diff --git a/packages/compiler/src/ast/optionals.cpp b/packages/compiler/src/ast/optionals.cpp index 575f6ff6..b6ef0a3e 100644 --- a/packages/compiler/src/ast/optionals.cpp +++ b/packages/compiler/src/ast/optionals.cpp @@ -1,6 +1,6 @@ #include "ast/optionals.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include #include diff --git a/packages/compiler/src/ast/traversal/expression_visitor.cpp b/packages/compiler/src/ast/traversal/expression_visitor.cpp index 87e689b3..b56897ac 100644 --- a/packages/compiler/src/ast/traversal/expression_visitor.cpp +++ b/packages/compiler/src/ast/traversal/expression_visitor.cpp @@ -1,5 +1,5 @@ #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/type_inference.h" #include "ast/visitor.h" #include "ast/definitions/function_definition.h" diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index 0b3aac48..05dd81c9 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -1,4 +1,4 @@ -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/type_inference.h" #include "ast/visitor.h" #include "ast/nodes/expression.h" diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 4110b259..c8eae770 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -1,7 +1,7 @@ #include "../../../include/ast/traversal.h" #include "ast/casting.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/visitor.h" #include "ast/nodes/ast_node.h" #include "ast/nodes/blocks.h" diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index a3686bfc..98b0e274 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -4,7 +4,7 @@ #include "ast/casting.h" #include "ast/flags.h" #include "ast/generics.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/definitions/function_definition.h" #include "ast/nodes/function_declaration.h" #include "ast/nodes/literal_values.h" diff --git a/packages/compiler/src/stl/runtime_symbols.cpp b/packages/compiler/src/stl/runtime_symbols.cpp index f2e85944..3c4cbe9a 100644 --- a/packages/compiler/src/stl/runtime_symbols.cpp +++ b/packages/compiler/src/stl/runtime_symbols.cpp @@ -1,5 +1,5 @@ #include "ast/flags.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "runtime/stride_runtime.h" #include "runtime/symbols.h" @@ -9,7 +9,7 @@ using namespace stride::runtime; -void stride::runtime::register_runtime_symbols(const std::shared_ptr& context) +void stride::runtime::register_runtime_symbols(const std::shared_ptr& context) { const auto fragment = SourceFragment(nullptr, 0, 0); std::vector> args; diff --git a/packages/compiler/tests/test_chained_accessor.cpp b/packages/compiler/tests/test_chained_accessor.cpp index 997beffe..47092ff8 100644 --- a/packages/compiler/tests/test_chained_accessor.cpp +++ b/packages/compiler/tests/test_chained_accessor.cpp @@ -17,7 +17,7 @@ #include "ast/nodes/expression.h" #include "ast/nodes/literal_values.h" #include "ast/nodes/types.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/symbols.h" #include "errors.h" #include "files.h" @@ -204,12 +204,12 @@ TEST(ChainedAccessor, ArraySubscriptInFunctionArg) class ChainedAccessorTypeTest : public ::testing::Test { protected: - std::shared_ptr context; + std::shared_ptr context; std::shared_ptr source; void SetUp() override { - context = std::make_shared(); + context = std::make_shared(); source = std::make_shared("test.sr", ""); // Define a Point struct: { x: i32, y: f32 } diff --git a/packages/compiler/tests/test_type_casts.cpp b/packages/compiler/tests/test_type_casts.cpp index fffd0f33..d02dec6a 100644 --- a/packages/compiler/tests/test_type_casts.cpp +++ b/packages/compiler/tests/test_type_casts.cpp @@ -1,4 +1,4 @@ -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "utils.h" #include diff --git a/packages/compiler/tests/test_type_inference.cpp b/packages/compiler/tests/test_type_inference.cpp index c52b9cb3..d30c0475 100644 --- a/packages/compiler/tests/test_type_inference.cpp +++ b/packages/compiler/tests/test_type_inference.cpp @@ -3,7 +3,7 @@ #include "ast/nodes/expression.h" #include "ast/nodes/types.h" #include "ast/nodes/function_declaration.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/symbols.h" #include "errors.h" #include "files.h" @@ -17,12 +17,12 @@ using namespace stride::tests; class TypeInferenceTest : public ::testing::Test { protected: - std::shared_ptr context; + std::shared_ptr context; std::shared_ptr source; void SetUp() override { - context = std::make_shared(); + context = std::make_shared(); source = std::make_shared("test.sr", ""); } diff --git a/packages/compiler/tests/test_types.cpp b/packages/compiler/tests/test_types.cpp index 8a1ba39f..6e7e3ae6 100644 --- a/packages/compiler/tests/test_types.cpp +++ b/packages/compiler/tests/test_types.cpp @@ -1,4 +1,4 @@ -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "utils.h" #include diff --git a/packages/compiler/tests/utils.h b/packages/compiler/tests/utils.h index a657d96a..0c00f595 100644 --- a/packages/compiler/tests/utils.h +++ b/packages/compiler/tests/utils.h @@ -2,7 +2,7 @@ #include "files.h" #include "ast/ast.h" -#include "ast/parsing_context.h" +#include "ast/symbol_table.h" #include "ast/visitor.h" #include "ast/nodes/blocks.h" #include "../include/ast/traversal.h" @@ -19,12 +19,12 @@ namespace stride::tests { - inline std::pair, std::shared_ptr> parse_code_with_context( + inline std::pair, std::shared_ptr> parse_code_with_context( const std::string& code) { const auto source = std::make_shared("test.sr", code); auto tokens = ast::tokenizer::tokenize(source); - const auto context = std::make_shared(); + const auto context = std::make_shared(); auto node = parse_sequential(context, tokens); From 001d75d2d75cd9f4c7f17140ede4ad11f9fa2e14 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 25 Mar 2026 08:21:27 +0100 Subject: [PATCH 30/31] Refactor AST structure to use AstBranch for file representation and update related parsing and traversal logic --- packages/compiler/include/ast/ast.h | 49 ++++++++++- packages/compiler/include/ast/traversal.h | 31 +++++-- packages/compiler/include/ast/visitor.h | 20 ++--- packages/compiler/src/ast/ast.cpp | 36 +++++--- .../src/ast/traversal/expression_visitor.cpp | 2 +- .../ast/traversal/function_call_visitor.cpp | 2 +- .../src/ast/traversal/function_visitor.cpp | 2 +- .../src/ast/traversal/import_visitor.cpp | 20 ++--- .../compiler/src/ast/traversal/traversal.cpp | 88 ++++++++----------- packages/compiler/src/compilation/program.cpp | 30 ++++--- 10 files changed, 165 insertions(+), 115 deletions(-) diff --git a/packages/compiler/include/ast/ast.h b/packages/compiler/include/ast/ast.h index 1eee5a35..70d37f1e 100644 --- a/packages/compiler/include/ast/ast.h +++ b/packages/compiler/include/ast/ast.h @@ -1,5 +1,7 @@ #pragma once + #include "symbol_table.h" +#include "ast/nodes/blocks.h" #include #include @@ -19,11 +21,50 @@ namespace stride::ast { using FilePath = std::string; + // A branch, representing a file in the AST + class AstBranch + { + std::shared_ptr _source_file; + std::unique_ptr _node; + + public: + explicit AstBranch( + std::shared_ptr source_file, + std::unique_ptr node + ) : + _source_file(std::move(source_file)), + _node(std::move(node)) {} + + [[nodiscard]] + const std::string& get_file_content() const + { + return this->_source_file->source; + } + + [[nodiscard]] + const std::string& get_file_path() const + { + return this->_source_file->path; + } + + [[nodiscard]] + SourceFile* get_source_file() const + { + return this->_source_file.get(); + } + + [[nodiscard]] + AstBlock* get_node() const + { + return this->_node.get(); + } + }; + class Ast { - std::map> _files{}; + std::map> _branches; - static std::pair> parse_file(const FilePath& path); + static std::pair> parse_file(const FilePath& path); public: static std::unique_ptr parse_files( @@ -34,9 +75,9 @@ namespace stride::ast void print() const; - const std::map>& get_files() + const std::map>& get_branches() { - return this->_files; + return this->_branches; } }; diff --git a/packages/compiler/include/ast/traversal.h b/packages/compiler/include/ast/traversal.h index f43a671f..755aa78e 100644 --- a/packages/compiler/include/ast/traversal.h +++ b/packages/compiler/include/ast/traversal.h @@ -1,7 +1,13 @@ #pragma once +#include "symbol_table.h" + +#include namespace stride::ast { + class AstBranch; + enum class ContextType; + class SymbolTable; class IVisitor; class AstFunctionCall; class AstPackage; @@ -21,20 +27,27 @@ namespace stride::ast /// ensuring that child expression types are available when the parent is visited. class AstNodeTraverser { - public: - void visit(IVisitor* visitor, IAstNode* node); - - void visit_variable_declaration(IVisitor* visitor, AstVariableDeclaration* node); + private: + std::shared_ptr _root_symbol_table; + std::shared_ptr _current_symbol_table; - void visit_for_loop(IVisitor* visitor, AstForLoop* node); + std::string _context_name; + ContextType _current_context_type; - void visit_while_loop(IVisitor* visitor, AstWhileLoop* node); + public: + explicit AstNodeTraverser( + std::shared_ptr root_symbol_table + ) : + _root_symbol_table(std::move(root_symbol_table)), + _current_symbol_table(root_symbol_table), + _current_context_type(ContextType::GLOBAL) {} - void visit_expression(IVisitor* visitor, IAstExpression* node); + void traverse(IVisitor* visitor, const AstBranch* branch); - void visit_conditional_statement(IVisitor* visitor, AstConditionalStatement* node); + private: + void visit(IVisitor* visitor, IAstNode* node); - void visit_return_statement(IVisitor* visitor, const AstReturnStatement* node); + void visit_expression(IVisitor* visitor, IAstExpression* node); void visit_block(IVisitor* visitor, const AstBlock* node); }; diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h index 05c1c6aa..c71b5d2b 100644 --- a/packages/compiler/include/ast/visitor.h +++ b/packages/compiler/include/ast/visitor.h @@ -24,15 +24,15 @@ namespace stride::ast virtual ~IVisitor() = default; /// Called for every expression node, after its sub-expressions have been visited. - virtual void accept(IAstExpression* expr) {}; + virtual void accept_expression_node(IAstExpression* expr) {}; - virtual void accept(IAstFunction* expr) {}; + virtual void accept_function_node(IAstFunction* expr) {}; - virtual void accept(AstImport* node) {} + virtual void accept_import_node(AstImport* node) {} - virtual void accept(AstPackage* node) {} + virtual void accept_package_node(AstPackage* node) {} - virtual void accept(AstFunctionCall* function_call) {} + virtual void accept_function_call(AstFunctionCall* function_call) {} }; /// Visitor that infers and assigns types to every expression node in the AST. @@ -48,19 +48,19 @@ namespace stride::ast /// Infers the type of `expr` and stores it on the node. /// For AstVariableDeclaration: also registers the variable in its context. /// For IAstFunction: also registers the function in its context. - void accept(IAstExpression* expr) override; + void accept_expression_node(IAstExpression* expr) override; }; class FunctionVisitor : public IVisitor { public: - void accept(IAstFunction* function) override; + void accept_function_node(IAstFunction* function) override; }; class FunctionCallVisitor : public IVisitor { public: - void accept(AstFunctionCall* function_call) override; + void accept_function_call(AstFunctionCall* function_call) override; }; class ImportVisitor : public IVisitor @@ -84,9 +84,9 @@ namespace stride::ast this->_current_file_name = file_name; } - void accept(AstImport* node) override; + void accept_import_node(AstImport* node) override; - void accept(AstPackage* node) override; + void accept_package_node(AstPackage* node) override; void cross_register_symbols(Ast* ast) const; }; diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp index 94d4d3fb..1463d286 100644 --- a/packages/compiler/src/ast/ast.cpp +++ b/packages/compiler/src/ast/ast.cpp @@ -25,7 +25,7 @@ std::unique_ptr Ast::parse_files(const std::vector& files) { auto ast = std::make_unique(); - std::vector>>> futures; + std::vector>>> futures; futures.reserve(files.size()); for (const auto& file : files) @@ -45,39 +45,44 @@ std::unique_ptr Ast::parse_files(const std::vector& files) { auto [file_path, node] = future.get(); - ast->_files.emplace(file_path, std::move(node)); + ast->_branches.emplace(file_path, std::move(node)); } return ast; } -std::pair> Ast::parse_file(const FilePath& path) +std::pair> Ast::parse_file(const FilePath& path) { - const auto source_file = read_file(path); + auto source_file = read_file(path); auto tokens = tokenizer::tokenize(source_file); const auto context = std::make_shared(); auto file_node = parse_sequential(context, tokens); - return { path, std::move(file_node) }; + return { + path, std::make_unique( + std::move(source_file), + std::move(file_node) + ) + }; } void Ast::print() const { - for (const auto& [file_name, node] : this->_files) + for (const auto& [file_name, branch] : this->_branches) { std::cout << "--- " << file_name << " --- " << std::endl; - std::cout << node->to_string() << std::endl; + std::cout << branch->get_node()->to_string() << std::endl; } } void Ast::optimize() { - for (const auto& [file_name, node] : this->_files) + for (const auto& [file_name, branch] : this->_branches) { std::vector> new_children; - for (auto& child_ptr : node->get_children()) + for (auto& child_ptr : branch->get_node()->get_children()) { IAstNode* child = child_ptr.get(); @@ -95,14 +100,19 @@ void Ast::optimize() } // If not reducible or reduction failed, move the original node // This requires changing the loop to take ownership or clone - new_children.push_back(std::unique_ptr(child)); + new_children.push_back(std::unique_ptr(child)); } - this->_files[file_name] = std::make_unique( - node->get_source_fragment(), - node->get_context(), + auto reduced_block = std::make_unique( + branch->get_node()->get_source_fragment(), + branch->get_node()->get_context(), std::move(new_children) ); + + this->_branches[file_name] = std::make_unique( + std::unique_ptr(branch->get_source_file()), + std::move(reduced_block) + ); } } diff --git a/packages/compiler/src/ast/traversal/expression_visitor.cpp b/packages/compiler/src/ast/traversal/expression_visitor.cpp index b56897ac..dca8ec04 100644 --- a/packages/compiler/src/ast/traversal/expression_visitor.cpp +++ b/packages/compiler/src/ast/traversal/expression_visitor.cpp @@ -7,7 +7,7 @@ using namespace stride::ast; -void ExpressionVisitor::accept(IAstExpression* expr) +void ExpressionVisitor::accept_expression_node(IAstExpression* expr) { expr->set_type(infer_expression_type(expr)); diff --git a/packages/compiler/src/ast/traversal/function_call_visitor.cpp b/packages/compiler/src/ast/traversal/function_call_visitor.cpp index acea54b5..86e5faa3 100644 --- a/packages/compiler/src/ast/traversal/function_call_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_call_visitor.cpp @@ -4,7 +4,7 @@ using namespace stride::ast; -void FunctionCallVisitor::accept(AstFunctionCall* function_call) +void FunctionCallVisitor::accept_function_call(AstFunctionCall* function_call) { if (function_call->get_generic_type_arguments().empty()) return; diff --git a/packages/compiler/src/ast/traversal/function_visitor.cpp b/packages/compiler/src/ast/traversal/function_visitor.cpp index 05dd81c9..58fa150e 100644 --- a/packages/compiler/src/ast/traversal/function_visitor.cpp +++ b/packages/compiler/src/ast/traversal/function_visitor.cpp @@ -7,7 +7,7 @@ using namespace stride::ast; -void FunctionVisitor::accept(IAstFunction* function) +void FunctionVisitor::accept_function_node(IAstFunction* function) { function->set_type(infer_function_type(function)); diff --git a/packages/compiler/src/ast/traversal/import_visitor.cpp b/packages/compiler/src/ast/traversal/import_visitor.cpp index cf18f07d..199ed786 100644 --- a/packages/compiler/src/ast/traversal/import_visitor.cpp +++ b/packages/compiler/src/ast/traversal/import_visitor.cpp @@ -5,7 +5,7 @@ #include "ast/nodes/import.h" #include "ast/nodes/package.h" -void stride::ast::ImportVisitor::accept(AstImport* node) +void stride::ast::ImportVisitor::accept_import_node(AstImport* node) { const auto& package_identifier = node->get_package_identifier(); const auto& import_identifiers = node->get_import_list(); @@ -42,14 +42,14 @@ void stride::ast::ImportVisitor::accept(AstImport* node) } } -void stride::ast::ImportVisitor::accept(AstPackage* node) +void stride::ast::ImportVisitor::accept_package_node(AstPackage* node) { this->_package_file_mapping[node->get_package_name()].push_back(this->_current_file_name); } void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const { - for (const auto& [file_name, node] : ast->get_files()) + for (const auto& [file_name, branch] : ast->get_branches()) { if (!this->_import_registry.contains(file_name)) continue; @@ -63,7 +63,7 @@ void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const throw parsing_error( ErrorType::REFERENCE_ERROR, std::format("Package '{}' not found", package_name), - node->get_source_fragment() + branch->get_node()->get_source_fragment() ); } // The Ast nodes from which we wish to extract the symbols @@ -76,8 +76,8 @@ void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const std::optional> definition; for (const auto& file_name_with_exports : files_with_exports) { - const auto& node_with_exports = ast->get_files().at(file_name_with_exports); - definition = node_with_exports->get_context()->get_definition_by_internal_name(import_name); + const auto& node_with_exports = ast->get_branches().at(file_name_with_exports); + definition = node_with_exports->get_node()->get_context()->get_definition_by_internal_name(import_name); if (definition.has_value()) break; } @@ -87,7 +87,7 @@ void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const throw parsing_error( ErrorType::REFERENCE_ERROR, std::format("Variable or function '{}' not found in package '{}'", import_name, package_name), - node->get_source_fragment() + branch->get_node()->get_source_fragment() ); } @@ -96,15 +96,15 @@ void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const throw parsing_error( ErrorType::REFERENCE_ERROR, std::format("Variable or function '{}' is not public", import_name), - node->get_source_fragment() + branch->get_node()->get_source_fragment() ); } // Define only if not yet present - if (node->get_context()->get_definition_by_internal_name(definition.value()->get_internal_symbol_name()) + if (branch->get_node()->get_context()->get_definition_by_internal_name(definition.value()->get_internal_symbol_name()) == std::nullopt) { - node->get_context()->define(std::move(definition.value())); + branch->get_node()->get_context()->define(std::move(definition.value())); } } } diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index c8eae770..2abccdec 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -1,5 +1,6 @@ #include "../../../include/ast/traversal.h" +#include "ast/ast.h" #include "ast/casting.h" #include "ast/symbol_table.h" #include "ast/visitor.h" @@ -19,6 +20,15 @@ using namespace stride::ast; +void AstNodeTraverser::traverse(IVisitor* visitor, const AstBranch* branch) +{ + this->_context_name = ""; + this->_current_context_type = ContextType::GLOBAL; + this->_current_symbol_table = this->_root_symbol_table; + + this->visit(visitor, branch->get_node()); +} + void AstNodeTraverser::visit_block(IVisitor* visitor, const AstBlock* node) { if (!node) @@ -39,7 +49,7 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) if (auto* fn = cast_expr(node)) { visit_block(visitor, fn->get_body()); - visitor->accept(fn); + visitor->accept_function_node(fn); } else if (const auto* binary = cast_expr(node)) { @@ -60,7 +70,7 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) for (const auto& arg : fn_call->get_arguments()) visit_expression(visitor, arg.get()); - visitor->accept(fn_call); + visitor->accept_function_call(fn_call); } else if (const auto* array = cast_expr(node)) { @@ -96,12 +106,12 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) else if (auto* function_node = cast_expr(node)) { visit_block(visitor, function_node->get_body()); - visitor->accept(function_node); + visitor->accept_function_node(function_node); } else if (auto* type_cast = cast_expr(node)) { visit_expression(visitor, type_cast->get_value()); - visitor->accept(type_cast); + visitor->accept_expression_node(type_cast); } else if (auto* indirect_call = cast_expr(node)) { @@ -110,51 +120,13 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) visit_expression(visitor, arg.get()); } visit_expression(visitor, indirect_call->get_callee()); - visitor->accept(indirect_call); + visitor->accept_expression_node(indirect_call); } // AstLiteral, AstIdentifier, AstVariadicArgReference, // AstArrayMemberAccessor (base/index already handled above) — leaf nodes, no children. - visitor->accept(node); -} - -void AstNodeTraverser::visit_conditional_statement(IVisitor* visitor, AstConditionalStatement* node) -{ - visit_expression(visitor, node->get_condition()); - visit_block(visitor, node->get_body()); - if (node->get_else_body()) - visit_block(visitor, node->get_else_body()); -} - -void AstNodeTraverser::visit_while_loop(IVisitor* visitor, AstWhileLoop* node) -{ - if (node->get_condition()) - visit_expression(visitor, node->get_condition()); - visit_block(visitor, node->get_body()); -} - -void AstNodeTraverser::visit_for_loop(IVisitor* visitor, AstForLoop* node) -{ - if (node->get_initializer()) - visit_expression(visitor, node->get_initializer()); - if (node->get_condition()) - visit_expression(visitor, node->get_condition()); - if (node->get_incrementor()) - visit_expression(visitor, node->get_incrementor()); - visit_block(visitor, node->get_body()); -} - -void AstNodeTraverser::visit_return_statement(IVisitor* visitor, const AstReturnStatement* node) -{ - if (node->get_return_expression().has_value()) - visit_expression(visitor, node->get_return_expression().value().get()); -} - -void AstNodeTraverser::visit_variable_declaration(IVisitor* visitor, AstVariableDeclaration* node) -{ - visitor->accept(node); - visit_expression(visitor, node->get_initial_value()); + visitor->accept_expression_node(node); } void AstNodeTraverser::visit(IVisitor* visitor, IAstNode* node) @@ -166,19 +138,30 @@ void AstNodeTraverser::visit(IVisitor* visitor, IAstNode* node) if (auto* conditional = dynamic_cast(node)) { - visit_conditional_statement(visitor, conditional); + visit_expression(visitor, conditional->get_condition()); + visit_block(visitor, conditional->get_body()); + if (conditional->get_else_body()) + visit_block(visitor, conditional->get_else_body()); } else if (auto* while_loop = dynamic_cast(node)) { - visit_while_loop(visitor, while_loop); + if (while_loop->get_condition()) + visit_expression(visitor, while_loop->get_condition()); + visit_block(visitor, while_loop->get_body()); } - else if (auto* for_loop = dynamic_cast(node)) + else if (const auto* for_loop = dynamic_cast(node)) { - visit_for_loop(visitor, for_loop); + if (for_loop->get_initializer()) + visit_expression(visitor, for_loop->get_initializer()); + if (for_loop->get_condition()) + visit_expression(visitor, for_loop->get_condition()); + if (for_loop->get_incrementor()) + visit_expression(visitor, for_loop->get_incrementor()); } else if (const auto* return_stmt = dynamic_cast(node)) { - visit_return_statement(visitor, return_stmt); + if (return_stmt->get_return_expression().has_value()) + visit_expression(visitor, return_stmt->get_return_expression().value().get()); } else if (auto* module = dynamic_cast(node)) { @@ -194,14 +177,15 @@ void AstNodeTraverser::visit(IVisitor* visitor, IAstNode* node) } else if (auto* variable_declaration = dynamic_cast(node)) { - visit_variable_declaration(visitor, variable_declaration); + visit_expression(visitor, variable_declaration->get_initial_value()); + visitor->accept_expression_node(variable_declaration); } else if (auto* import_node = dynamic_cast(node)) { - visitor->accept(import_node); + visitor->accept_import_node(import_node); } else if (auto* package_node = dynamic_cast(node)) { - visitor->accept(package_node); + visitor->accept_package_node(package_node); } } diff --git a/packages/compiler/src/compilation/program.cpp b/packages/compiler/src/compilation/program.cpp index 1e7d9964..6800e603 100644 --- a/packages/compiler/src/compilation/program.cpp +++ b/packages/compiler/src/compilation/program.cpp @@ -47,7 +47,9 @@ std::unique_ptr Program::prepare_module( llvm::IRBuilder<> builder(context); - ast::AstNodeTraverser traverser; + auto global_symbol_table = std::make_shared(); + + ast::AstNodeTraverser traverser(global_symbol_table); ast::ExpressionVisitor expression_visitor; ast::FunctionVisitor function_visitor; ast::FunctionCallVisitor function_call_visitor; @@ -56,52 +58,52 @@ std::unique_ptr Program::prepare_module( // // First step - Cross-file symbol registration (imports and function signatures) // - for (const auto& [file_name, node] : this->_ast->get_files()) + for (const auto& [file_name, branch] : this->_ast->get_branches()) { // Populate own symbol table with stride runtime symbols // These are externally available functions that are linked after codegen - runtime::register_runtime_symbols(node->get_context()); + runtime::register_runtime_symbols(branch->get_node()->get_context()); // Resolve imports and populate local registry - Used for cross registration step import_visitor.set_current_file_name(file_name); - traverser.visit_block(&import_visitor, node.get()); + traverser.traverse(&import_visitor, branch.get()); // Ensures functions are defined in our symbol table - traverser.visit_block(&function_visitor, node.get()); + traverser.traverse(&function_visitor, branch.get()); } import_visitor.cross_register_symbols(this->_ast.get()); // // Generic resolution - Instantiates functions that have generic arguments // - for (const auto& node : this->_ast->get_files() | std::views::values) + for (const auto& node : this->_ast->get_branches() | std::views::values) { - traverser.visit_block(&function_call_visitor, node.get()); + traverser.traverse(&function_call_visitor, node.get()); } // // Third step - Type resolution and symbol forward declarations // - for (const auto& node : this->_ast->get_files() | std::views::values) + for (const auto& node : this->_ast->get_branches() | std::views::values) { // Type checker - this must be executed after all external symbols have been populated - traverser.visit_block(&expression_visitor, node.get()); + traverser.traverse(&expression_visitor, node.get()); } - for (const auto& node : this->_ast->get_files() | std::views::values) + for (const auto& branch : this->_ast->get_branches() | std::views::values) { // Resolving forward references - Ensures symbols certain symbols are available before implementation - node->resolve_forward_references( + branch->get_node()->resolve_forward_references( module.get(), &builder ); } /// --- Final step - LLVM IR validation and code generation - for (const auto& node : this->_ast->get_files() | std::views::values) + for (const auto& node : this->_ast->get_branches() | std::views::values) { - node->validate(); - node->codegen(module.get(), &builder); + node->get_node()->validate(); + node->get_node()->codegen(module.get(), &builder); } if (llvm::verifyModule(*module, &llvm::errs())) From 56b954f82cf33742c66fab8c3fcb2c44a7eafbc3 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 25 Mar 2026 09:35:02 +0100 Subject: [PATCH 31/31] Enhance generic function support by adding a visitor for concrete type resolution and improving AST traversal methods --- example.sr | 6 +- packages/compiler/include/ast/ast.h | 5 ++ packages/compiler/include/ast/traversal.h | 2 + .../ast/nodes/expressions/unary_operation.cpp | 7 +-- .../function_declaration_codegen.cpp | 5 ++ .../function_declaration_validation.cpp | 59 +++++++++++++++++++ .../compiler/src/ast/traversal/traversal.cpp | 13 ++-- packages/compiler/src/ast/type_inference.cpp | 10 ++++ packages/compiler/tests/utils.h | 21 ++++--- 9 files changed, 104 insertions(+), 24 deletions(-) diff --git a/example.sr b/example.sr index 24a6596f..402373c7 100644 --- a/example.sr +++ b/example.sr @@ -5,14 +5,14 @@ import System::{ type Array = { elements: T[]; count: i32; - at: (i32) -> T; + // at: (i32) -> T; }; fn arrayOf(elements: T[], count: i32): Array { return Array::{ elements, count, - at: (index: i32): T -> elements[index] + // at: (index: i32): T -> elements[index] }; } @@ -21,7 +21,7 @@ type Callback = () -> void; fn makeCb(): Callback { const names = arrayOf(["Toyota", "Honda", "Ford", "Toyota"], 4); - return (): void -> io::print("\x1b[32mDriving the car: %s", names.at(1)); + return (): void -> io::print("\x1b[32mDriving the car: %s", names.elements[1]); } fn main(): i32 { diff --git a/packages/compiler/include/ast/ast.h b/packages/compiler/include/ast/ast.h index 70d37f1e..e8b24654 100644 --- a/packages/compiler/include/ast/ast.h +++ b/packages/compiler/include/ast/ast.h @@ -58,6 +58,11 @@ namespace stride::ast { return this->_node.get(); } + + std::unique_ptr release_node() + { + return std::move(this->_node); + } }; class Ast diff --git a/packages/compiler/include/ast/traversal.h b/packages/compiler/include/ast/traversal.h index 755aa78e..f5e5187f 100644 --- a/packages/compiler/include/ast/traversal.h +++ b/packages/compiler/include/ast/traversal.h @@ -44,6 +44,8 @@ namespace stride::ast void traverse(IVisitor* visitor, const AstBranch* branch); + void traverse_block(IVisitor* visitor, AstBlock* block); + private: void visit(IVisitor* visitor, IAstNode* node); diff --git a/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp b/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp index fddb21a3..4278f539 100644 --- a/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp +++ b/packages/compiler/src/ast/nodes/expressions/unary_operation.cpp @@ -83,14 +83,9 @@ std::optional> stride::ast::parse_binary_unary_o } const auto rhs_expr_pos = rhs_expr->get()->get_source_fragment(); - const auto source = SourceFragment( - op_type_pos.source, - op_type_pos.offset, - rhs_expr_pos.offset + rhs_expr_pos.length - op_type_pos.offset - ); return std::make_unique( - source, + SourceFragment::join(op_type_pos,rhs_expr_pos), context, op_type.value(), std::move(rhs_expr.value()) diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp index 39d8b7f8..168a33d0 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_codegen.cpp @@ -19,6 +19,11 @@ llvm::Value* IAstFunction::codegen( if (implementations.empty()) { + // A generic function that is never instantiated (never called with concrete type + // arguments) has nothing to codegen — this is valid and not an error. + if (this->is_generic()) + return nullptr; + throw parsing_error( ErrorType::COMPILATION_ERROR, std::format( diff --git a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp index 381af1b4..4ae4c303 100644 --- a/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp +++ b/packages/compiler/src/ast/nodes/functions/declaration/function_declaration_validation.cpp @@ -1,8 +1,53 @@ #include "ast/casting.h" #include "ast/definitions/function_definition.h" +#include "ast/generics.h" #include "ast/nodes/conditional_statement.h" #include "ast/nodes/function_declaration.h" #include "ast/nodes/return_statement.h" +#include "ast/type_inference.h" +#include "ast/visitor.h" + +namespace +{ + using namespace stride::ast; + + /// Visitor used to set concrete types on a cloned generic function body. + /// Unlike ExpressionVisitor it: + /// - Resolves generic param names (e.g. T → i32) in the inferred type via resolve_generics. + /// - Uses overwrite=true when registering variable declarations so that re-running on a + /// cloned body (sharing the same SymbolTable as the original) doesn't throw "already defined". + class GenericBodyExpressionVisitor final : public IVisitor + { + const GenericParameterList& _param_names; + const GenericTypeList& _instantiated_types; + + public: + GenericBodyExpressionVisitor( + const GenericParameterList& param_names, + const GenericTypeList& instantiated_types + ) : + _param_names(param_names), + _instantiated_types(instantiated_types) {} + + void accept_expression_node(IAstExpression* expr) override + { + auto inferred = infer_expression_type(expr); + expr->set_type(resolve_generics(inferred.get(), _param_names, _instantiated_types)); + + if (const auto* var_decl = dynamic_cast(expr)) + { + const auto resolved_type = var_decl->get_type(); + resolved_type->set_flags(var_decl->get_flags()); + var_decl->get_context()->define_variable( + var_decl->get_symbol(), + resolved_type->clone_ty(), + var_decl->get_visibility(), + true // overwrite existing registration from the original body traversal + ); + } + } + }; +} // anonymous namespace using namespace stride::ast; @@ -70,6 +115,20 @@ void IAstFunction::validate() EMPTY_GENERIC_PARAMETER_LIST // Omit generics - They've been resolved ); + // Run type inference on the cloned body using the concrete parameter types now + // in context (updated above). This is necessary because clone() does not + // preserve the _type field set during the earlier traversal pass. + // Use GenericBodyExpressionVisitor so that remaining generic param names + // (e.g. T in `Wrapper::{...}`) are also substituted with concrete types. + { + AstNodeTraverser traverser(this->get_context()); + GenericBodyExpressionVisitor generic_expr_visitor( + this->_generic_parameters, + instantiated_generic_types + ); + traverser.traverse_block(&generic_expr_visitor, node->get_body()); + } + validate_candidate(node.get()); // Restore original parameter types in the context diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 2abccdec..48a19f76 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -29,6 +29,11 @@ void AstNodeTraverser::traverse(IVisitor* visitor, const AstBranch* branch) this->visit(visitor, branch->get_node()); } +void AstNodeTraverser::traverse_block(IVisitor* visitor, AstBlock* block) +{ + this->visit_block(visitor, block); +} + void AstNodeTraverser::visit_block(IVisitor* visitor, const AstBlock* node) { if (!node) @@ -149,7 +154,7 @@ void AstNodeTraverser::visit(IVisitor* visitor, IAstNode* node) visit_expression(visitor, while_loop->get_condition()); visit_block(visitor, while_loop->get_body()); } - else if (const auto* for_loop = dynamic_cast(node)) + else if (auto* for_loop = dynamic_cast(node)) { if (for_loop->get_initializer()) visit_expression(visitor, for_loop->get_initializer()); @@ -157,6 +162,7 @@ void AstNodeTraverser::visit(IVisitor* visitor, IAstNode* node) visit_expression(visitor, for_loop->get_condition()); if (for_loop->get_incrementor()) visit_expression(visitor, for_loop->get_incrementor()); + visit_block(visitor, for_loop->get_body()); } else if (const auto* return_stmt = dynamic_cast(node)) { @@ -175,11 +181,6 @@ void AstNodeTraverser::visit(IVisitor* visitor, IAstNode* node) { visit_expression(visitor, expr); } - else if (auto* variable_declaration = dynamic_cast(node)) - { - visit_expression(visitor, variable_declaration->get_initial_value()); - visitor->accept_expression_node(variable_declaration); - } else if (auto* import_node = dynamic_cast(node)) { visitor->accept_import_node(import_node); diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 98b0e274..737e311a 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -380,6 +380,16 @@ std::unique_ptr stride::ast::infer_variable_declaration_type( return value_type->clone_ty(); } + // If the annotated type is an unresolved alias (i.e. a generic parameter like T with no + // type definition in scope), accept the value type directly. This allows constructs like + // `const result: T = x + x;` inside a generic function body, where T is a placeholder. + if (const auto* alias = cast_type(annotated_type.value()); + alias && !alias->get_type_definition().has_value()) + { + value_type->set_flags(declaration->get_flags()); + return value_type->clone_ty(); + } + if (annotated_type.value()->equals(value_type.get())) { return annotated_type.value()->clone_ty(); diff --git a/packages/compiler/tests/utils.h b/packages/compiler/tests/utils.h index 0c00f595..520bed7a 100644 --- a/packages/compiler/tests/utils.h +++ b/packages/compiler/tests/utils.h @@ -26,27 +26,30 @@ namespace stride::tests auto tokens = ast::tokenizer::tokenize(source); const auto context = std::make_shared(); - auto node = parse_sequential(context, tokens); + auto branch = std::make_unique(source, parse_sequential(context, tokens)); - ast::AstNodeTraverser traverser; + ast::AstNodeTraverser traverser(context); ast::ExpressionVisitor expression_visitor; ast::FunctionVisitor function_visitor; ast::FunctionCallVisitor function_call_visitor; ast::ImportVisitor import_visitor; - runtime::register_runtime_symbols(node->get_context()); + runtime::register_runtime_symbols(branch->get_node()->get_context()); import_visitor.set_current_file_name("test.sr"); - traverser.visit_block(&import_visitor, node.get()); + traverser.traverse(&import_visitor, branch.get()); - traverser.visit_block(&function_visitor, node.get()); - traverser.visit_block(&function_call_visitor, node.get()); + traverser.traverse(&function_visitor, branch.get()); + traverser.traverse(&function_call_visitor, branch.get()); - traverser.visit_block(&expression_visitor, node.get()); + traverser.traverse(&expression_visitor, branch.get()); - node->validate(); + branch->get_node()->validate(); - return std::make_pair(std::move(node), context); + return std::make_pair( + branch->release_node(), + context + ); } inline std::unique_ptr parse_code(const std::string& code)