From ff546b766c95c73f45eb3ca199537ceed65a1fe6 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 11 Mar 2026 19:16:03 +0100 Subject: [PATCH 01/30] Add enum statement support to Intellij plugin --- example.sr | 6 ++++++ .../src/ast/nodes/expressions/object_initializer.cpp | 5 +++-- packages/compiler/src/ast/nodes/types/object_type.cpp | 4 ++-- packages/compiler/src/stl/stride_runtime.cpp | 3 +-- .../stride-plugin-intellij/src/main/grammars/Stride.bnf | 6 ++++++ .../stride-plugin-intellij/src/main/grammars/Stride.flex | 1 + .../stride/intellij/highlight/StrideSyntaxHighlighter.kt | 2 +- 7 files changed, 20 insertions(+), 7 deletions(-) diff --git a/example.sr b/example.sr index c4b763e3..4ce09d28 100644 --- a/example.sr +++ b/example.sr @@ -3,6 +3,12 @@ import System::{ IO::Read }; +enum SomeEnum { + Variant1, + Variant2, + Variant3 +} + type Array = { length: i32; data: T[]; diff --git a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp index a2a977d4..554b2650 100644 --- a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp @@ -336,6 +336,7 @@ llvm::Value* AstObjectInitializer::codegen( } else { + // FIXME: // For non-pointers, if they are both structs but have different names, // LLVM doesn't allow bitcast. However, they should have been the same. // If we reach here, we might need a more complex conversion or @@ -359,9 +360,9 @@ llvm::Value* AstObjectInitializer::codegen( for (size_t i = 0; i < dynamic_members.size(); ++i) { auto* member_val = dynamic_members[i]; - auto* target_type = struct_type->getElementType(i); - if (member_val->getType() != target_type) + if (auto* target_type = struct_type->getElementType(i); + member_val->getType() != target_type) { if (member_val->getType()->isPointerTy() && target_type->isPointerTy()) { diff --git a/packages/compiler/src/ast/nodes/types/object_type.cpp b/packages/compiler/src/ast/nodes/types/object_type.cpp index 50b0ed2b..37f65c5a 100644 --- a/packages/compiler/src/ast/nodes/types/object_type.cpp +++ b/packages/compiler/src/ast/nodes/types/object_type.cpp @@ -145,9 +145,9 @@ std::string AstObjectType::get_internalized_name() { std::string scoped_name = resolve_internal_name({ this->get_context()->get_name(), this->get_type_name() }); - for (const auto& [name, type] : this->_members) + for (const auto& generic : this->_instantiated_generics) { - scoped_name += "$" + type->get_type_name(); + scoped_name += "$" + generic->get_type_name(); } return scoped_name; diff --git a/packages/compiler/src/stl/stride_runtime.cpp b/packages/compiler/src/stl/stride_runtime.cpp index d3bba571..ddf74e29 100644 --- a/packages/compiler/src/stl/stride_runtime.cpp +++ b/packages/compiler/src/stl/stride_runtime.cpp @@ -9,8 +9,7 @@ extern "C" { int _printf_internal(const char* format, va_list args) { - int r = vprintf(format, args); - return r; + return vprintf(format, args); } uint64_t _system_time_ns_internal() diff --git a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf index 9052d6ed..445c0907 100644 --- a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf +++ b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf @@ -16,6 +16,7 @@ FN='fn' NIL='nil' CONST='const' + ENUM='enum' LET='let' CONTINUE='continue' BREAK='break' @@ -111,6 +112,7 @@ private StandaloneItem ::= ( FunctionDeclaration | ModuleStatement | TypeDefinition + | EnumDefinition | VariableDeclarationStatement | ExternFunctionDeclaration | SEMICOLON @@ -129,6 +131,10 @@ FunctionParameter ::= IDENTIFIER COLON Type TypeDefinition ::= KW_TYPE IDENTIFIER [GenericTypeArguments] EQ Type SEMICOLON +EnumDefinition ::= ENUM IDENTIFIER LBRACE EnumMember (COMMA EnumMember)* [COMMA /* trailing */] RBRACE + +EnumMember ::= IDENTIFIER | (IDENTIFIER EQ Expression) + // Ensure structs have at least 1 field ObjectDefinitionFields ::= ObjectField+ // field: i32 diff --git a/packages/stride-plugin-intellij/src/main/grammars/Stride.flex b/packages/stride-plugin-intellij/src/main/grammars/Stride.flex index f94b45e1..ce8837da 100644 --- a/packages/stride-plugin-intellij/src/main/grammars/Stride.flex +++ b/packages/stride-plugin-intellij/src/main/grammars/Stride.flex @@ -47,6 +47,7 @@ STRING_LITERAL=\"([^\\\"\r\n]|\\[^\r\n])*\" "break" { return StrideTypes.BREAK; } "continue" { return StrideTypes.CONTINUE; } "const" { return StrideTypes.CONST; } + "enum" { return StrideTypes.ENUM; } "type" { return StrideTypes.KW_TYPE; } "let" { return StrideTypes.LET; } "extern" { return StrideTypes.EXTERN; } diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideSyntaxHighlighter.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideSyntaxHighlighter.kt index 468a893f..74417ae8 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideSyntaxHighlighter.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideSyntaxHighlighter.kt @@ -56,7 +56,7 @@ class StrideSyntaxHighlighter : SyntaxHighlighterBase() { override fun getTokenHighlights(tokenType: IElementType): Array { return when (tokenType) { StrideTypes.MODULE, StrideTypes.PACKAGE, StrideTypes.PUBLIC, - StrideTypes.FN, StrideTypes.CONST, + StrideTypes.FN, StrideTypes.CONST, StrideTypes.ENUM, StrideTypes.BREAK, StrideTypes.CONTINUE, StrideTypes.IMPORT, StrideTypes.LET, StrideTypes.EXTERN, StrideTypes.AS, StrideTypes.RETURN, StrideTypes.FOR, StrideTypes.WHILE, From b4cc6c0b5ee175fd8a4f60c213751696d0712dd6 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 11 Mar 2026 20:36:56 +0100 Subject: [PATCH 02/30] Refactor AstAliasType to simplify underlying type retrieval and improve error handling --- example.sr | 34 ++++++-- .../compiler/include/ast/nodes/enumerables.h | 14 ++- packages/compiler/include/ast/nodes/types.h | 10 +-- .../compiler/include/ast/parsing_context.h | 2 +- .../src/ast/context/type_registry.cpp | 2 +- packages/compiler/src/ast/generics.cpp | 18 +++- .../compiler/src/ast/nodes/enumerables.cpp | 38 ++++++-- .../expressions/array_member_accessor.cpp | 14 +-- .../ast/nodes/expressions/member_accessor.cpp | 2 +- .../nodes/expressions/object_initializer.cpp | 12 +-- .../src/ast/nodes/functions/function_call.cpp | 8 +- .../src/ast/nodes/types/alias_type.cpp | 86 +++++++++---------- .../src/ast/nodes/types/function_type.cpp | 5 +- .../src/ast/nodes/types/type_metadata.cpp | 15 +--- .../compiler/src/ast/nodes/types/types.cpp | 31 +++---- packages/compiler/src/ast/type_inference.cpp | 19 +--- packages/compiler/tests/test_generics.cpp | 30 +++++++ packages/standard-library/system/io.sr | 8 +- 18 files changed, 187 insertions(+), 161 deletions(-) diff --git a/example.sr b/example.sr index 4ce09d28..0fed002a 100644 --- a/example.sr +++ b/example.sr @@ -1,7 +1,4 @@ -import System::{ - IO::Print, - IO::Read -}; +import system::{ io::print }; enum SomeEnum { Variant1, @@ -30,11 +27,36 @@ fn make_car(make: string, model: string, year: i32): SomeCar { return SomeCar::{ make, model, year }; } +fn make_person(name: string, age: i32, cars: Array): SomePerson { + return SomePerson::{ + name, + age, + cars + }; +} + fn main(): i32 { + const me = make_person( + "Alice", + 30, + Array::{ + length: 1, + data: [ + make_car("Toyota", "Corolla", 2020) + ] + } + ); + + + const my_cars = me.cars.data; + const my_car = my_cars[0]; + + io::print("I am %s, age %d\n", me.name, me.age); + io::print("My car is a %s %s from %d", my_car.make, my_car.model, my_car.year); - const my_car = make_car("Toyota", "Corolla", 2020); + // const some_enum_val: SomeEnum = SomeEnum::Variant2; - IO::Print("My car is a %s %s from %d", my_car.make, my_car.model, my_car.year); + // io::print("The value of some_enum_val is: %d", some_enum_val); return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/nodes/enumerables.h b/packages/compiler/include/ast/nodes/enumerables.h index 1d8ce5b8..480fdc33 100644 --- a/packages/compiler/include/ast/nodes/enumerables.h +++ b/packages/compiler/include/ast/nodes/enumerables.h @@ -40,7 +40,10 @@ namespace stride::ast std::string to_string() override; - llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override { return nullptr; } + llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override + { + return nullptr; + } std::unique_ptr clone() override; }; @@ -77,14 +80,19 @@ namespace stride::ast std::string to_string() override; - llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override { return nullptr; } + llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override + { + return nullptr; + } std::unique_ptr clone() override; }; std::unique_ptr parse_enumerable_member( const std::shared_ptr& context, - TokenSet& set); + TokenSet& set, + size_t element_index + ); std::unique_ptr parse_enumerable_declaration( const std::shared_ptr& context, diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index 661c208f..dfbff849 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -313,18 +313,14 @@ namespace stride::ast [[nodiscard]] std::optional> get_reference_type() const; - /// Returns the cached underlying type pointer (resolved including generics), - /// populating the cache via get_underlying_type() if needed. - /// The returned pointer is valid as long as this AstAliasType is alive. - [[nodiscard]] - IAstType* get_underlying_type_ptr(); - /// Returns the super base type of the reference, e.g., if we have: /// type RootType = i32; /// type MidType = RootType; /// type LeafType = MidType; /// Then, calling `get_base_reference_type` on `LeafType` will return `i32`. - [[nodiscard]] std::optional> get_underlying_type(); + /// Throws an exception if no underlying type can be resolved. + [[nodiscard]] + IAstType* get_underlying_type(); [[nodiscard]] std::optional get_type_definition() const; diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h index 21f9a14c..144c711e 100644 --- a/packages/compiler/include/ast/parsing_context.h +++ b/packages/compiler/include/ast/parsing_context.h @@ -350,7 +350,7 @@ namespace stride::ast ) const; [[nodiscard]] - std::optional get_struct_type(const std::string& name) const; + std::optional get_object_type(const std::string& name) const; [[nodiscard]] const definition::IdentifiableSymbolDef* get_symbol_def( diff --git a/packages/compiler/src/ast/context/type_registry.cpp b/packages/compiler/src/ast/context/type_registry.cpp index f2f04f59..ce6571fe 100644 --- a/packages/compiler/src/ast/context/type_registry.cpp +++ b/packages/compiler/src/ast/context/type_registry.cpp @@ -52,7 +52,7 @@ std::optional ParsingContext::get_type_definition(const std::st /// Gets the root struct type layout for name. /// Will recursively look up the parent struct definition if name is a reference struct type. -std::optional ParsingContext::get_struct_type(const std::string& name) const +std::optional ParsingContext::get_object_type(const std::string& name) const { const auto type_def = get_type_definition(name); diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index 0b0b4a0a..f2fbccd9 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -118,10 +118,22 @@ std::unique_ptr stride::ast::resolve_generics( } GenericTypeList resolved_generics; - resolved_generics.reserve(object_type->get_instantiated_generics().size()); - for (const auto& gen : object_type->get_instantiated_generics()) + if (const auto& instantiated_generics = object_type->get_instantiated_generics(); + instantiated_generics.empty() && !instantiated_types.empty() && !param_names.empty()) { - resolved_generics.push_back(resolve_generics(gen.get(), param_names, instantiated_types)); + resolved_generics.reserve(instantiated_types.size()); + for (const auto& gen : instantiated_types) + { + resolved_generics.push_back(gen->clone_ty()); + } + } + else + { + resolved_generics.reserve(instantiated_generics.size()); + for (const auto& gen : instantiated_generics) + { + resolved_generics.push_back(resolve_generics(gen.get(), param_names, instantiated_types)); + } } return std::make_unique( diff --git a/packages/compiler/src/ast/nodes/enumerables.cpp b/packages/compiler/src/ast/nodes/enumerables.cpp index de6576e1..bda66c0e 100644 --- a/packages/compiler/src/ast/nodes/enumerables.cpp +++ b/packages/compiler/src/ast/nodes/enumerables.cpp @@ -18,7 +18,9 @@ using namespace stride::ast::definition; */ std::unique_ptr stride::ast::parse_enumerable_member( const std::shared_ptr& context, - TokenSet& set) + TokenSet& set, + size_t element_index +) { const auto member_name_tok = set.expect(TokenType::IDENTIFIER); auto member_sym = member_name_tok.get_lexeme(); @@ -27,10 +29,32 @@ std::unique_ptr stride::ast::parse_enumerable_member( Symbol( member_name_tok.get_source_fragment(), context->get_name(), - member_sym, - /* internal_name = */ - member_sym), - SymbolType::ENUM_MEMBER); + member_sym + ), + SymbolType::ENUM_MEMBER + ); + + if (!set.has_next() || !set.peek_next_eq(TokenType::COLON)) + { + // Using index as element value + if (set.has_next() && set.peek_next_eq(TokenType::COMMA)) + { + // Consume optional trailing comma + set.next(); + } + + return std::make_unique( + member_name_tok.get_source_fragment(), + context, + std::move(member_sym), + std::make_unique( + member_name_tok.get_source_fragment(), + context, + PrimitiveType::INT32, + element_index + ) + ); + } set.expect(TokenType::COLON, "Expected a colon after enum member name"); @@ -76,9 +100,9 @@ std::unique_ptr stride::ast::parse_enumerable_declaration( context, context->get_context_type()); - while (enum_body_subset.has_next()) + for (size_t i = 0; enum_body_subset.has_next(); ++i) { - members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset)); + members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset, i)); } return std::make_unique( diff --git a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp index 75704bc1..fbdfab64 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp @@ -82,19 +82,7 @@ llvm::Value* AstArrayMemberAccessor::codegen( if (const auto named_ty = cast_type(array_iden_type.get())) { - if (const auto reference_type = named_ty->get_underlying_type(); - reference_type.has_value()) - { - array_iden_type = reference_type.value()->clone_ty(); - } - else - { - throw parsing_error( - ErrorType::SEMANTIC_ERROR, - "Array member accessor used on non-reference type", - this->get_source_fragment() - ); - } + array_iden_type = named_ty->get_underlying_type()->clone_ty(); } llvm::Value* base_ptr = this->_array_identifier->codegen(module, builder); diff --git a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp index 0cb21e98..936686a5 100644 --- a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp @@ -113,7 +113,7 @@ llvm::Value* AstMemberAccessor::codegen_global_member_accessor( for (const auto& accessor : this->_members) { - auto struct_def_opt = this->get_context()->get_struct_type(current_struct_name); + auto struct_def_opt = this->get_context()->get_object_type(current_struct_name); if (!struct_def_opt.has_value()) return nullptr; diff --git a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp index 554b2650..3d400aac 100644 --- a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp @@ -110,17 +110,7 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ { const auto underlying_type = alias_def->get_underlying_type(); - if (!underlying_type.has_value()) - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format("Named type '{}' does not reference another type, cannot be used as object type", - alias_def->get_name()), - this->get_source_fragment() - ); - } - - if (auto* object_def = cast_type(underlying_type.value().get())) + if (auto* object_def = cast_type(underlying_type)) { auto resolved_type = instantiate_generic_type(this, object_def, type_def.value()); diff --git a/packages/compiler/src/ast/nodes/functions/function_call.cpp b/packages/compiler/src/ast/nodes/functions/function_call.cpp index 43fcea20..fa1c9d5d 100644 --- a/packages/compiler/src/ast/nodes/functions/function_call.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_call.cpp @@ -357,14 +357,10 @@ llvm::Value* AstFunctionCall::codegen_anonymous_function_call( auto base_type = var_def->get_type()->clone_ty(); if (auto* alias_ty = cast_type(base_type.get())) { - if (const auto resolved = alias_ty->get_underlying_type(); - resolved.has_value()) - { - base_type = resolved.value()->clone_ty(); - } + base_type = alias_ty->get_underlying_type()->clone_ty(); } - if (auto* fn_type = dynamic_cast(base_type.get())) + if (const auto* fn_type = dynamic_cast(base_type.get())) { // First: check if the variable's internal name maps to a named function in the // symbol table with a matching type signature. If so, call it directly without diff --git a/packages/compiler/src/ast/nodes/types/alias_type.cpp b/packages/compiler/src/ast/nodes/types/alias_type.cpp index 37116a81..7730225d 100644 --- a/packages/compiler/src/ast/nodes/types/alias_type.cpp +++ b/packages/compiler/src/ast/nodes/types/alias_type.cpp @@ -64,12 +64,10 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr { if (auto* named = cast_type(type.get())) { - if (auto underlying = named->get_underlying_type(); underlying.has_value()) - { - return std::move(underlying.value()); - } + return named->get_underlying_type()->clone_ty(); } - else if (const auto* array = cast_type(type.get())) + + if (const auto* array = cast_type(type.get())) { auto element_type = array->get_element_type()->clone_ty(); auto resolved_element = resolve_nested_underlying_types(std::move(element_type)); @@ -82,7 +80,8 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr array->get_flags() ); } - else if (auto* object_type = cast_type(type.get())) + + if (auto* object_type = cast_type(type.get())) { ObjectTypeMemberList resolved_members; for (const auto& [name, member_type] : object_type->get_members()) @@ -90,15 +89,23 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr resolved_members.emplace_back(name, resolve_nested_underlying_types(member_type->clone_ty())); } + GenericTypeList resolved_generics; + for (const auto& gen : object_type->get_instantiated_generics()) + { + resolved_generics.push_back(resolve_nested_underlying_types(gen->clone_ty())); + } + return std::make_unique( object_type->get_source_fragment(), object_type->get_context(), object_type->get_type_name(), std::move(resolved_members), - object_type->get_flags() + object_type->get_flags(), + std::move(resolved_generics) ); } - else if (const auto* tuple = cast_type(type.get())) + + if (const auto* tuple = cast_type(type.get())) { std::vector> resolved_members; for (const auto& member : tuple->get_members()) @@ -113,7 +120,8 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr tuple->get_flags() ); } - else if (auto* func = cast_type(type.get())) + + if (const auto* func = cast_type(type.get())) { std::vector> resolved_params; for (const auto& param : func->get_parameter_types()) @@ -135,28 +143,25 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr return std::move(type); } -IAstType* AstAliasType::get_underlying_type_ptr() -{ - if (!this->_underlying_type) - { - get_underlying_type(); - } - return this->_underlying_type.get(); -} - -std::optional> AstAliasType::get_underlying_type() +IAstType* AstAliasType::get_underlying_type() { // Prevent reinstantiating type if it's a complex type if (this->_underlying_type != nullptr) { - return this->_underlying_type->clone_ty(); + return this->_underlying_type.get(); } const auto& reference_type_definition = this->get_type_definition(); if (!reference_type_definition.has_value()) { - return std::nullopt; + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "Could not find definition for type '{}'", + this->get_name()), + this->get_source_fragment() + ); } std::unique_ptr base_type = this->is_generic_overload() @@ -165,7 +170,13 @@ std::optional> AstAliasType::get_underlying_type() if (!base_type) { - return std::nullopt; + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format( + "Could not find underlying type for type '{}'", + this->get_name()), + this->get_source_fragment() + ); } int recursion_guard = 0; @@ -225,24 +236,21 @@ std::optional> AstAliasType::get_underlying_type() this->_underlying_type = std::move(base_type); - return this->_underlying_type->clone_ty(); + return this->_underlying_type.get(); } bool AstAliasType::is_castable_to_impl(IAstType* other) { const auto self_reference_type = get_underlying_type(); - if (!self_reference_type.has_value()) - return false; - // Check our base type is a primitive, and whether that type is castable to `other` if (auto* other_primitive = cast_type(other)) { - return self_reference_type.value()->is_castable_to(other_primitive); + return self_reference_type->is_castable_to(other_primitive); } // A named type should be castable to its base type and vice versa - if (self_reference_type.value()->is_castable_to(other)) + if (self_reference_type->is_castable_to(other)) { return true; } @@ -250,10 +258,9 @@ bool AstAliasType::is_castable_to_impl(IAstType* other) // Final case would be to check whether both base types are the same if (auto* other_alias_ty = cast_type(other)) { - if (const auto second_reference_type = other_alias_ty->get_underlying_type(); - second_reference_type.has_value()) + if (const auto second_reference_type = other_alias_ty->get_underlying_type()) { - return self_reference_type.value()->equals(second_reference_type.value().get()); + return self_reference_type->equals(second_reference_type); } } return false; @@ -268,10 +275,9 @@ bool AstAliasType::is_assignable_to_impl(IAstType* other) // type SomePrimitive = i32[] // const someVar: SomePrimitive = [1, 2, 3]; // In this case, `[1, 2, 3]` should be assignable to the base types of `SomePrimitive` - if (const auto self_base_type = get_underlying_type(); - self_base_type.has_value()) + if (const auto self_base_type = get_underlying_type()) { - return self_base_type.value()->is_assignable_to(other); + return self_base_type->is_assignable_to(other); } return false; @@ -279,13 +285,7 @@ bool AstAliasType::is_assignable_to_impl(IAstType* other) llvm::Type* AstAliasType::get_llvm_type_impl(llvm::Module* module) { - if (!this->_underlying_type) - { - if (!this->get_underlying_type().has_value()) - return nullptr; - } - - return this->_underlying_type->get_llvm_type(module); + return this->get_underlying_type()->get_llvm_type(module); } bool AstAliasType::equals(IAstType* other) @@ -319,9 +319,9 @@ bool AstAliasType::equals(IAstType* other) return true; } - if (const auto self_base = this->get_underlying_type(); self_base.has_value()) + if (const auto self_base = this->get_underlying_type()) { - return self_base.value()->equals(other); + return self_base->equals(other); } return false; diff --git a/packages/compiler/src/ast/nodes/types/function_type.cpp b/packages/compiler/src/ast/nodes/types/function_type.cpp index 762d3014..f32fd04b 100644 --- a/packages/compiler/src/ast/nodes/types/function_type.cpp +++ b/packages/compiler/src/ast/nodes/types/function_type.cpp @@ -126,10 +126,7 @@ bool AstFunctionType::is_castable_to_impl(IAstType* other) { if (const auto other_named = cast_type(other)) { - if (const auto base_type = other_named->get_underlying_type(); base_type.has_value()) - { - return this->equals(base_type.value().get()); - } + return this->equals(other_named->get_underlying_type()); } return false; diff --git a/packages/compiler/src/ast/nodes/types/type_metadata.cpp b/packages/compiler/src/ast/nodes/types/type_metadata.cpp index c2f7a3f2..6c6a13eb 100644 --- a/packages/compiler/src/ast/nodes/types/type_metadata.cpp +++ b/packages/compiler/src/ast/nodes/types/type_metadata.cpp @@ -100,13 +100,9 @@ bool AstArrayType::is_assignable_to_impl(IAstType* other) if (auto* other_alias_ty = cast_type(other)) { const auto reference_type = other_alias_ty->get_underlying_type(); - if (!reference_type.has_value()) - { - return false; - } // Validate whether the reference type of `other_named` is assignable to - return this->is_assignable_to(reference_type.value().get()); + return this->is_assignable_to(reference_type); } // If both are arrays, we can just simply check whether their element types are equal @@ -129,14 +125,7 @@ bool AstArrayType::is_castable_to_impl(IAstType* other) // whether `[1, 2, 3]` in `SomeArray` (i32[]) is assignable to `Array(i32)` if (auto* other_alias_ty = cast_type(other)) { - const auto reference_type = other_alias_ty->get_underlying_type(); - if (!reference_type.has_value()) - { - return false; - } - - // Validate whether the reference type of `other_named` is assignable to - return this->is_castable_to(reference_type.value().get()); + return this->is_castable_to(other_alias_ty->get_underlying_type()); } return false; diff --git a/packages/compiler/src/ast/nodes/types/types.cpp b/packages/compiler/src/ast/nodes/types/types.cpp index 8acd1bc5..84d69d71 100644 --- a/packages/compiler/src/ast/nodes/types/types.cpp +++ b/packages/compiler/src/ast/nodes/types/types.cpp @@ -108,19 +108,17 @@ bool IAstType::is_assignable_to(IAstType* other) } // Try to resolve named types on both sides to check assignability - if (auto* this_named = cast_type(this)) + if (auto* self_alias_ty = cast_type(this)) { - if (const auto self_base = this_named->get_underlying_type(); - self_base.has_value() && self_base.value()->is_assignable_to(other)) + if (self_alias_ty->get_underlying_type()->is_assignable_to(other)) { return true; } } - if (auto* other_named = cast_type(other)) + if (auto* other_alias_ty = cast_type(other)) { - if (const auto other_base = other_named->get_underlying_type(); - other_base.has_value() && this->is_assignable_to(other_base.value().get())) + if (this->is_assignable_to(other_alias_ty->get_underlying_type())) { return true; } @@ -158,12 +156,7 @@ AstPrimitiveType* extract_primitive_reference_types(IAstType* type) { const auto ref_type = named->get_underlying_type(); - if (!ref_type.has_value()) - { - return nullptr; - } - - return extract_primitive_reference_types(ref_type.value().get()); + return extract_primitive_reference_types(ref_type); } if (const auto* array_type = cast_type(type)) @@ -211,7 +204,7 @@ std::unique_ptr stride::ast::get_dominant_field_type( // If one is a named type and the other is its base type, we can also return the dominant type if (const auto lhs_named = cast_type(lhs)) { - if (const auto base = lhs_named->get_underlying_type(); base.has_value() && base.value()->equals(rhs)) + if (const auto base = lhs_named->get_underlying_type(); base->equals(rhs)) { return rhs->clone_ty(); } @@ -219,7 +212,8 @@ std::unique_ptr stride::ast::get_dominant_field_type( if (const auto rhs_named = cast_type(rhs)) { - if (const auto base = rhs_named->get_underlying_type(); base.has_value() && base.value()->equals(lhs)) + if (const auto base = rhs_named->get_underlying_type(); + base->equals(lhs)) { return lhs->clone_ty(); } @@ -310,17 +304,14 @@ std::optional stride::ast::get_object_type_from_type(IAstType* t // resolves to AstObjectType with concrete member types, not raw X/Y/Z params). if (auto* alias_type = cast_type(type)) { - if (auto* resolved = alias_type->get_underlying_type_ptr()) + if (auto* object_type = cast_type(alias_type->get_underlying_type())) { - if (auto* object_type = cast_type(resolved)) - { - return object_type; - } + return object_type; } } // Fall back to raw struct type lookup - base_struct_type = type->get_context()->get_struct_type(type->get_type_name()) + base_struct_type = type->get_context()->get_object_type(type->get_type_name()) .value_or(nullptr); if (!base_struct_type) diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 91949721..476d0edc 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -403,24 +403,7 @@ std::unique_ptr stride::ast::infer_array_accessor_type( if (const auto alias_type = cast_type(array_type.get())) { // Instantiate type if it contains generics - const auto base_ty = alias_type->get_underlying_type(); - - if (!base_ty.has_value()) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Named type '{}' does not reference another type, cannot be used as array type", - alias_type->get_name() - ), - { - ErrorSourceReference("Named type", alias_type->get_source_fragment()), - ErrorSourceReference("Requires array type", accessor->get_array_identifier()->get_source_fragment()) - } - ); - } - - if (const auto array_base_ty = cast_type(base_ty.value().get())) + if (const auto array_base_ty = cast_type(alias_type->get_underlying_type())) { return array_base_ty->get_element_type()->clone_ty(); } diff --git a/packages/compiler/tests/test_generics.cpp b/packages/compiler/tests/test_generics.cpp index cf932755..aa7c3849 100644 --- a/packages/compiler/tests/test_generics.cpp +++ b/packages/compiler/tests/test_generics.cpp @@ -68,3 +68,33 @@ TEST(Generics, NestedObjectInstantiation) }; )"); } + +TEST(Generics, FunctionSignatureMismatch) +{ + assert_compiles(R"( + type Array = { + length: i32; + data: T[]; + }; + + type SomeCar = { + make: string; + }; + + type SomePerson = { + cars: Array; + }; + + fn make_person(cars: Array): SomePerson { + return SomePerson::{ cars }; + } + + fn main(): i32 { + const p = make_person(Array::{ + length: 1, + data: [SomeCar::{ make: "Toyota" }] + }); + return 0; + } + )"); +} diff --git a/packages/standard-library/system/io.sr b/packages/standard-library/system/io.sr index 8ac0b71b..da83980c 100644 --- a/packages/standard-library/system/io.sr +++ b/packages/standard-library/system/io.sr @@ -1,6 +1,6 @@ -package System; +package system; -module IO { +module io { pub const stdout_fd: i32 = 0; pub const stdin_fd: i32 = 1; @@ -10,14 +10,14 @@ module IO { * @param input The string to print. * @param ... Additional arguments to format into the string. */ - pub fn Print(input: string, ...): void { + pub fn print(input: string, ...): void { _printf_internal(input, ...); } /** * Reads a string from STDIN */ - pub fn Read(bytes: i32): string { + pub fn read(bytes: i32): string { return _read_in_internal(bytes); } } From 953c9b7135f70158b2499c782d92ec84a58abdea Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 11 Mar 2026 22:11:59 +0100 Subject: [PATCH 03/30] Refactored chained expression parsing to allow function calls and member accessors in a single line --- example.sr | 7 +- .../compiler/include/ast/nodes/expression.h | 117 +++-- .../compiler/include/ast/type_inference.h | 12 +- .../compiler/src/ast/nodes/enumerables.cpp | 29 +- .../expressions/array_member_accessor.cpp | 27 +- .../src/ast/nodes/expressions/expression.cpp | 108 ++-- .../ast/nodes/expressions/member_accessor.cpp | 490 ++++++++++-------- .../nodes/functions/function_declaration.cpp | 6 +- .../compiler/src/ast/traversal/traversal.cpp | 8 +- packages/compiler/src/ast/type_inference.cpp | 150 +++--- .../compiler/tests/test_type_inference.cpp | 23 +- 11 files changed, 531 insertions(+), 446 deletions(-) diff --git a/example.sr b/example.sr index 0fed002a..6c73bd46 100644 --- a/example.sr +++ b/example.sr @@ -41,15 +41,12 @@ fn main(): i32 { 30, Array::{ length: 1, - data: [ - make_car("Toyota", "Corolla", 2020) - ] + data: [make_car("Toyota", "Corolla", 2020)] } ); - const my_cars = me.cars.data; - const my_car = my_cars[0]; + const my_car = me.cars[0].data[0]; io::print("I am %s, age %d\n", me.name, me.age); io::print("My car is a %s %s from %d", my_car.make, my_car.model, my_car.year); diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index a5b5f3d3..154e9208 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -216,24 +216,24 @@ namespace stride::ast class AstArrayMemberAccessor : public IAstExpression { - std::unique_ptr _array_identifier; + std::unique_ptr _array_base; std::unique_ptr _index_accessor_expr; public: explicit AstArrayMemberAccessor( const SourceFragment& source, const std::shared_ptr& context, - std::unique_ptr array_identifier, + std::unique_ptr array_base, std::unique_ptr index_expr ) : IAstExpression(source, context), - _array_identifier(std::move(array_identifier)), + _array_base(std::move(array_base)), _index_accessor_expr(std::move(index_expr)) {} [[nodiscard]] - AstIdentifier* get_array_identifier() const + IAstExpression* get_array_base() const { - return this->_array_identifier.get(); + return this->_array_base.get(); } [[nodiscard]] @@ -257,40 +257,36 @@ namespace stride::ast std::unique_ptr clone() override; }; - class AstMemberAccessor + /// Represents a chained postfix expression: base.member, where base is any expression + /// and followup is an AstIdentifier (the member name). Multi-step chains like a.b.c + /// are represented left-recursively: ChainedExpression(ChainedExpression(a, b), c). + class AstChainedExpression : public IAstExpression { - // When adding function chaining support, - // these types will have to be changed to IAstExpression - std::unique_ptr _base; - std::vector> _members; + std::unique_ptr _base; + std::unique_ptr _followup; // always AstIdentifier at each leaf public: - explicit AstMemberAccessor( + explicit AstChainedExpression( const SourceFragment& source, const std::shared_ptr& context, - std::unique_ptr base, - std::vector> members + std::unique_ptr base, + std::unique_ptr followup ) : IAstExpression(source, context), _base(std::move(base)), - _members(std::move(members)) {} + _followup(std::move(followup)) {} [[nodiscard]] - AstIdentifier* get_base() const + IAstExpression* get_base() const { return this->_base.get(); } [[nodiscard]] - std::vector get_members() const; - - [[nodiscard]] - AstIdentifier* get_last_member() const + IAstExpression* get_followup() const { - return this->_members.empty() - ? nullptr - : this->_members.back().get(); + return this->_followup.get(); } llvm::Value* codegen( @@ -315,6 +311,59 @@ namespace stride::ast ) const; }; + /// Represents an indirect (expression-based) function call: expr(args). + /// Used when the callee is not a simple named identifier, e.g. arr[i]() or obj.fn(). + class AstIndirectCall + : public IAstExpression + { + std::unique_ptr _callee; + ExpressionList _args; + + public: + explicit AstIndirectCall( + const SourceFragment& source, + const std::shared_ptr& context, + std::unique_ptr callee, + ExpressionList args + ) : + IAstExpression(source, context), + _callee(std::move(callee)), + _args(std::move(args)) {} + + [[nodiscard]] + IAstExpression* get_callee() const + { + return this->_callee.get(); + } + + [[nodiscard]] + const ExpressionList& get_args() const + { + return this->_args; + } + + llvm::Value* codegen( + llvm::Module* module, + llvm::IRBuilderBase* builder + ) override; + + std::string to_string() override; + + bool is_reducible() override + { + return false; + } + + std::optional> reduce() override + { + return std::nullopt; + } + + std::unique_ptr clone() override; + + void validate() override; + }; + class AstFunctionCall : public IAstExpression { @@ -953,11 +1002,11 @@ namespace stride::ast int min_precedence ); - /// This parses both function call chaining, and struct member access - std::unique_ptr parse_chained_member_access( + /// Parses a single chained member access step: consumes `.identifier` and wraps lhs + std::unique_ptr parse_chained_member_access( const std::shared_ptr& context, TokenSet& set, - const std::unique_ptr& lhs + std::unique_ptr lhs ); /// Parses a unary operator expression @@ -972,11 +1021,18 @@ namespace stride::ast TokenSet& set ); - /// Parses an array member accessor expression, e.g., [] - std::unique_ptr parse_array_member_accessor( + /// Parses an array subscript: consumes `[]` and wraps the base expression + std::unique_ptr parse_array_member_accessor( + const std::shared_ptr& context, + TokenSet& set, + std::unique_ptr array_base + ); + + /// Parses an indirect call: consumes `()` and wraps the callee expression + std::unique_ptr parse_indirect_call( const std::shared_ptr& context, TokenSet& set, - std::unique_ptr array_identifier + std::unique_ptr callee ); /// Parses a struct initializer expression into an AstObjectInitializer node @@ -1037,7 +1093,6 @@ namespace stride::ast /// This is the case if an expression starts with `{ }` bool is_struct_initializer(const TokenSet& set); - /// Checks whether the subsequent tokens might be member accessors - /// e.g., . - bool is_member_accessor(IAstExpression* lhs, const TokenSet& set); + /// Checks whether the next tokens begin a member access: `.identifier` + bool is_member_accessor(const TokenSet& set); } // namespace stride::ast diff --git a/packages/compiler/include/ast/type_inference.h b/packages/compiler/include/ast/type_inference.h index c953c2c4..172941ca 100644 --- a/packages/compiler/include/ast/type_inference.h +++ b/packages/compiler/include/ast/type_inference.h @@ -11,7 +11,8 @@ namespace stride::ast class IAstExpression; class AstFunctionCall; class AstLiteral; - class AstMemberAccessor; + class AstChainedExpression; + class AstIndirectCall; class AstObjectInitializer; class AstUnaryOp; class IAstType; @@ -46,10 +47,13 @@ namespace stride::ast const AstVariableDeclaration* declaration, int recursion_guard); - /// Infers the type of the field accessed via a member accessor expression - std::unique_ptr infer_object_member_accessor_type(const AstMemberAccessor* member_accessor_expr); + /// Infers the type produced by a chained expression (base.followup member access) + std::unique_ptr infer_chained_expression_type(const AstChainedExpression* chained_expr); - /// Infers the type of the element accessed via an array member accessor expression, which is the element type of the array being accessed + /// Infers the return type of an indirect call expression (expr(args)) + std::unique_ptr infer_indirect_call_type(const AstIndirectCall* call_expr); + + /// Infers the element type produced by an array subscript expression std::unique_ptr infer_array_accessor_type(const AstArrayMemberAccessor* accessor, int recursion_guard); std::unique_ptr infer_function_type( diff --git a/packages/compiler/src/ast/nodes/enumerables.cpp b/packages/compiler/src/ast/nodes/enumerables.cpp index bda66c0e..8117b70f 100644 --- a/packages/compiler/src/ast/nodes/enumerables.cpp +++ b/packages/compiler/src/ast/nodes/enumerables.cpp @@ -5,7 +5,6 @@ #include "ast/tokens/token_set.h" #include -#include using namespace stride::ast; using namespace stride::ast::definition; @@ -82,8 +81,7 @@ std::unique_ptr stride::ast::parse_enumerable_declaration( ) { const auto reference_token = set.expect(TokenType::KEYWORD_ENUM); - const auto enumerable_name_tok = set.expect(TokenType::IDENTIFIER); - auto enumerable_name = enumerable_name_tok.get_lexeme(); + const auto enumerable_name = set.expect(TokenType::IDENTIFIER).get_lexeme(); context->define_symbol( Symbol(reference_token.get_source_fragment(), @@ -94,7 +92,7 @@ std::unique_ptr stride::ast::parse_enumerable_declaration( auto enum_body_subset = collect_block_required(set, "Expected a block in enum declaration"); - std::vector> members = {}; + std::vector> members; auto enum_definition_context = std::make_shared( context, @@ -116,9 +114,9 @@ std::unique_ptr stride::ast::parse_enumerable_declaration( std::unique_ptr AstEnumerable::clone() { std::vector> cloned_members; - cloned_members.reserve(this->get_members().size()); + cloned_members.reserve(this->_members.size()); - for (const auto& member : this->get_members()) + for (const auto& member : this->_members) { cloned_members.push_back(member->clone_as()); } @@ -149,20 +147,15 @@ std::string AstEnumerableMember::to_string() std::string AstEnumerable::to_string() { - std::ostringstream imploded; - - if (this->get_members().empty()) - { - return std::format("Enumerable {} (empty)", this->get_name()); - } + std::vector members; - imploded << this->get_members()[0]->to_string(); - for (size_t i = 1; i < this->get_members().size(); ++i) + for (const auto& member : this->get_members()) { - imploded << "\n " << this->get_members()[i]->to_string(); + members.push_back(member->to_string()); } - return std::format("Enumerable {} (\n {}\n)", - this->get_name(), - imploded.str()); + return std::format( + "Enumerable {} (\n {}\n)", + this->get_name(), + join(members, ",\n ")); } diff --git a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp index fbdfab64..26730e84 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp @@ -11,10 +11,10 @@ using namespace stride::ast; -std::unique_ptr stride::ast::parse_array_member_accessor( +std::unique_ptr stride::ast::parse_array_member_accessor( const std::shared_ptr& context, TokenSet& set, - std::unique_ptr array_identifier) + std::unique_ptr array_base) { auto expression_block = collect_block_variant( set, @@ -32,19 +32,19 @@ std::unique_ptr stride::ast::parse_array_member_accessor( auto index_expression = parse_inline_expression(context, expression_block.value()); - const auto source_pos = SourceFragment::combine(array_identifier->get_source_fragment(), last_src_pos); + const auto source_pos = SourceFragment::combine(array_base->get_source_fragment(), last_src_pos); return std::make_unique( source_pos, context, - std::move(array_identifier), + std::move(array_base), std::move(index_expression) ); } void AstArrayMemberAccessor::validate() { - this->_array_identifier->validate(); + this->_array_base->validate(); this->_index_accessor_expr->validate(); const auto index_accessor_type = this->_index_accessor_expr->get_type(); @@ -78,19 +78,18 @@ llvm::Value* AstArrayMemberAccessor::codegen( llvm::IRBuilderBase* builder ) { - std::unique_ptr array_iden_type = this->_array_identifier->get_type()->clone_ty(); + std::unique_ptr array_base_type = this->_array_base->get_type()->clone_ty(); - if (const auto named_ty = cast_type(array_iden_type.get())) + if (const auto named_ty = cast_type(array_base_type.get())) { - array_iden_type = named_ty->get_underlying_type()->clone_ty(); + array_base_type = named_ty->get_underlying_type()->clone_ty(); } - llvm::Value* base_ptr = this->_array_identifier->codegen(module, builder); + llvm::Value* base_ptr = this->_array_base->codegen(module, builder); llvm::Value* index_val = this->_index_accessor_expr->codegen(module, builder); // Element type, not the array type. - // Assumes `array_iden_type` is something like "T[]" and has an element type you can extract. - const auto* array_ty = cast_type(array_iden_type.get()); + const auto* array_ty = cast_type(array_base_type.get()); if (!array_ty) { throw parsing_error( @@ -123,7 +122,7 @@ std::unique_ptr AstArrayMemberAccessor::clone() return std::make_unique( this->get_source_fragment(), this->get_context(), - this->_array_identifier->clone_as(), + this->_array_base->clone_as(), this->_index_accessor_expr->clone_as() ); } @@ -132,7 +131,7 @@ std::string AstArrayMemberAccessor::to_string() { return std::format( "ArrayAccess({}, {})", - this->_array_identifier->to_string(), + this->_array_base->to_string(), this->_index_accessor_expr->to_string() ); } @@ -140,7 +139,7 @@ std::string AstArrayMemberAccessor::to_string() bool AstArrayMemberAccessor::is_reducible() { // If the value is a literal, it's reducible for sure. - if (cast_expr(this->_array_identifier.get())) + if (cast_expr(this->_array_base.get())) { return true; } diff --git a/packages/compiler/src/ast/nodes/expressions/expression.cpp b/packages/compiler/src/ast/nodes/expressions/expression.cpp index a48516ae..f40262c8 100644 --- a/packages/compiler/src/ast/nodes/expressions/expression.cpp +++ b/packages/compiler/src/ast/nodes/expressions/expression.cpp @@ -29,27 +29,25 @@ std::unique_ptr stride::ast::parse_inline_expression_part( TokenSet& set ) { + std::unique_ptr result; + if (auto lit = parse_literal_optional(context, set); lit.has_value()) { - return std::move(lit.value()); + result = std::move(lit.value()); } - // Will try to parse ::{ ... } - if (is_struct_initializer(set)) + else if (is_struct_initializer(set)) { - return parse_object_initializer(context, set); + result = parse_object_initializer(context, set); } - // Will try to parse [ ... ] - if (is_array_initializer(set)) + else if (is_array_initializer(set)) { - return parse_array_initializer(context, set); + result = parse_array_initializer(context, set); } - - // Could either be a function call, or object access - if (set.peek_next_eq(TokenType::IDENTIFIER)) + // Could either be a function call, or object/array access + else if (set.peek_next_eq(TokenType::IDENTIFIER)) { - /// Regular identifier parsing; can be variable reference const auto reference_token = set.peek_next(); // Mangled name including module, e.g., `Math__PI` const SymbolNameSegments name_segments = parse_segmented_identifier( @@ -67,64 +65,80 @@ std::unique_ptr stride::ast::parse_inline_expression_part( return std::move(reassignment.value()); } - /// Function invocations, e.g., `(...)`, or `::` + // Named function invocations, e.g., `(...)` or `::(...)` if (set.peek_next_eq(TokenType::LPAREN)) { - return parse_function_call(context, name_segments, set); + result = parse_function_call(context, name_segments, set); } - - if (set.peek_next_eq(TokenType::LSQUARE_BRACKET)) - { - return parse_array_member_accessor( - context, - set, - std::move(identifier) - ); - } - - if (is_member_accessor(identifier.get(), set)) + else { - return parse_chained_member_access( - context, - set, - std::move(identifier) - ); + result = std::move(identifier); } - - return std::move(identifier); } - // If the next token is a '(', we'll try to descend into it // until we find another one, e.g. `(1 + (2 * 3))` with nested parentheses - if (set.peek_next_eq(TokenType::LPAREN)) + else if (set.peek_next_eq(TokenType::LPAREN)) { if ((set.peek_eq(TokenType::IDENTIFIER, 1) // Checks for "(: ..." && set.peek_eq(TokenType::COLON, 2)) || (set.peek_eq(TokenType::RPAREN, 1) && // Checks for "():" set.peek_eq(TokenType::COLON, 2))) { - return parse_anonymous_fn_expression(context, set); + result = parse_anonymous_fn_expression(context, set); + } + else + { + set.next(); + // Fixed: Use parse_inline_expression (full expression parser) instead of + // parse_inline_expression_part to allow binary operations inside parentheses. + auto expr = parse_inline_expression(context, set); + // TODO: If we have a comma next, it might be a tuple expression + set.expect(TokenType::RPAREN, "Expected ')' after expression"); + result = std::move(expr); } - - set.next(); - // Fixed: Use parse_inline_expression (full expression parser) instead of - // parse_inline_expression_part to allow binary operations inside parentheses. - auto expr = parse_inline_expression(context, set); - // TODO: If we have a comma next, it might be a tuple expression - set.expect(TokenType::RPAREN, "Expected ')' after expression"); - return expr; } - - if (set.peek_next_eq(TokenType::THREE_DOTS)) + else if (set.peek_next_eq(TokenType::THREE_DOTS)) { const auto& ref = set.next(); - return std::make_unique( + result = std::make_unique( ref.get_source_fragment(), context ); } + else + { + set.throw_error("Invalid token found in expression"); + } + + // Unified postfix operator loop: + // Handles `.member`, `[index]`, and `(args)` chaining on any primary expression. + // Builds a left-recursive tree so each step's base is the result of the previous step. + int recursion_depth = 0; + while (true) + { + if (is_member_accessor(set)) + { + result = parse_chained_member_access(context, set, std::move(result)); + } + else if (set.peek_next_eq(TokenType::LSQUARE_BRACKET)) + { + result = parse_array_member_accessor(context, set, std::move(result)); + } + else if (set.peek_next_eq(TokenType::LPAREN)) + { + result = parse_indirect_call(context, set, std::move(result)); + } + else + { + break; + } + if (++recursion_depth > MAX_RECURSION_DEPTH) + { + set.throw_error("Expression too complex"); + } + } - set.throw_error("Invalid token found in expression"); + return result; } /* @@ -304,7 +318,7 @@ SymbolNameSegments stride::ast::parse_segmented_identifier( TokenSet& set, const std::string& error_message) { - std::vector segments = {}; + std::vector segments; segments.push_back(set.expect(TokenType::IDENTIFIER, error_message).get_lexeme()); diff --git a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp index 936686a5..920dc8fc 100644 --- a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp +++ b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp @@ -4,6 +4,7 @@ #include "ast/parsing_context.h" #include "ast/nodes/enumerables.h" #include "ast/nodes/expression.h" +#include "ast/nodes/blocks.h" #include "ast/tokens/token_set.h" #include @@ -11,160 +12,137 @@ using namespace stride::ast; -/// This parses both function call chaining, and struct member access -/// e.g., -/// foo().bar.baz() -/// or -/// struct_var.member.member2 ... -std::unique_ptr stride::ast::parse_chained_member_access( +// --------------------------------------------------------------------------- +// Parsing helpers +// --------------------------------------------------------------------------- + +/// Checks whether the next two tokens begin a member access (`.identifier`). +bool stride::ast::is_member_accessor(const TokenSet& set) +{ + return set.peek_eq(TokenType::DOT, 0) + && set.peek_eq(TokenType::IDENTIFIER, 1); +} + +/// Consumes `.identifier` and wraps `lhs` in an AstChainedExpression. +std::unique_ptr stride::ast::parse_chained_member_access( const std::shared_ptr& context, TokenSet& set, - const std::unique_ptr& lhs + std::unique_ptr lhs ) { - std::vector> chained_accessors = {}; - - // Initial accessors - const auto reference_token = set.peek_next(); + set.expect(TokenType::DOT, "Expected '.' in member access"); + const auto member_tok = set.expect(TokenType::IDENTIFIER, "Expected identifier after '.' in member access"); - while (set.peek_next_eq(TokenType::DOT)) - { - set.expect(TokenType::DOT, "Expected '.' after identifier in member access"); - - const auto accessor_iden_tok = set.expect(TokenType::IDENTIFIER, - "Expected identifier after '.' in member access"); + auto member_id = std::make_unique( + context, + Symbol(member_tok.get_source_fragment(), member_tok.get_lexeme()) + ); - auto symbol = Symbol(accessor_iden_tok.get_source_fragment(), - accessor_iden_tok.get_lexeme()); + const auto source = SourceFragment::combine(lhs->get_source_fragment(), member_tok.get_source_fragment()); - chained_accessors.push_back(std::make_unique(context, symbol)); - } + return std::make_unique( + source, + context, + std::move(lhs), + std::move(member_id) + ); +} - const auto lhs_source_pos = lhs->get_source_fragment(); +/// Consumes `()` and wraps `callee` in an AstIndirectCall. +std::unique_ptr stride::ast::parse_indirect_call( + const std::shared_ptr& context, + TokenSet& set, + std::unique_ptr callee +) +{ + const auto callee_src = callee->get_source_fragment(); + auto param_block = collect_parenthesized_block(set); - // TODO: Allow function calls to be the last element as well. - auto lhs_identifier = cast_expr(lhs.get()); - if (!lhs_identifier) + ExpressionList args; + if (param_block.has_value()) { - throw parsing_error( - ErrorType::TYPE_ERROR, - "Member access base must be an identifier", - lhs_source_pos); + auto subset = param_block.value(); + if (subset.has_next()) + { + args.push_back(parse_inline_expression(context, subset)); + while (subset.has_next()) + { + subset.expect(TokenType::COMMA, "Expected ',' between arguments"); + args.push_back(parse_inline_expression(context, subset)); + } + } } - const auto last_source_pos = chained_accessors.back().get()->get_source_fragment(); - const auto source_pos = SourceFragment::combine(lhs_source_pos, last_source_pos); + const auto close_src = set.peek(-1).get_source_fragment(); + const auto source = SourceFragment::combine(callee_src, close_src); - return std::make_unique( - source_pos, + return std::make_unique( + source, context, - lhs_identifier->clone_as(), - std::move(chained_accessors) + std::move(callee), + std::move(args) ); } -std::vector AstMemberAccessor::get_members() const -{ - // We don't wanna transfer ownership to anyone else... - std::vector result; - result.reserve(this->_members.size()); - std::ranges::transform( - this->_members, - std::back_inserter(result), - [](const std::unique_ptr& member) - { - return member.get(); - }); - return result; -} - -bool stride::ast::is_member_accessor(IAstExpression* lhs, const TokenSet& set) -{ - // We assume the expression and subsequent tokens are "member accessors" - // if the LHS is an identifier, and it's followed by `.` - // E.g., `struct_var.member` - if (cast_expr(lhs)) - { - return set.peek_eq(TokenType::DOT, 0) - && set.peek_eq(TokenType::IDENTIFIER, 1); - } - return false; -} +// --------------------------------------------------------------------------- +// AstChainedExpression +// --------------------------------------------------------------------------- -llvm::Value* AstMemberAccessor::codegen_global_member_accessor( +llvm::Value* AstChainedExpression::codegen_global_member_accessor( llvm::Module* module, llvm::IRBuilderBase* builder ) const { llvm::Value* base_val = this->_base->codegen(module, builder); - IAstType* cloned_base_type = this->_base->get_type(); - std::string base_type_name = cloned_base_type->get_type_name(); + IAstType* current_type = this->_base->get_type(); + std::string current_struct_name = current_type->get_type_name(); - // We look for a GlobalVariable with an initializer auto* global_var = llvm::dyn_cast_or_null(base_val); if (!global_var || !global_var->hasInitializer()) { - // If the base isn't a global with an initializer, we can't fold it. return nullptr; } llvm::Constant* current_const = global_var->getInitializer(); - std::string current_struct_name = base_type_name; - for (const auto& accessor : this->_members) + const auto* member_id = cast_expr(this->_followup.get()); + if (!member_id) { - auto struct_def_opt = this->get_context()->get_object_type(current_struct_name); - if (!struct_def_opt.has_value()) - return nullptr; - - const auto struct_def = struct_def_opt.value(); - - const auto member_index = struct_def->get_member_field_index(accessor->get_name()); - if (!member_index.has_value()) - { - return nullptr; - } + return nullptr; + } - // Extract the constant field value - current_const = current_const->getAggregateElement(member_index.value()); - if (!current_const) - { - // Index out of bounds or invalid aggregate - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format("Invalid member access on constant '{}'", base_type_name), - this->get_source_fragment() - ); - } + auto struct_def_opt = this->get_context()->get_object_type(current_struct_name); + if (!struct_def_opt.has_value()) + return nullptr; - auto member_field_type = struct_def->get_member_field_type(accessor->get_name()); - if (!member_field_type.has_value()) - { - return nullptr; - } + const auto struct_def = struct_def_opt.value(); + const auto member_index = struct_def->get_member_field_index(member_id->get_name()); + if (!member_index.has_value()) + return nullptr; - cloned_base_type = member_field_type.value(); - current_struct_name = cloned_base_type->get_type_name(); + current_const = current_const->getAggregateElement(member_index.value()); + if (!current_const) + { + throw parsing_error( + ErrorType::COMPILATION_ERROR, + std::format("Invalid member access on constant '{}'", current_struct_name), + this->get_source_fragment() + ); } return current_const; } -llvm::Value* AstMemberAccessor::codegen( +llvm::Value* AstChainedExpression::codegen( llvm::Module* module, llvm::IRBuilderBase* builder ) { - // Global struct definitions have no insertion point, so we need to do - // constant folding here by looking up the global variable and its initializer. if (!builder->GetInsertBlock()) { return codegen_global_member_accessor(module, builder); } - // Standard Code Generation (Function context) - - // Will codegen identifier - reference to variable (getptr) llvm::Value* current_val = this->_base->codegen(module, builder); if (!current_val) { @@ -172,191 +150,245 @@ llvm::Value* AstMemberAccessor::codegen( } const auto base_struct_type = get_object_type_from_type(this->_base->get_type()); - - // Base must be a struct for member access to be valid. - // This would be okay if there were on members, however, this should never happen if (!base_struct_type.has_value()) { throw parsing_error( ErrorType::TYPE_ERROR, - "Member access base must be a struct", + "Member access base must be a struct type", + this->get_source_fragment() + ); + } + + const auto* member_id = cast_expr(this->_followup.get()); + if (!member_id) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + "Chained expression followup must be an identifier (member name)", this->get_source_fragment() ); } IAstType* parent_type = base_struct_type.value(); - std::string parent_struct_internalized_name = base_struct_type.value()->get_internalized_name(); - std::string current_accessor_name = this->_base->get_name(); + const std::string parent_struct_internalized_name = base_struct_type.value()->get_internalized_name(); - // With opaque pointers, we need to know if we are operating on an address (L-value) - // or a loaded struct value (R-value). Pointers allow GEP, values require ExtractValue. const bool is_pointer_ty = current_val->getType()->isPointerTy(); - for (const auto& accessor : this->_members) + auto parent_struct_type_opt = get_object_type_from_type(parent_type); + if (!parent_struct_type_opt.has_value()) { - auto parent_struct_type_opt = get_object_type_from_type(parent_type); - - // In next iteration, it's possible that the previous member produced a non-struct type, - // hence yielding a nullptr. - if (!parent_struct_type_opt.has_value()) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Cannot access member '{}' of non-struct type '{}'", - accessor->get_name(), - current_accessor_name - ), - this->get_source_fragment() - ); - } - - const auto parent_struct_type = parent_struct_type_opt.value(); + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Cannot access member '{}' of non-struct type", + member_id->get_name() + ), + this->get_source_fragment() + ); + } - parent_struct_internalized_name = parent_struct_type->get_internalized_name(); + const auto parent_struct_type = parent_struct_type_opt.value(); + const std::string internalized_name = parent_struct_type->get_internalized_name(); - const auto member_index = parent_struct_type->get_member_field_index(accessor->get_name()); + const auto member_index = parent_struct_type->get_member_field_index(member_id->get_name()); + if (!member_index.has_value()) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Cannot access member '{}' of object type '{}': member does not exist", + member_id->get_name(), + parent_struct_type->get_type_name() + ), + this->get_source_fragment() + ); + } - // Validate whether the previous struct contains the "member" field - if (!member_index.has_value()) + if (is_pointer_ty) + { + llvm::StructType* struct_llvm_type = llvm::StructType::getTypeByName( + module->getContext(), + internalized_name + ); + if (!struct_llvm_type) { throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Cannot access member '{}' of object type '{}': member does not exist", - accessor->get_name(), - parent_struct_type->get_type_name() - ), + ErrorType::COMPILATION_ERROR, + std::format("Object '{}' not registered internally", parent_struct_type->get_type_name()), this->get_source_fragment() ); } - if (is_pointer_ty) - { - // Get the LLVM type for the current struct to generate the GEP - llvm::StructType* struct_llvm_type = llvm::StructType::getTypeByName( - module->getContext(), - parent_struct_internalized_name - ); - if (!struct_llvm_type) - { - throw parsing_error( - ErrorType::COMPILATION_ERROR, - std::format( - "Object '{}' not registered internally", - parent_struct_type->get_type_name()), - this->get_source_fragment()); - } - - // Create the GEP (GetElementPtr) instruction: ¤t_ptr->member - current_val = builder->CreateStructGEP( - struct_llvm_type, - current_val, - member_index.value(), - "ptr_" + accessor->get_name() - ); - } - else - { - // We have a direct value, extract the member: current_val.member - current_val = builder->CreateExtractValue( - current_val, - member_index.value(), - "val_" + accessor->get_name() - ); - } + llvm::Value* member_ptr = builder->CreateStructGEP( + struct_llvm_type, + current_val, + member_index.value(), + "ptr_" + member_id->get_name() + ); - auto member_field_type = parent_struct_type->get_member_field_type(accessor->get_name()); + auto member_field_type = parent_struct_type->get_member_field_type(member_id->get_name()); if (!member_field_type.has_value()) { throw parsing_error( ErrorType::COMPILATION_ERROR, - std::format( - "Unknown member type '{}' in object '{}'", - accessor->get_name(), - parent_struct_type->get_type_name() - ), - this->get_source_fragment()); + std::format("Unknown member type '{}' in object '{}'", + member_id->get_name(), parent_struct_type->get_type_name()), + this->get_source_fragment() + ); } - parent_type = member_field_type.value(); + llvm::Type* final_llvm_type = member_field_type.value()->get_llvm_type(module); + return builder->CreateLoad(final_llvm_type, member_ptr, "val_member_access"); } - if (!parent_type) + // Value (not pointer) — use ExtractValue + llvm::Value* val = builder->CreateExtractValue( + current_val, + member_index.value(), + "val_" + member_id->get_name() + ); + + return val; +} + +std::unique_ptr AstChainedExpression::clone() +{ + return std::make_unique( + this->get_source_fragment(), + this->get_context(), + this->_base->clone_as(), + this->_followup->clone_as() + ); +} + +void AstChainedExpression::validate() +{ + this->_base->validate(); + this->_followup->validate(); +} + +std::string AstChainedExpression::to_string() +{ + return std::format( + "ChainedExpression(base: {}, followup: {})", + this->_base->to_string(), + this->_followup->to_string() + ); +} + +std::optional> AstChainedExpression::reduce() +{ + return std::nullopt; +} + +bool AstChainedExpression::is_reducible() +{ + return false; +} + +// --------------------------------------------------------------------------- +// AstIndirectCall +// --------------------------------------------------------------------------- + +llvm::Value* AstIndirectCall::codegen( + llvm::Module* module, + llvm::IRBuilderBase* builder +) +{ + llvm::Value* callee_val = this->_callee->codegen(module, builder); + if (!callee_val) { throw parsing_error( ErrorType::COMPILATION_ERROR, - std::format( - "Invalid member access on non-struct type '{}'", - this->_base->get_type()->get_type_name() - ), - this->get_source_fragment()); + "Indirect call: could not evaluate callee expression", + this->get_source_fragment() + ); } - // if we were working with pointers, we need to load the final result - if (is_pointer_ty) + // Derive the LLVM function type from the callee's AST type + auto callee_ast_type = this->_callee->get_type()->clone_ty(); + if (auto* alias_ty = cast_type(callee_ast_type.get())) { - llvm::Type* final_llvm_type = parent_type->get_llvm_type(module); + callee_ast_type = alias_ty->get_underlying_type()->clone_ty(); + } - return builder->CreateLoad( - final_llvm_type, - current_val, - "val_member_access" + const auto* fn_type = dynamic_cast(callee_ast_type.get()); + if (!fn_type) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + "Indirect call: callee expression does not have a function type", + this->get_source_fragment() ); } - // If we were working with values (ExtractValue), we already have the result. - return current_val; + std::vector llvm_param_types; + llvm_param_types.reserve(fn_type->get_parameter_types().size()); + for (const auto& param : fn_type->get_parameter_types()) + { + llvm_param_types.push_back(param->get_llvm_type(module)); + } + + llvm::FunctionType* llvm_fn_type = llvm::FunctionType::get( + fn_type->get_return_type()->get_llvm_type(module), + llvm_param_types, + fn_type->is_variadic() + ); + + std::vector args_v; + args_v.reserve(this->_args.size()); + for (const auto& arg : this->_args) + { + llvm::Value* arg_val = arg->codegen(module, builder); + if (!arg_val) + return nullptr; + args_v.push_back(arg_val); + } + + const auto instruction_name = + llvm_fn_type->getReturnType()->isVoidTy() ? "" : "indcalltmp"; + + return builder->CreateCall(llvm_fn_type, callee_val, args_v, instruction_name); } -std::unique_ptr AstMemberAccessor::clone() +std::unique_ptr AstIndirectCall::clone() { - std::vector> members; - members.reserve(this->_members.size()); - - for (const auto& member : this->_members) + ExpressionList cloned_args; + cloned_args.reserve(this->_args.size()); + for (const auto& arg : this->_args) { - members.push_back(member->clone_as()); + cloned_args.push_back(arg->clone_as()); } - return std::make_unique( + return std::make_unique( this->get_source_fragment(), this->get_context(), - this->_base->clone_as(), - std::move(members) + this->_callee->clone_as(), + std::move(cloned_args) ); } -void AstMemberAccessor::validate() +void AstIndirectCall::validate() { - for (const auto& member : this->_members) + this->_callee->validate(); + for (const auto& arg : this->_args) { - member->validate(); + arg->validate(); } } -std::string AstMemberAccessor::to_string() +std::string AstIndirectCall::to_string() { - std::vector member_names; - member_names.reserve(this->_members.size()); - - for (const auto& member : this->_members) + std::vector arg_strs; + arg_strs.reserve(this->_args.size()); + for (const auto& arg : this->_args) { - member_names.push_back(member->get_name()); + arg_strs.push_back(arg->to_string()); } return std::format( - "MemberAccessor(base: {}, member: {})", - this->get_base()->to_string(), - join(member_names, ",")); -} - -std::optional> AstMemberAccessor::reduce() -{ - return std::nullopt; // Not yet reducible -} - -bool AstMemberAccessor::is_reducible() -{ - return false; + "IndirectCall(callee: {}, args: [{}])", + this->_callee->to_string(), + join(arg_strs, ", ") + ); } diff --git a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp index ee4287ae..9e2131bf 100644 --- a/packages/compiler/src/ast/nodes/functions/function_declaration.cpp +++ b/packages/compiler/src/ast/nodes/functions/function_declaration.cpp @@ -461,10 +461,10 @@ void collect_free_variables( return; } - // Handle member access - if (const auto* member = cast_expr(node)) + // Handle chained member access + if (const auto* chained = cast_expr(node)) { - collect_free_variables(member->get_base(), lambda_context, outer_context, captures); + collect_free_variables(chained->get_base(), lambda_context, outer_context, captures); return; } diff --git a/packages/compiler/src/ast/traversal/traversal.cpp b/packages/compiler/src/ast/traversal/traversal.cpp index 50a5383b..57fb9602 100644 --- a/packages/compiler/src/ast/traversal/traversal.cpp +++ b/packages/compiler/src/ast/traversal/traversal.cpp @@ -65,7 +65,7 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) } else if (const auto* array_accessor = dynamic_cast(node)) { - visit_expression(visitor, array_accessor->get_array_identifier()); + visit_expression(visitor, array_accessor->get_array_base()); visit_expression(visitor, array_accessor->get_index()); } else if (const auto* struct_init = dynamic_cast(node)) @@ -82,10 +82,10 @@ void AstNodeTraverser::visit_expression(IVisitor* visitor, IAstExpression* node) { visit_expression(visitor, reassign->get_value()); } - else if (const auto* member_access = dynamic_cast(node)) + else if (const auto* chained = dynamic_cast(node)) { - // Visit the base identifier so its type is resolved before the accessor's type is inferred. - visit_expression(visitor, member_access->get_base()); + // Visit the base expression so its type is resolved before the accessor's type is inferred. + visit_expression(visitor, chained->get_base()); } else if (auto* function_node = dynamic_cast(node)) { diff --git a/packages/compiler/src/ast/type_inference.cpp b/packages/compiler/src/ast/type_inference.cpp index 476d0edc..127d0916 100644 --- a/packages/compiler/src/ast/type_inference.cpp +++ b/packages/compiler/src/ast/type_inference.cpp @@ -161,101 +161,88 @@ std::unique_ptr stride::ast::infer_array_member_type(const AstArray* a return infer_expression_type(array->get_elements().front().get()); } -std::unique_ptr stride::ast::infer_object_member_accessor_type(const AstMemberAccessor* member_accessor_expr) +std::unique_ptr stride::ast::infer_chained_expression_type(const AstChainedExpression* chained_expr) { - // Base must be an identifier, e.g., .... - const auto base_iden = cast_expr(member_accessor_expr->get_base()); - if (!base_iden) + // Infer the type of the base (left side) + auto base_type = infer_expression_type(chained_expr->get_base()); + + // Resolve alias types to get the underlying struct type + IAstType* struct_type_raw = get_object_type_from_type(base_type.get()).value_or(nullptr); + if (!struct_type_raw) { throw parsing_error( ErrorType::TYPE_ERROR, - "Member access base must be an identifier", - member_accessor_expr->get_source_fragment()); + std::format( + "Member access base must be a struct type, got '{}'", + base_type->get_type_name() + ), + chained_expr->get_source_fragment() + ); } - // ---- Look up the base variable in the symbol table/context - const auto variable_definition = member_accessor_expr->get_context()->lookup_variable( - base_iden->get_name(), - true); - - if (!variable_definition) + const auto* struct_type = dynamic_cast(struct_type_raw); + if (!struct_type) { throw parsing_error( - ErrorType::REFERENCE_ERROR, - std::format("Variable '{}' not found in scope", base_iden->get_name()), - member_accessor_expr->get_source_fragment() + ErrorType::TYPE_ERROR, + std::format("Object type '{}' not found in scope", base_type->get_type_name()), + chained_expr->get_source_fragment() ); } - // ---- Ensure parent (base) is a struct type - IAstType* parent_type = get_object_type_from_type(variable_definition->get_type()). - value_or(nullptr); - - if (!parent_type) + // The followup must be an identifier (the member name) + const auto* member_id = cast_expr(chained_expr->get_followup()); + if (!member_id) { throw parsing_error( ErrorType::TYPE_ERROR, - std::format( - "Object type '{}' not found in this scope", - variable_definition->get_type()->get_type_name()), - member_accessor_expr->get_source_fragment()); + "Chained expression followup must be an identifier", + chained_expr->get_source_fragment() + ); } - std::string parent_name = variable_definition->get_symbol().name; - - // Iterate through all member segments (e.g., .b, .c) - for (const auto& members = member_accessor_expr->get_members(); - const auto member : members) + const auto field_type = struct_type->get_member_field_type(member_id->get_name()); + if (!field_type.has_value()) { - auto struct_type = get_object_type_from_type(parent_type); - - // It's possible that the previous iteration yielded a non-struct type, and thus being nullptr. - if (!struct_type.has_value()) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Expected member '{}' in '{}' to be of type 'object', got '{}'", - parent_name, - member_accessor_expr->get_base()->get_type()->get_type_name(), - parent_type->get_type_name() - ), - member_accessor_expr->get_source_fragment() - ); - } - - // Resolve the member identifier (e.g., 'b') - // For now, members are already identifiers, though this might change in the future. + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format("Field '{}' not found in struct '{}'", + member_id->get_name(), + base_type->get_type_name()), + chained_expr->get_source_fragment() + ); + } - const auto member_identifier = cast_expr(member); - if (!member_identifier) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - "Member accessor must be an identifier", - member_accessor_expr->get_source_fragment()); - } + return field_type.value()->clone_ty(); +} - const auto field_type = struct_type.value()->get_member_field_type( - member_identifier->get_name()); +std::unique_ptr stride::ast::infer_indirect_call_type(const AstIndirectCall* call_expr) +{ + auto callee_type = infer_expression_type(call_expr->get_callee()); - if (!field_type.has_value()) - { - throw parsing_error( - ErrorType::TYPE_ERROR, - std::format( - "Field '{}' not found in struct", - member_identifier->get_name()), - member_accessor_expr->get_source_fragment()); - } + // Unwrap alias + IAstType* raw_type = callee_type.get(); + std::unique_ptr unwrapped; + if (auto* alias = cast_type(raw_type)) + { + unwrapped = alias->get_underlying_type()->clone_ty(); + raw_type = unwrapped.get(); + } - // Update current_type for the next iteration (or for the final return) - parent_type = field_type.value(); - parent_name = member->get_name(); + const auto* fn_type = dynamic_cast(raw_type); + if (!fn_type) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format( + "Cannot call expression of type '{}' as a function", + callee_type->get_type_name() + ), + call_expr->get_source_fragment() + ); } - // Return the final inferred type - return parent_type->clone_ty(); + return fn_type->get_return_type()->clone_ty(); } std::unique_ptr stride::ast::infer_object_initializer_type(const AstObjectInitializer* struct_initializer) @@ -390,8 +377,8 @@ std::unique_ptr stride::ast::infer_array_accessor_type( const AstArrayMemberAccessor* accessor, const int recursion_guard) { - // Infer the identifier's type. We must ensure it's an array type, and then we can return the member type. - const auto array_type = infer_expression_type(accessor->get_array_identifier(), recursion_guard); + // Infer the base expression's type. We must ensure it's an array type, and then we can return the element type. + const auto array_type = infer_expression_type(accessor->get_array_base(), recursion_guard); // If the immediate type is an array, we can simply return the member type if (const auto array = cast_type(array_type.get())) @@ -414,13 +401,15 @@ std::unique_ptr stride::ast::infer_array_accessor_type( "Named type '{}' references a type that is not an array, cannot be used as array type", alias_type->get_name() ), - accessor->get_array_identifier()->get_source_fragment() + accessor->get_array_base()->get_source_fragment() ); } throw parsing_error( ErrorType::SEMANTIC_ERROR, - "Unable to resolve array member accessor type", + std::format("Expected array type for member accessor, got: '{}'", + accessor->get_array_base()->get_type()->get_type_name() + ), accessor->get_source_fragment() ); } @@ -503,9 +492,14 @@ std::unique_ptr stride::ast::infer_expression_type(IAstExpression* exp return infer_object_initializer_type(struct_init); } - if (const auto* member_accessor = cast_expr(expr)) + if (const auto* chained = cast_expr(expr)) + { + return infer_chained_expression_type(chained); + } + + if (const auto* indirect_call = cast_expr(expr)) { - return infer_object_member_accessor_type(member_accessor); + return infer_indirect_call_type(indirect_call); } if (const auto* function_definition = cast_expr(expr)) diff --git a/packages/compiler/tests/test_type_inference.cpp b/packages/compiler/tests/test_type_inference.cpp index 7fc11dba..0820abe9 100644 --- a/packages/compiler/tests/test_type_inference.cpp +++ b/packages/compiler/tests/test_type_inference.cpp @@ -445,28 +445,26 @@ TEST_F(TypeInferenceTest, InferStructAndMemberAccess) std::make_unique(dummy_sf(), context, "Point"), VisibilityModifier::PUBLIC); - // Nested member access + // Nested member access using AstChainedExpression context->define_variable(dummy_sym("q"), std::make_unique(dummy_sf(), context, "Point"), VisibilityModifier::PUBLIC); auto base2 = std::make_unique(context, dummy_sym("q")); - std::vector> access_members2; - access_members2.push_back(std::make_unique(context, dummy_sym("x"))); - auto member_access2 = std::make_unique( + auto member2 = std::make_unique(context, dummy_sym("x")); + auto chained2 = std::make_unique( dummy_sf(), context, std::move(base2), - std::move(access_members2)); - EXPECT_EQ(infer_expression_type(member_access2.get())->to_string(), "i32"); + std::move(member2)); + EXPECT_EQ(infer_expression_type(chained2.get())->to_string(), "i32"); } -TEST_F(TypeInferenceTest, InferMemberAccessorErrors) +TEST_F(TypeInferenceTest, InferChainedExpressionErrors) { // Base not found auto base = std::make_unique(context, dummy_sym("unknown_var")); - std::vector> members; - members.push_back(std::make_unique(context, dummy_sym("x"))); - auto access = std::make_unique(dummy_sf(), context, std::move(base), std::move(members)); + auto member = std::make_unique(context, dummy_sym("x")); + auto access = std::make_unique(dummy_sf(), context, std::move(base), std::move(member)); EXPECT_THROW(infer_expression_type(access.get()), parsing_error); // Base is not a struct @@ -474,9 +472,8 @@ TEST_F(TypeInferenceTest, InferMemberAccessorErrors) std::make_unique(dummy_sf(), context, PrimitiveType::INT32), VisibilityModifier::PUBLIC); auto base2 = std::make_unique(context, dummy_sym("i")); - std::vector> members2; - members2.push_back(std::make_unique(context, dummy_sym("x"))); - auto access2 = std::make_unique(dummy_sf(), context, std::move(base2), std::move(members2)); + auto member2 = std::make_unique(context, dummy_sym("x")); + auto access2 = std::make_unique(dummy_sf(), context, std::move(base2), std::move(member2)); EXPECT_THROW(infer_expression_type(access2.get()), parsing_error); } From 25939bef1854972d8606fdd05206f4a4b2471fdf Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Wed, 11 Mar 2026 22:26:18 +0100 Subject: [PATCH 04/30] Added tests for chained expression member accessor --- .../compiler/tests/test_chained_accessor.cpp | 507 ++++++++++++++++++ 1 file changed, 507 insertions(+) create mode 100644 packages/compiler/tests/test_chained_accessor.cpp diff --git a/packages/compiler/tests/test_chained_accessor.cpp b/packages/compiler/tests/test_chained_accessor.cpp new file mode 100644 index 00000000..997beffe --- /dev/null +++ b/packages/compiler/tests/test_chained_accessor.cpp @@ -0,0 +1,507 @@ +/// Tests for chained expression parsing and type inference. +/// Covers: +/// - Simple member access: a.b +/// - Multi-step chains: a.b.c.d +/// - Array subscript followed by member access: arr[i].field +/// - Member access followed by array subscript: a.b[i] +/// - Mixed chains: a[i].b.c[j].d +/// - Function call returning struct, then member access: fn().field +/// - Array of function pointers: arr[i]().field +/// - AstChainedExpression type inference unit tests +/// - AstArrayMemberAccessor with chained base type inference +/// - AstIndirectCall type inference +/// - Error cases: non-struct base, unknown member, non-function indirect call + +#include "utils.h" +#include "ast/type_inference.h" +#include "ast/nodes/expression.h" +#include "ast/nodes/literal_values.h" +#include "ast/nodes/types.h" +#include "ast/parsing_context.h" +#include "ast/symbols.h" +#include "errors.h" +#include "files.h" + +#include +#include + +using namespace stride; +using namespace stride::ast; +using namespace stride::tests; + +// ============================================================================= +// Integration (parse + codegen) tests +// ============================================================================= + +TEST(ChainedAccessor, SimpleMemberAccess) +{ + assert_compiles(R"( + type Point = { x: i32; y: i32; }; + fn get_x(p: Point): i32 { + return p.x; + } + fn main(): void { + const p: Point = Point::{ x: 10, y: 20 }; + const x = p.x; + } + )"); +} + +TEST(ChainedAccessor, MultiStepMemberAccess) +{ + assert_compiles(R"( + type Inner = { value: i32; }; + type Middle = { inner: Inner; }; + type Outer = { middle: Middle; }; + fn main(): void { + const o: Outer = Outer::{ + middle: Middle::{ + inner: Inner::{ value: 42 } + } + }; + const v = o.middle.inner.value; + } + )"); +} + +TEST(ChainedAccessor, ArraySubscriptThenMember) +{ + assert_compiles(R"( + type Point = { x: i32; y: i32; }; + fn main(): void { + const pts: Point[] = [ Point::{ x: 1, y: 2 } ]; + const x = pts[0].x; + } + )"); +} + +TEST(ChainedAccessor, MemberThenArraySubscript) +{ + assert_compiles(R"( + type Container = { data: i32[]; }; + fn main(): void { + const c: Container = Container::{ data: [10, 20, 30] }; + const first = c.data[0]; + } + )"); +} + +TEST(ChainedAccessor, ChainWithArrayAndMember) +{ + assert_compiles(R"( + type Point = { x: i32; y: i32; }; + type Row = { points: Point[]; }; + fn main(): void { + const row: Row = Row::{ points: [ Point::{ x: 3, y: 4 } ] }; + const y = row.points[0].y; + } + )"); +} + +TEST(ChainedAccessor, DeepMixedChain) +{ + assert_compiles(R"( + type Leaf = { val: i32; }; + type Branch = { leaves: Leaf[]; }; + type Tree = { branches: Branch[]; }; + fn main(): void { + const tree: Tree = Tree::{ + branches: [ + Branch::{ leaves: [ Leaf::{ val: 99 } ] } + ] + }; + const v = tree.branches[0].leaves[0].val; + } + )"); +} + +TEST(ChainedAccessor, FunctionReturnsThenMemberAccess) +{ + assert_compiles(R"( + type Point = { x: i32; y: i32; }; + fn make_point(): Point { + return Point::{ x: 5, y: 6 }; + } + fn main(): void { + const x = make_point().x; + const y = make_point().y; + } + )"); +} + +TEST(ChainedAccessor, FunctionReturnsThenMultiStepAccess) +{ + assert_compiles(R"( + type Inner = { value: i32; }; + type Wrapper = { inner: Inner; }; + fn make(): Wrapper { + return Wrapper::{ inner: Inner::{ value: 7 } }; + } + fn main(): void { + const v = make().inner.value; + } + )"); +} + +TEST(ChainedAccessor, FunctionReturnsThenArraySubscript) +{ + assert_compiles(R"( + type Container = { data: i32[]; }; + fn make(): Container { + return Container::{ data: [1, 2, 3] }; + } + fn main(): void { + const second = make().data[1]; + } + )"); +} + +TEST(ChainedAccessor, ChainedMemberAccessInReturn) +{ + assert_compiles(R"( + type Inner = { value: i32; }; + type Outer = { inner: Inner; }; + fn get_value(o: Outer): i32 { + return o.inner.value; + } + fn main(): void { + const o: Outer = Outer::{ inner: Inner::{ value: 123 } }; + const v = get_value(o); + } + )"); +} + +TEST(ChainedAccessor, ChainedMemberInCondition) +{ + assert_compiles(R"( + type Stats = { score: i32; }; + type Player = { stats: Stats; }; + fn main(): void { + const p: Player = Player::{ stats: Stats::{ score: 100 } }; + if (p.stats.score > 50) { + const x: i32 = 1; + } + } + )"); +} + +TEST(ChainedAccessor, ArraySubscriptInFunctionArg) +{ + assert_compiles(R"( + type Point = { x: i32; y: i32; }; + fn print_x(x: i32): void {} + fn main(): void { + const pts: Point[] = [ Point::{ x: 7, y: 8 } ]; + print_x(pts[0].x); + } + )"); +} + +// ============================================================================= +// Type inference unit tests +// ============================================================================= + +class ChainedAccessorTypeTest : public ::testing::Test +{ +protected: + std::shared_ptr context; + std::shared_ptr source; + + void SetUp() override + { + context = std::make_shared(); + source = std::make_shared("test.sr", ""); + + // Define a Point struct: { x: i32, y: f32 } + std::vector>> point_members; + point_members.emplace_back("x", std::make_unique(sf(), context, PrimitiveType::INT32)); + point_members.emplace_back("y", std::make_unique(sf(), context, PrimitiveType::FLOAT32)); + auto point_ty = std::make_unique(sf(), context, "Point", std::move(point_members)); + context->define_type(sym("Point"), std::move(point_ty), {}, VisibilityModifier::PUBLIC); + + // Define a Wrapper struct: { pt: Point, values: i32[] } + std::vector>> wrapper_members; + wrapper_members.emplace_back("pt", std::make_unique(sf(), context, "Point")); + wrapper_members.emplace_back( + "values", + std::make_unique( + sf(), + context, + std::make_unique( + sf(), + context, + PrimitiveType::INT32), + 0)); + auto wrapper_ty = std::make_unique(sf(), context, "Wrapper", std::move(wrapper_members)); + context->define_type(sym("Wrapper"), std::move(wrapper_ty), {}, VisibilityModifier::PUBLIC); + + // Register variables + context->define_variable( + sym("p"), + std::make_unique(sf(), context, "Point"), + VisibilityModifier::PUBLIC); + context->define_variable( + sym("w"), + std::make_unique(sf(), context, "Wrapper"), + VisibilityModifier::PUBLIC); + + // Register a Point[] array variable + context->define_variable( + sym("pts"), + std::make_unique( + sf(), + context, + std::make_unique(sf(), context, "Point"), + 0), + VisibilityModifier::PUBLIC); + + // Register a function returning Point + std::vector> fn_params; + auto fn_ret = std::make_unique(sf(), context, "Point"); + auto fn_ty = std::make_unique(sf(), context, std::move(fn_params), std::move(fn_ret)); + context->define_function(sym("make_point"), std::move(fn_ty), VisibilityModifier::PUBLIC, 0); + } + + [[nodiscard]] SourceFragment sf() const + { + return { source, 0, 0 }; + } + + [[nodiscard]] Symbol sym(const std::string& name) const + { + return Symbol(sf(), name); + } + + /// Builds: Identifier(name) + [[nodiscard]] std::unique_ptr id(const std::string& name) const + { + return std::make_unique(context, sym(name)); + } + + /// Builds: ChainedExpression(base, Identifier(member)) + [[nodiscard]] std::unique_ptr chain( + std::unique_ptr base, + const std::string& member) const + { + return std::make_unique(sf(), context, std::move(base), id(member)); + } + + /// Builds: ArrayAccess(base, IntLiteral(idx)) + [[nodiscard]] std::unique_ptr subscript( + std::unique_ptr base, + int32_t idx) const + { + auto index = std::make_unique(sf(), context, PrimitiveType::INT32, idx, 0); + return std::make_unique(sf(), context, std::move(base), std::move(index)); + } +}; + +// --- Simple chain: p.x → i32 + +TEST_F(ChainedAccessorTypeTest, SimpleMemberAccess_i32) +{ + const auto expr = chain(id("p"), "x"); + EXPECT_EQ(infer_expression_type(expr.get())->to_string(), "i32"); +} + +TEST_F(ChainedAccessorTypeTest, SimpleMemberAccess_f32) +{ + const auto expr = chain(id("p"), "y"); + EXPECT_EQ(infer_expression_type(expr.get())->to_string(), "f32"); +} + +// --- Two-step chain: w.pt.x → i32 + +TEST_F(ChainedAccessorTypeTest, TwoStepChain) +{ + // ChainedExpression(ChainedExpression(w, pt), x) + auto inner = chain(id("w"), "pt"); + const auto outer = chain(std::move(inner), "x"); + EXPECT_EQ(infer_expression_type(outer.get())->to_string(), "i32"); +} + +// --- Two-step chain resolves to correct second member: w.pt.y → f32 + +TEST_F(ChainedAccessorTypeTest, TwoStepChain_SecondMember) +{ + auto inner = chain(id("w"), "pt"); + const auto outer = chain(std::move(inner), "y"); + EXPECT_EQ(infer_expression_type(outer.get())->to_string(), "f32"); +} + +// --- Array subscript then member: pts[0].x → i32 + +TEST_F(ChainedAccessorTypeTest, ArraySubscriptThenMember) +{ + auto access = subscript(id("pts"), 0); + const auto expr = chain(std::move(access), "x"); + EXPECT_EQ(infer_expression_type(expr.get())->to_string(), "i32"); +} + +// --- Array subscript then member (y): pts[0].y → f32 + +TEST_F(ChainedAccessorTypeTest, ArraySubscriptThenMember_f32) +{ + auto access = subscript(id("pts"), 0); + const auto expr = chain(std::move(access), "y"); + EXPECT_EQ(infer_expression_type(expr.get())->to_string(), "f32"); +} + +// --- Member then array subscript: w.values[0] → i32 + +TEST_F(ChainedAccessorTypeTest, MemberThenArraySubscript) +{ + auto member = chain(id("w"), "values"); + const auto expr = subscript(std::move(member), 0); + EXPECT_EQ(infer_expression_type(expr.get())->to_string(), "i32"); +} + +// --- Indirect call then member: make_point().x → i32 + +TEST_F(ChainedAccessorTypeTest, IndirectCallThenMember) +{ + auto callee = id("make_point"); + // Set the callee's type to the function type + // (type inference needs set_type to work on identifiers that are functions) + // Build the FunctionCall first via AstFunctionCall; then build IndirectCall on top + // In practice, indirect call is used on non-named callees; here we test type inference + // by constructing AstIndirectCall directly with a correctly-typed callee. + std::vector> fn_params; + auto fn_ret = std::make_unique(sf(), context, "Point"); + auto fn_ast_ty = std::make_unique(sf(), context, std::move(fn_params), std::move(fn_ret)); + callee->set_type(std::move(fn_ast_ty)); + + auto call = std::make_unique(sf(), context, std::move(callee), ExpressionList{}); + const auto expr = chain(std::move(call), "x"); + EXPECT_EQ(infer_expression_type(expr.get())->to_string(), "i32"); +} + +// --- Indirect call type inference: returns correct return type + +TEST_F(ChainedAccessorTypeTest, IndirectCallReturnsCorrectType) +{ + auto callee = id("make_point"); + std::vector> fn_params; + auto fn_ret = std::make_unique(sf(), context, "Point"); + auto fn_ast_ty = std::make_unique(sf(), context, std::move(fn_params), std::move(fn_ret)); + callee->set_type(std::move(fn_ast_ty)); + + const auto call = std::make_unique(sf(), context, std::move(callee), ExpressionList{}); + EXPECT_EQ(infer_expression_type(call.get())->to_string(), "Point"); +} + +// --- Array subscript followed by member followed by subscript (deep chain) +// w.values[0] (just array subscript on a member array): already tested above. +// Here: subscript → chain → subscript through wrapper's values. +// We test: (AstArrayMemberAccessor base= chain(w, values))[0] → i32 + +TEST_F(ChainedAccessorTypeTest, ChainedMemberThenSubscript_ReturnsElementType) +{ + auto member_access = chain(id("w"), "values"); // type: i32[] + const auto arr_access = subscript(std::move(member_access), 2); // type: i32 + EXPECT_EQ(infer_expression_type(arr_access.get())->to_string(), "i32"); +} + +// ============================================================================= +// Error cases +// ============================================================================= + +TEST_F(ChainedAccessorTypeTest, ErrorMemberAccessOnNonStruct) +{ + // Define an int variable, try to access .x on it → type error + context->define_variable( + sym("n"), + std::make_unique(sf(), context, PrimitiveType::INT32), + VisibilityModifier::PUBLIC); + + auto expr = chain(id("n"), "x"); + EXPECT_THROW(infer_expression_type(expr.get()), parsing_error); +} + +TEST_F(ChainedAccessorTypeTest, ErrorUndefinedBaseVariable) +{ + const auto expr = chain(id("does_not_exist"), "x"); + EXPECT_THROW(infer_expression_type(expr.get()), parsing_error); +} + +TEST_F(ChainedAccessorTypeTest, ErrorMemberNotInStruct) +{ + const auto expr = chain(id("p"), "z"); // Point has x, y — not z + EXPECT_THROW(infer_expression_type(expr.get()), parsing_error); +} + +TEST_F(ChainedAccessorTypeTest, ErrorMemberNotInStructAtDepth2) +{ + auto inner = chain(id("w"), "pt"); // w.pt is a Point + auto outer = chain(std::move(inner), "nonexistent"); // Point has no such field + EXPECT_THROW(infer_expression_type(outer.get()), parsing_error); +} + +TEST_F(ChainedAccessorTypeTest, ErrorIndirectCallOnNonFunction) +{ + auto callee = id("p"); + // Set type to a non-function type + callee->set_type(std::make_unique(sf(), context, "Point")); + auto call = std::make_unique(sf(), context, std::move(callee), ExpressionList{}); + EXPECT_THROW(infer_expression_type(call.get()), parsing_error); +} + +TEST_F(ChainedAccessorTypeTest, ErrorArraySubscriptOnNonArray) +{ + // Access p[0] where p is a Point (not an array) → type error + auto index = std::make_unique(sf(), context, PrimitiveType::INT32, 0, 0); + auto expr = std::make_unique(sf(), context, id("p"), std::move(index)); + EXPECT_THROW(infer_expression_type(expr.get()), parsing_error); +} + +// ============================================================================= +// Integration error tests +// ============================================================================= + +TEST(ChainedAccessorErrors, AccessNonExistentMember) +{ + assert_throws_message(R"( + type Point = { x: i32; y: i32; }; + fn main(): void { + const p: Point = Point::{ x: 1, y: 2 }; + const z = p.z; + } + )", + "not found"); +} + +TEST(ChainedAccessorErrors, AccessMemberOnPrimitive) +{ + assert_throws_message(R"( + fn main(): void { + const n: i32 = 42; + const x = n.field; + } + )", + "struct"); +} + +TEST(ChainedAccessorErrors, AccessNonExistentNestedMember) +{ + assert_throws_message(R"( + type Inner = { val: i32; }; + type Outer = { inner: Inner; }; + fn main(): void { + const o: Outer = Outer::{ inner: Inner::{ val: 1 } }; + const bad = o.inner.bad_field; + } + )", + "not found"); +} + +TEST(ChainedAccessorErrors, ArraySubscriptOutOfChain_BadMember) +{ + assert_throws_message(R"( + type Point = { x: i32; }; + fn main(): void { + const pts: Point[] = [ Point::{ x: 1 } ]; + const v = pts[0].nonexistent; + } + )", + "not found"); +} From 08f97617cc62e46208df2e2afad0cc2a05cedf91 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 08:14:28 +0100 Subject: [PATCH 05/30] Refactor object type handling to improve clarity and support generics --- example.sr | 20 +++++-------- .../compiler/include/ast/nodes/expression.h | 6 ++-- packages/compiler/include/ast/nodes/types.h | 12 +++++++- .../nodes/expressions/array_initializer.cpp | 11 ++++--- .../nodes/expressions/object_initializer.cpp | 30 ++++++++++++++----- 5 files changed, 49 insertions(+), 30 deletions(-) diff --git a/example.sr b/example.sr index 6c73bd46..6ec63d7f 100644 --- a/example.sr +++ b/example.sr @@ -17,36 +17,30 @@ type SomeCar = { year: i32; }; -type SomePerson = { +type SomePerson = { name: string; age: i32; - cars: Array; + cars: T; }; fn make_car(make: string, model: string, year: i32): SomeCar { return SomeCar::{ make, model, year }; } -fn make_person(name: string, age: i32, cars: Array): SomePerson { - return SomePerson::{ - name, - age, - cars - }; +fn make_person(name: string, age: i32, cars: Array): SomePerson> { + return SomePerson>::{ name, age, cars }; } fn main(): i32 { + const my_car = make_car("Honda", "Civic", 2018); const me = make_person( "Alice", 30, - Array::{ - length: 1, - data: [make_car("Toyota", "Corolla", 2020)] - } + Array::{ length: 1, data: [my_car] } ); - const my_car = me.cars[0].data[0]; + const my_car = me.cars.data[0]; io::print("I am %s, age %d\n", me.name, me.age); io::print("My car is a %s %s from %d", my_car.make, my_car.model, my_car.year); diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h index 154e9208..6b8d0047 100644 --- a/packages/compiler/include/ast/nodes/expression.h +++ b/packages/compiler/include/ast/nodes/expression.h @@ -791,7 +791,7 @@ namespace stride::ast class AstObjectInitializer : public IAstExpression { - std::string _struct_name; + std::string _object_type_name; std::vector _member_initializers; GenericTypeList _generic_type_arguments; @@ -807,7 +807,7 @@ namespace stride::ast GenericTypeList generic_type_arguments = {} ) : IAstExpression(source, context), - _struct_name(std::move(struct_name)), + _object_type_name(std::move(struct_name)), _member_initializers(std::move(member_initializers)), _generic_type_arguments(std::move(generic_type_arguments)) {} @@ -820,7 +820,7 @@ namespace stride::ast [[nodiscard]] const std::string& get_struct_name() const { - return _struct_name; + return _object_type_name; } [[nodiscard]] diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index dfbff849..f587b250 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -285,7 +285,17 @@ namespace stride::ast std::string get_type_name() override { - return this->get_name(); + if (!this->_generic_types.empty()) + { + std::vector generic_names; + for (const auto& generic : this->_generic_types) + { + generic_names.push_back(generic->get_type_name()); + } + + return std::format("{}<{}>", this->_name, join(generic_names, ", ")); + } + return this->_name; } std::string to_string() override; diff --git a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp index 31c91601..73d68a32 100644 --- a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp @@ -42,15 +42,14 @@ std::unique_ptr stride::ast::parse_array_initializer( } } - const auto& ref_src_pos = reference_token.get_source_fragment(); const auto& last_token_pos = set.peek(-1).get_source_fragment(); // `]` is already consumed, so we peek back at it + const auto position = SourceFragment::combine(reference_token.get_source_fragment(), last_token_pos); + return std::make_unique( - SourceFragment( - ref_src_pos.source, - ref_src_pos.offset, - last_token_pos.offset + last_token_pos.length - ref_src_pos.offset), + position, context, - std::move(elements)); + std::move(elements) + ); } diff --git a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp index 3d400aac..e39a7a74 100644 --- a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp +++ b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp @@ -86,17 +86,33 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ return this->_object_type->clone_as(); } - const auto type_def = this->get_context()->get_type_definition(this->_struct_name); + const auto type_def = this->get_context()->get_type_definition(this->_object_type_name); if (!type_def.has_value()) { throw parsing_error( ErrorType::COMPILATION_ERROR, - std::format("Struct type '{}' is undefined", this->_struct_name), + std::format("Object type '{}' is undefined", this->_object_type_name), this->get_source_fragment() ); } + if (!this->_generic_type_arguments.empty() && !type_def.value()->get_generics_parameters().empty()) + { + if (this->_generic_type_arguments.size() != type_def.value()->get_generics_parameters().size()) + { + throw parsing_error( + ErrorType::TYPE_ERROR, + std::format("Invalid instantiation of object type '{}': expected {} generic arguments, got {}", + this->_object_type_name, + type_def.value()->get_generics_parameters().size(), + this->_generic_type_arguments.size() + ), + this->get_source_fragment() + ); + } + } + if (const auto* object_def = cast_type(type_def.value()->get_type())) { auto resolved_type = instantiate_generic_type(this, const_cast(object_def), type_def.value()); @@ -122,7 +138,7 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ throw parsing_error( ErrorType::COMPILATION_ERROR, - std::format("Type '{}' is not an object", this->_struct_name), + std::format("Type '{}' is not an object", this->_object_type_name), this->get_source_fragment() ); } @@ -189,7 +205,7 @@ void AstObjectInitializer::validate() std::format( "Too {} members found in object '{}': expected {}, got {}", object_members.size() > this->_member_initializers.size() ? "few" : "many", - this->_struct_name, + this->_object_type_name, object_members.size(), this->_member_initializers.size()), this->get_source_fragment() @@ -207,7 +223,7 @@ void AstObjectInitializer::validate() ErrorType::TYPE_ERROR, std::format( "Object type '{}' has no member named '{}'", - this->_struct_name, + this->_object_type_name, field_name ), this->get_source_fragment()); @@ -298,7 +314,7 @@ llvm::Value* AstObjectInitializer::codegen( { throw parsing_error( ErrorType::COMPILATION_ERROR, - std::format("Struct type '{}' is undefined", this->_struct_name), + std::format("Struct type '{}' is undefined", this->_object_type_name), this->get_source_fragment() ); } @@ -396,7 +412,7 @@ std::unique_ptr AstObjectInitializer::clone() return std::make_unique( this->get_source_fragment(), this->get_context(), - this->_struct_name, + this->_object_type_name, std::move(member_initializers), std::move(member_generic_types) ); From 45c7dc7c6453003ac49e4dd32ac6e9c5dd62ddbf Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 09:28:48 +0100 Subject: [PATCH 06/30] Enhance type name retrieval in AstObjectType and AstAliasType for better generic support --- example.sr | 2 +- packages/compiler/include/ast/nodes/types.h | 40 +++++-------------- packages/compiler/src/ast/generics.cpp | 4 +- .../src/ast/nodes/return_statement.cpp | 2 + .../src/ast/nodes/types/alias_type.cpp | 17 +++++++- .../src/ast/nodes/types/object_type.cpp | 35 +++++++++++----- 6 files changed, 56 insertions(+), 44 deletions(-) diff --git a/example.sr b/example.sr index 6ec63d7f..f1f06a86 100644 --- a/example.sr +++ b/example.sr @@ -36,7 +36,7 @@ fn main(): i32 { const me = make_person( "Alice", 30, - Array::{ length: 1, data: [my_car] } + Array::{ length: 1, data: [123] } ); diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h index f587b250..c464f3b0 100644 --- a/packages/compiler/include/ast/nodes/types.h +++ b/packages/compiler/include/ast/nodes/types.h @@ -283,20 +283,7 @@ namespace stride::ast [[nodiscard]] std::unique_ptr clone() override; - std::string get_type_name() override - { - if (!this->_generic_types.empty()) - { - std::vector generic_names; - for (const auto& generic : this->_generic_types) - { - generic_names.push_back(generic->get_type_name()); - } - - return std::format("{}<{}>", this->_name, join(generic_names, ", ")); - } - return this->_name; - } + std::string get_type_name() override; std::string to_string() override; @@ -488,31 +475,22 @@ namespace stride::ast _instantiated_generics(std::move(instantiated_generics)), _type_name(std::move(type_name)) {} - [[nodiscard]] - const GenericTypeList& get_instantiated_generics() const - { - return _instantiated_generics; - } + [[nodiscard]] const GenericTypeList& get_instantiated_generics() const; - [[nodiscard]] - ObjectTypeMemberList get_members() const; + [[nodiscard]] ObjectTypeMemberList get_members() const; - [[nodiscard]] - std::optional get_member_field_type(const std::string& field_name) const; + [[nodiscard]] std::optional get_member_field_type(const std::string& field_name) const; - [[nodiscard]] - std::optional get_member_field_index(const std::string& field_name) const; + [[nodiscard]] std::optional get_member_field_index(const std::string& field_name) const; - [[nodiscard]] - std::string get_type_name() override + [[nodiscard]] const std::string& get_base_name() const { return this->_type_name; } - std::string to_string() override - { - return this->get_type_name(); - } + [[nodiscard]] std::string get_type_name() override; + + std::string to_string() override; [[nodiscard]] bool equals(IAstType* other) override; diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp index f2fbccd9..c590755c 100644 --- a/packages/compiler/src/ast/generics.cpp +++ b/packages/compiler/src/ast/generics.cpp @@ -139,7 +139,7 @@ std::unique_ptr stride::ast::resolve_generics( return std::make_unique( object_type->get_source_fragment(), object_type->get_context(), - object_type->get_type_name(), + object_type->get_base_name(), std::move(resolved_members), object_type->get_flags(), std::move(resolved_generics) @@ -260,7 +260,7 @@ std::unique_ptr stride::ast::instantiate_generic_type( return std::make_unique( type->get_source_fragment(), type->get_context(), - type->get_type_name(), + type->get_base_name(), std::move(resolved_members), type->get_flags(), std::move(resolved_args) diff --git a/packages/compiler/src/ast/nodes/return_statement.cpp b/packages/compiler/src/ast/nodes/return_statement.cpp index a744b83f..1ba8e98c 100644 --- a/packages/compiler/src/ast/nodes/return_statement.cpp +++ b/packages/compiler/src/ast/nodes/return_statement.cpp @@ -7,6 +7,7 @@ #include #include +#include using namespace stride::ast; using namespace stride::ast::definition; @@ -156,6 +157,7 @@ llvm::Value* AstReturnStatement::codegen( } else { + module->print(llvm::errs(), nullptr); throw parsing_error( ErrorType::COMPILATION_ERROR, "Cannot cast return value to function return type", diff --git a/packages/compiler/src/ast/nodes/types/alias_type.cpp b/packages/compiler/src/ast/nodes/types/alias_type.cpp index 7730225d..5e95ccf5 100644 --- a/packages/compiler/src/ast/nodes/types/alias_type.cpp +++ b/packages/compiler/src/ast/nodes/types/alias_type.cpp @@ -98,7 +98,7 @@ static std::unique_ptr resolve_nested_underlying_types(std::unique_ptr return std::make_unique( object_type->get_source_fragment(), object_type->get_context(), - object_type->get_type_name(), + object_type->get_base_name(), std::move(resolved_members), object_type->get_flags(), std::move(resolved_generics) @@ -327,6 +327,21 @@ bool AstAliasType::equals(IAstType* other) return false; } +std::string AstAliasType::get_type_name() +{ + if (!this->_generic_types.empty()) + { + std::vector generic_names; + for (const auto& generic : this->_generic_types) + { + generic_names.push_back(generic->get_type_name()); + } + + return std::format("{}<{}>", this->_name, join(generic_names, ", ")); + } + return this->_name; +} + std::unique_ptr AstAliasType::clone() { GenericTypeList generic_types; diff --git a/packages/compiler/src/ast/nodes/types/object_type.cpp b/packages/compiler/src/ast/nodes/types/object_type.cpp index 37f65c5a..0ad2a501 100644 --- a/packages/compiler/src/ast/nodes/types/object_type.cpp +++ b/packages/compiler/src/ast/nodes/types/object_type.cpp @@ -143,14 +143,7 @@ std::optional AstObjectType::get_member_field_index(const std::string& fiel /// resulting in no LLVM duplication std::string AstObjectType::get_internalized_name() { - std::string scoped_name = resolve_internal_name({ this->get_context()->get_name(), this->get_type_name() }); - - for (const auto& generic : this->_instantiated_generics) - { - scoped_name += "$" + generic->get_type_name(); - } - - return scoped_name; + return this->get_type_name(); } llvm::Type* AstObjectType::get_llvm_type_impl(llvm::Module* module) @@ -205,6 +198,30 @@ llvm::Type* AstObjectType::get_llvm_type_impl(llvm::Module* module) return struct_type; } +const GenericTypeList& AstObjectType::get_instantiated_generics() const +{ + return _instantiated_generics; +} + +std::string AstObjectType::get_type_name() +{ + if (!this->_instantiated_generics.empty()) + { + std::vector generic_names; + for (const auto& generic : this->_instantiated_generics) + { + generic_names.push_back(generic->get_type_name()); + } + return std::format("{}<{}>", this->_type_name, join(generic_names, ", ")); + } + return this->_type_name; +} + +std::string AstObjectType::to_string() +{ + return this->get_type_name(); +} + bool AstObjectType::equals(IAstType* other) { if (const auto other_struct_ty = cast_type(other)) @@ -270,7 +287,7 @@ std::unique_ptr AstObjectType::clone() return std::make_unique( this->get_source_fragment(), this->get_context(), - this->get_type_name(), + this->_type_name, std::move(cloned_members), this->get_flags(), std::move(cloned_generics) From b003daf8da76cc683947702e4ac56d808db54c5f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 09:30:57 +0100 Subject: [PATCH 07/30] Update example --- example.sr | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example.sr b/example.sr index f1f06a86..4527a7b3 100644 --- a/example.sr +++ b/example.sr @@ -8,7 +8,7 @@ enum SomeEnum { type Array = { length: i32; - data: T[]; + data: T []; }; type SomeCar = { @@ -27,8 +27,8 @@ fn make_car(make: string, model: string, year: i32): SomeCar { return SomeCar::{ make, model, year }; } -fn make_person(name: string, age: i32, cars: Array): SomePerson> { - return SomePerson>::{ name, age, cars }; +fn make_person(name: string, age: i32, cars: Array): SomePerson> { + return SomePerson>::{ name, age, cars }; } fn main(): i32 { @@ -36,7 +36,7 @@ fn main(): i32 { const me = make_person( "Alice", 30, - Array::{ length: 1, data: [123] } + Array::{ length: 1, data: [my_car] } ); From e66fe2a74e44beeceed704e77df2a37e4377d07a Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 18:21:02 +0100 Subject: [PATCH 08/30] remove: unused files and documentation, refactor types, enhance import handling and package structure --- example.sr | 39 ++++++++++++------- packages/compiler/compilation-diagram.md | 18 --------- packages/compiler/config-json-schema.json | 7 ---- packages/compiler/include/ast/visitor.h | 4 +- packages/compiler/src/ast/ast.cpp | 2 - .../src/ast/traversal/import_visitor.cpp | 23 +++++++---- packages/standard-library/datastructures.sr | 26 +++++++++++++ packages/standard-library/system/io.sr | 2 +- .../src/main/grammars/Stride.bnf | 2 +- 9 files changed, 69 insertions(+), 54 deletions(-) delete mode 100644 packages/compiler/compilation-diagram.md delete mode 100644 packages/compiler/config-json-schema.json create mode 100644 packages/standard-library/datastructures.sr diff --git a/example.sr b/example.sr index 4527a7b3..8a970f94 100644 --- a/example.sr +++ b/example.sr @@ -1,4 +1,6 @@ -import system::{ io::print }; +import System::{ + io::print +}; enum SomeEnum { Variant1, @@ -8,36 +10,43 @@ enum SomeEnum { type Array = { length: i32; - data: T []; + data: T[]; }; -type SomeCar = { +type Car = { make: string; model: string; year: i32; }; -type SomePerson = { +type SomePerson = { name: string; age: i32; - cars: T; + cars: TSomeCar; }; -fn make_car(make: string, model: string, year: i32): SomeCar { - return SomeCar::{ make, model, year }; +fn make_car(make: string, model: string, year: i32): Car { + return Car::{ + make, + model, + year + }; } -fn make_person(name: string, age: i32, cars: Array): SomePerson> { - return SomePerson>::{ name, age, cars }; +fn make_person(name: string, age: i32, cars: Array): SomePerson> { + return SomePerson>::{ + name, + age, + cars + }; } fn main(): i32 { const my_car = make_car("Honda", "Civic", 2018); - const me = make_person( - "Alice", - 30, - Array::{ length: 1, data: [my_car] } - ); + const me = make_person("Alice", 30, Array::{ + length: 1, + data: [my_car] + }); const my_car = me.cars.data[0]; @@ -47,7 +56,7 @@ fn main(): i32 { // const some_enum_val: SomeEnum = SomeEnum::Variant2; - // io::print("The value of some_enum_val is: %d", some_enum_val); +// io::print("The value of some_enum_val is: %d", some_enum_val); return 0; } \ No newline at end of file diff --git a/packages/compiler/compilation-diagram.md b/packages/compiler/compilation-diagram.md deleted file mode 100644 index 9a1ff026..00000000 --- a/packages/compiler/compilation-diagram.md +++ /dev/null @@ -1,18 +0,0 @@ -```mermaid -flowchart TD - File1 - File2 - File3 - File4 - - Tokenizer - - File1 --> Tokenizer - File2 --> Tokenizer - File3 --> Tokenizer - File4 --> Tokenizer - - Tokenizer -- Convert to AST --> Parser["Parser"] - - Parser --> Validation["Validation - Type resolution takes place here"] -``` \ No newline at end of file diff --git a/packages/compiler/config-json-schema.json b/packages/compiler/config-json-schema.json deleted file mode 100644 index d2fb135b..00000000 --- a/packages/compiler/config-json-schema.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-07/schema", - "type": "object", - "properties": { - - } -} \ No newline at end of file diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h index 28041c37..65008451 100644 --- a/packages/compiler/include/ast/visitor.h +++ b/packages/compiler/include/ast/visitor.h @@ -40,8 +40,8 @@ namespace stride::ast { std::string _current_file_name; // temporary values std::map< - std::string, /* package_name */ - std::string /* file_name */ + std::string, /* package_name */ + std::vector /* file_names */ > _package_file_mapping; std::map< std::string, /* file_name */ diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp index f0e06567..87e5b3df 100644 --- a/packages/compiler/src/ast/ast.cpp +++ b/packages/compiler/src/ast/ast.cpp @@ -127,8 +127,6 @@ std::unique_ptr stride::ast::parse_next_statement( return parse_package_declaration(context, set); case TokenType::KEYWORD_IMPORT: return parse_import_statement(context, set); - case TokenType::KEYWORD_TYPE: - return parse_type_definition(context, set, visibility_modifier); case TokenType::KEYWORD_CONTINUE: return parse_continue_statement(context, set); case TokenType::KEYWORD_BREAK: diff --git a/packages/compiler/src/ast/traversal/import_visitor.cpp b/packages/compiler/src/ast/traversal/import_visitor.cpp index 45079232..b8d31edc 100644 --- a/packages/compiler/src/ast/traversal/import_visitor.cpp +++ b/packages/compiler/src/ast/traversal/import_visitor.cpp @@ -40,7 +40,7 @@ void stride::ast::ImportVisitor::accept(AstImport* node) void stride::ast::ImportVisitor::accept(AstPackage* node) { - this->_package_file_mapping.emplace(node->get_package_name(), this->_current_file_name); + this->_package_file_mapping[node->get_package_name()].push_back(this->_current_file_name); } void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const @@ -62,16 +62,23 @@ void stride::ast::ImportVisitor::cross_register_symbols(Ast* ast) const node->get_source_fragment() ); } - // The Ast node from which we wish to extract the symbols - const auto& file_name_with_exports = this->_package_file_mapping.at(package_name); - const auto& node_with_exports = ast->get_files().at(file_name_with_exports); + // The Ast nodes from which we wish to extract the symbols + const auto& files_with_exports = this->_package_file_mapping.at(package_name); - // Acquire all symbols from `node_with_exports` + // Acquire all symbols from the package's files for (const auto& import_name : import_names) { - // Acquire import from node_with_exports - auto definition = node_with_exports->get_context()->get_definition_by_internal_name(import_name); - if (!definition) + // Search across all files that belong to this package + std::optional> definition; + for (const auto& file_name_with_exports : files_with_exports) + { + const auto& node_with_exports = ast->get_files().at(file_name_with_exports); + definition = node_with_exports->get_context()->get_definition_by_internal_name(import_name); + if (definition.has_value()) + break; + } + + if (!definition.has_value()) { throw parsing_error( ErrorType::REFERENCE_ERROR, diff --git a/packages/standard-library/datastructures.sr b/packages/standard-library/datastructures.sr new file mode 100644 index 00000000..f26efb78 --- /dev/null +++ b/packages/standard-library/datastructures.sr @@ -0,0 +1,26 @@ +package System; + + +module Structures { + + pub type Array = { + length: i32; + data: T []; + }; + + pub type MapEntry = { + key: K; + value: V; + }; + + pub type Map = { + entries: MapEntry[]; + }; + + pub fn arrayOf(members: T[]): Structures::Array { + return Array::{ + length: members.length, + data: members + }; + } +} \ No newline at end of file diff --git a/packages/standard-library/system/io.sr b/packages/standard-library/system/io.sr index da83980c..f7f5cb2c 100644 --- a/packages/standard-library/system/io.sr +++ b/packages/standard-library/system/io.sr @@ -1,4 +1,4 @@ -package system; +package System; module io { diff --git a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf index 445c0907..fd29bcbd 100644 --- a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf +++ b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf @@ -129,7 +129,7 @@ private FunctionDeclarationHeader ::= FN IDENTIFIER [GenericTypeArguments] LPARE FunctionParameterList ::= (FunctionParameter (COMMA FunctionParameter)* [COMMA ELLIPSIS]) | [ELLIPSIS] FunctionParameter ::= IDENTIFIER COLON Type -TypeDefinition ::= KW_TYPE IDENTIFIER [GenericTypeArguments] EQ Type SEMICOLON +TypeDefinition ::= [VisibilityModifier] KW_TYPE IDENTIFIER [GenericTypeArguments] EQ Type SEMICOLON EnumDefinition ::= ENUM IDENTIFIER LBRACE EnumMember (COMMA EnumMember)* [COMMA /* trailing */] RBRACE From 2ee23d12e3ca04c0ab757c922032262b564e542a Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 18:44:28 +0100 Subject: [PATCH 09/30] Added folding support to stride intellij plugin --- example.sr | 4 +- .../compiler/include/ast/nodes/enumerables.h | 2 + .../compiler/src/ast/nodes/enumerables.cpp | 4 ++ .../src/main/grammars/Stride.bnf | 4 +- .../intellij/editor/StrideFoldingBuilder.kt | 53 +++++++++++++++++++ .../src/main/resources/META-INF/plugin.xml | 3 ++ 6 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt diff --git a/example.sr b/example.sr index 8a970f94..2274ae5c 100644 --- a/example.sr +++ b/example.sr @@ -54,9 +54,9 @@ fn main(): i32 { io::print("I am %s, age %d\n", me.name, me.age); io::print("My car is a %s %s from %d", my_car.make, my_car.model, my_car.year); - // const some_enum_val: SomeEnum = SomeEnum::Variant2; + const some_enum_val: SomeEnum = SomeEnum::Variant2; -// io::print("The value of some_enum_val is: %d", some_enum_val); + io::print("The value of some_enum_val is: %d", some_enum_val); return 0; } \ No newline at end of file diff --git a/packages/compiler/include/ast/nodes/enumerables.h b/packages/compiler/include/ast/nodes/enumerables.h index 480fdc33..db2cafa9 100644 --- a/packages/compiler/include/ast/nodes/enumerables.h +++ b/packages/compiler/include/ast/nodes/enumerables.h @@ -86,6 +86,8 @@ namespace stride::ast } std::unique_ptr clone() override; + + void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override; }; std::unique_ptr parse_enumerable_member( diff --git a/packages/compiler/src/ast/nodes/enumerables.cpp b/packages/compiler/src/ast/nodes/enumerables.cpp index 8117b70f..07bfab32 100644 --- a/packages/compiler/src/ast/nodes/enumerables.cpp +++ b/packages/compiler/src/ast/nodes/enumerables.cpp @@ -111,6 +111,10 @@ std::unique_ptr stride::ast::parse_enumerable_declaration( ); } +void AstEnumerable::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) +{ +} + std::unique_ptr AstEnumerable::clone() { std::vector> cloned_members; diff --git a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf index fd29bcbd..61a02a12 100644 --- a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf +++ b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf @@ -146,7 +146,7 @@ VariableDeclaration ::= (CONST | LET) IDENTIFIER COLON Type [EQ Expression] VariableDeclarationStatement ::= [VisibilityModifier] (VariableDeclaration | VariableDeclarationInferredType) SEMICOLON // --- Types --- -Type ::= (PrimitiveType | UserType | FunctionType | StructDefinitionBody | TupleType) [ArrayNotion*] [QUESTION] +Type ::= (PrimitiveType | UserType | FunctionType | ObjectDefinitionBody | TupleType) [ArrayNotion*] [QUESTION] // :: or PrimitiveType ::= VOID | INT8 | INT16 | INT32 | INT64 | UINT8 | UINT16 | UINT32 | UINT64 | FLOAT32 | FLOAT64 | BOOL | CHAR | STRING UserType ::= ScopedIdentifier [GenericTypeValueArguments] { @@ -158,7 +158,7 @@ UserType ::= ScopedIdentifier [GenericTypeValueArguments] { TupleType ::= LPAREN Type (COMMA Type)* RPAREN -private StructDefinitionBody ::= LBRACE [ObjectDefinitionFields] RBRACE +ObjectDefinitionBody ::= LBRACE [ObjectDefinitionFields] RBRACE GenericTypeArguments ::= LT IDENTIFIER (COMMA IDENTIFIER)* GT GenericTypeValueArguments ::= LT Type (COMMA Type)* GT diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt new file mode 100644 index 00000000..37a850a9 --- /dev/null +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt @@ -0,0 +1,53 @@ +package com.stride.intellij.editor + +import com.intellij.lang.ASTNode +import com.intellij.lang.folding.FoldingBuilderEx +import com.intellij.lang.folding.FoldingDescriptor +import com.intellij.openapi.editor.Document +import com.intellij.openapi.util.TextRange +import com.intellij.psi.PsiElement +import com.intellij.psi.util.PsiTreeUtil +import com.stride.intellij.psi.StrideArrayDeclarationExpression +import com.stride.intellij.psi.StrideBlockStatement +import com.stride.intellij.psi.StrideObjectDefinitionBody + +class StrideFoldingBuilder : FoldingBuilderEx() { + + override fun buildFoldRegions(root: PsiElement, document: Document, quick: Boolean): Array { + val descriptors = mutableListOf() + + PsiTreeUtil.processElements(root) { element -> + when (element) { + is StrideBlockStatement -> { + if (element.textLength > 2) { + descriptors.add(FoldingDescriptor(element.node, element.textRange)) + } + } + is StrideObjectDefinitionBody -> { + if (element.textLength > 2) { + descriptors.add(FoldingDescriptor(element.node, element.textRange)) + } + } + is StrideArrayDeclarationExpression -> { + if (element.textLength > 2) { + descriptors.add(FoldingDescriptor(element.node, element.textRange)) + } + } + } + true + } + + return descriptors.toTypedArray() + } + + override fun getPlaceholderText(node: ASTNode): String { + return when (node.psi) { + is StrideBlockStatement -> "{...}" + is StrideObjectDefinitionBody -> "{...}" + is StrideArrayDeclarationExpression -> "[...]" + else -> "..." + } + } + + override fun isCollapsedByDefault(node: ASTNode): Boolean = false +} diff --git a/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml b/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml index b20a911e..467674aa 100644 --- a/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml +++ b/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml @@ -27,6 +27,9 @@ + + From 4acd08b16f339b313ecc1591ea94229fe1d49322 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 19:17:50 +0100 Subject: [PATCH 10/30] Updated folding implementation for enumerables + file headers --- example.sr | 1 + .../src/main/grammars/Stride.bnf | 9 ++++-- .../intellij/editor/StrideFoldingBuilder.kt | 30 +++++++++++++++---- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/example.sr b/example.sr index 2274ae5c..5cde7736 100644 --- a/example.sr +++ b/example.sr @@ -1,3 +1,4 @@ +package somePkg; import System::{ io::print }; diff --git a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf index 61a02a12..d2b7be84 100644 --- a/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf +++ b/packages/stride-plugin-intellij/src/main/grammars/Stride.bnf @@ -106,7 +106,9 @@ ] } -StrideFile ::= [PackageStatement] ImportStatement* StandaloneItem* +StrideFile ::= [FileHeader] StandaloneItem* + +FileHeader ::= (PackageStatement ImportStatement*) | ImportStatement+ private StandaloneItem ::= ( FunctionDeclaration @@ -119,7 +121,7 @@ private StandaloneItem ::= ( ) // import some::package::{ module::TypeA, TypeB, function_c}; -ImportStatement ::= IMPORT ScopedIdentifier COLON_COLON LBRACE ScopedIdentifier (COMMA ScopedIdentifier)* RBRACE SEMICOLON +ImportStatement ::= IMPORT ScopedIdentifier COLON_COLON LBRACE ScopedIdentifier (COMMA ScopedIdentifier)* [COMMA /* trailing */] RBRACE SEMICOLON // --- Declarations --- ExternFunctionDeclaration ::= [VisibilityModifier] [ASYNC] EXTERN FunctionDeclarationHeader SEMICOLON @@ -131,7 +133,8 @@ FunctionParameter ::= IDENTIFIER COLON Type TypeDefinition ::= [VisibilityModifier] KW_TYPE IDENTIFIER [GenericTypeArguments] EQ Type SEMICOLON -EnumDefinition ::= ENUM IDENTIFIER LBRACE EnumMember (COMMA EnumMember)* [COMMA /* trailing */] RBRACE +EnumDefinition ::= ENUM IDENTIFIER EnumBody +EnumBody ::= LBRACE EnumMember (COMMA EnumMember)* [COMMA /* trailing */] RBRACE EnumMember ::= IDENTIFIER | (IDENTIFIER EQ Expression) diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt index 37a850a9..0b04450d 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideFoldingBuilder.kt @@ -10,6 +10,9 @@ import com.intellij.psi.util.PsiTreeUtil import com.stride.intellij.psi.StrideArrayDeclarationExpression import com.stride.intellij.psi.StrideBlockStatement import com.stride.intellij.psi.StrideObjectDefinitionBody +import com.stride.intellij.psi.StrideEnumBody +import com.stride.intellij.psi.StrideFileHeader +import com.stride.intellij.psi.StrideTypes class StrideFoldingBuilder : FoldingBuilderEx() { @@ -18,11 +21,21 @@ class StrideFoldingBuilder : FoldingBuilderEx() { PsiTreeUtil.processElements(root) { element -> when (element) { + is StrideFileHeader -> { + if (element.textLength > 0) { + descriptors.add(FoldingDescriptor(element.node, element.textRange)) + } + } is StrideBlockStatement -> { if (element.textLength > 2) { descriptors.add(FoldingDescriptor(element.node, element.textRange)) } } + is StrideEnumBody -> { + if (element.textLength > 2) { + descriptors.add(FoldingDescriptor(element.node, element.textRange)) + } + } is StrideObjectDefinitionBody -> { if (element.textLength > 2) { descriptors.add(FoldingDescriptor(element.node, element.textRange)) @@ -41,13 +54,20 @@ class StrideFoldingBuilder : FoldingBuilderEx() { } override fun getPlaceholderText(node: ASTNode): String { - return when (node.psi) { - is StrideBlockStatement -> "{...}" - is StrideObjectDefinitionBody -> "{...}" - is StrideArrayDeclarationExpression -> "[...]" + return when { + node.psi is StrideFileHeader -> "..." + node.psi is StrideBlockStatement -> "{...}" + node.psi is StrideObjectDefinitionBody -> "{...}" + node.psi is StrideEnumBody -> "{...}" + node.psi is StrideArrayDeclarationExpression -> "[...]" else -> "..." } } - override fun isCollapsedByDefault(node: ASTNode): Boolean = false + override fun isCollapsedByDefault(node: ASTNode): Boolean { + return when { + node.psi is StrideFileHeader -> true + else -> false + } + } } From d874c5d01998c68ad077da8a1cf129da8c6a5fa4 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Thu, 12 Mar 2026 19:44:52 +0100 Subject: [PATCH 11/30] Rename scripts and enhance test execution output for better clarity --- .../kotlin/com/stride/intellij/StrideBlock.kt | 6 ++--- .../intellij/StrideFormattingModelBuilder.kt | 10 ++++----- ...h_to_homebrew.sh => brew-install-local.sh} | 0 scripts/run-tests.sh | 22 +++++++++++++++++++ scripts/run_tests.sh | 2 -- 5 files changed, 28 insertions(+), 12 deletions(-) rename scripts/{publish_to_homebrew.sh => brew-install-local.sh} (100%) create mode 100755 scripts/run-tests.sh delete mode 100755 scripts/run_tests.sh diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt index 9ca30665..ce20174c 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt @@ -52,12 +52,10 @@ class StrideBlock( Indent.getNoneIndent() } } - // Struct definition fields (within TYPE that contains a struct body) - parentType == StrideTypes.TYPE -> { + // Struct definition fields (within OBJECT_DEFINITION_BODY) + parentType == StrideTypes.OBJECT_DEFINITION_BODY -> { if (elementType == StrideTypes.OBJECT_DEFINITION_FIELDS) { Indent.getNormalIndent() - } else if (elementType != StrideTypes.LBRACE && elementType != StrideTypes.RBRACE) { - Indent.getNoneIndent() } else { Indent.getNoneIndent() } diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt index ddcd5f13..c683d9f5 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt @@ -85,8 +85,8 @@ class StrideFormattingModelBuilder : FormattingModelBuilder { // Array notation: T[] .betweenInside(StrideTypes.LSQUARE_BRACKET, StrideTypes.RSQUARE_BRACKET, StrideTypes.TYPE).spaceIf(false) - .afterInside(StrideTypes.LBRACE, StrideTypes.TYPE).lineBreakInCode() - .beforeInside(StrideTypes.RBRACE, StrideTypes.TYPE).lineBreakInCode() + .afterInside(StrideTypes.LBRACE, StrideTypes.OBJECT_DEFINITION_BODY).lineBreakInCode() + .beforeInside(StrideTypes.RBRACE, StrideTypes.OBJECT_DEFINITION_BODY).lineBreakInCode() .afterInside(StrideTypes.LBRACE, StrideTypes.BLOCK_STATEMENT).lineBreakInCode() .beforeInside(StrideTypes.RBRACE, StrideTypes.BLOCK_STATEMENT).lineBreakInCode() .afterInside(StrideTypes.LBRACE, StrideTypes.MODULE_STATEMENT).lineBreakInCode() @@ -101,7 +101,7 @@ class StrideFormattingModelBuilder : FormattingModelBuilder { .afterInside(StrideTypes.LPAREN, StrideTypes.FOR).spaceIf(false) .beforeInside(StrideTypes.LBRACE, StrideTypes.MODULE_STATEMENT).spaceIf(true) .beforeInside(StrideTypes.LBRACE, StrideTypes.BLOCK_STATEMENT).spaceIf(true) - .beforeInside(StrideTypes.LBRACE, StrideTypes.TYPE).spaceIf(true) + .beforeInside(StrideTypes.LBRACE, StrideTypes.OBJECT_DEFINITION_BODY).spaceIf(true) .beforeInside(StrideTypes.LBRACE, StrideTypes.OBJECT_INITIALIZATION).spaceIf(false) // Module items should be on separate lines @@ -116,9 +116,7 @@ class StrideFormattingModelBuilder : FormattingModelBuilder { // Statements should be on separate lines - use STATEMENT wrapper .afterInside(StrideTypes.STATEMENT, StrideTypes.BLOCK_STATEMENT).lineBreakInCode() - .afterInside(StrideTypes.LBRACE, StrideTypes.OBJECT_INITIALIZATION).lineBreakInCode() - .beforeInside(StrideTypes.RBRACE, StrideTypes.OBJECT_INITIALIZATION).lineBreakInCode() - .beforeInside(StrideTypes.OBJECT_INIT_FIELD, StrideTypes.OBJECT_INIT_FIELDS).lineBreakInCode() + .between(StrideTypes.OBJECT_FIELD, StrideTypes.OBJECT_FIELD).lineBreakInCode() .between(StrideTypes.TYPE_DEFINITION, StrideTypes.TYPE_DEFINITION).blankLines(1) .between(StrideTypes.FUNCTION_DECLARATION, StrideTypes.FUNCTION_DECLARATION).blankLines(1) diff --git a/scripts/publish_to_homebrew.sh b/scripts/brew-install-local.sh similarity index 100% rename from scripts/publish_to_homebrew.sh rename to scripts/brew-install-local.sh diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh new file mode 100755 index 00000000..b4cb1bb5 --- /dev/null +++ b/scripts/run-tests.sh @@ -0,0 +1,22 @@ +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PROJECT_ROOT="$( cd "$SCRIPT_DIR/.." && pwd )" + +BLUE="\033[34m" +WHITE="\033[37m" +RESET="\033[0m" + +printf "${BLUE}+---------------------------------------+${RESET}\n" +printf "${BLUE}| |${RESET}\n" +printf "${BLUE}| ${WHITE}Running tests for Stride compiler... ${BLUE}|${RESET}\n" +printf "${BLUE}| |${RESET}\n" +printf "${BLUE}+---------------------------------------+${RESET}\n" + +cmake --build "${PROJECT_ROOT}/packages/compiler/cmake-build-debug" --target cstride_tests && "${PROJECT_ROOT}/packages/compiler/cmake-build-debug/tests/cstride_tests" + +printf "${BLUE}+----------------------------------------------+${RESET}\n" +printf "${BLUE}| |${RESET}\n" +printf "${BLUE}| ${WHITE}Running tests for stride-plugin-intellij... ${BLUE}|${RESET}\n" +printf "${BLUE}| |${RESET}\n" +printf "${BLUE}+----------------------------------------------+${RESET}\n" + +cd "${PROJECT_ROOT}/packages/stride-plugin-intellij"; ./gradlew :test diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh deleted file mode 100755 index e8bc03e5..00000000 --- a/scripts/run_tests.sh +++ /dev/null @@ -1,2 +0,0 @@ -cmake --build ./packages/compiler/cmake-build-debug --target cstride_tests && -./packages/compiler/cmake-build-debug/cstride_tests \ No newline at end of file From 8ac40943a81ae1143b7b4afe20553cb32c9c4b3f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Fri, 13 Mar 2026 12:15:22 +0100 Subject: [PATCH 12/30] Enhance formatting support for object initialization and function calls, including dynamic wrapping based on line width settings --- example.sr | 6 +- .../kotlin/com/stride/intellij/StrideBlock.kt | 131 ++++++++++++++++-- .../intellij/StrideFormattingModelBuilder.kt | 11 +- .../editor/StrideEnterHandlerDelegate.kt | 7 +- .../highlight/StrideColorSettingsPage.kt | 4 +- .../src/main/resources/META-INF/plugin.xml | 2 + .../stride/intellij/StrideFormatterTest.kt | 4 +- .../intellij/StrideLineWidthFormatterTest.kt | 123 ++++++++++++++++ .../intellij/StrideStructInitFormatterTest.kt | 10 +- 9 files changed, 263 insertions(+), 35 deletions(-) create mode 100644 packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideLineWidthFormatterTest.kt diff --git a/example.sr b/example.sr index 5cde7736..603001ce 100644 --- a/example.sr +++ b/example.sr @@ -35,11 +35,7 @@ fn make_car(make: string, model: string, year: i32): Car { } fn make_person(name: string, age: i32, cars: Array): SomePerson> { - return SomePerson>::{ - name, - age, - cars - }; + return SomePerson>::{ name, age, cars }; } fn main(): i32 { diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt index ce20174c..d32d3130 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideBlock.kt @@ -5,7 +5,6 @@ import com.intellij.lang.ASTNode import com.intellij.psi.TokenType import com.intellij.psi.codeStyle.CodeStyleSettings import com.intellij.psi.formatter.common.AbstractBlock -import com.intellij.psi.tree.TokenSet import com.stride.intellij.psi.StrideTypes class StrideBlock( @@ -13,14 +12,15 @@ class StrideBlock( private val wrap: Wrap?, private val alignment: Alignment?, private val settings: CodeStyleSettings, - private val spacingBuilder: SpacingBuilder + private val spacingBuilder: SpacingBuilder, + private val strideSettings: StrideCodeStyleSettings ) : AbstractBlock(node, wrap, alignment) { override fun buildChildren(): List { val blocks = mutableListOf() var child = myNode.firstChildNode while (child != null) { if (child.elementType != TokenType.WHITE_SPACE) { - blocks.add(StrideBlock(child, null, null, settings, spacingBuilder)) + blocks.add(StrideBlock(child, null, null, settings, spacingBuilder, strideSettings)) } child = child.treeNext } @@ -33,7 +33,6 @@ class StrideBlock( val parentType = parent.elementType return when { - // Block statements (function bodies, control flow blocks) parentType == StrideTypes.BLOCK_STATEMENT -> { if (elementType != StrideTypes.LBRACE && elementType != StrideTypes.RBRACE) { Indent.getNormalIndent() @@ -52,7 +51,6 @@ class StrideBlock( Indent.getNoneIndent() } } - // Struct definition fields (within OBJECT_DEFINITION_BODY) parentType == StrideTypes.OBJECT_DEFINITION_BODY -> { if (elementType == StrideTypes.OBJECT_DEFINITION_FIELDS) { Indent.getNormalIndent() @@ -60,13 +58,12 @@ class StrideBlock( Indent.getNoneIndent() } } - // Individual struct fields parentType == StrideTypes.OBJECT_DEFINITION_FIELDS -> { Indent.getNoneIndent() } - // Struct initialization fields + // Object initialization: indent fields when wrapping parentType == StrideTypes.OBJECT_INITIALIZATION -> { - if (elementType == StrideTypes.OBJECT_INIT_FIELDS) { + if (elementType == StrideTypes.OBJECT_INIT_FIELDS && shouldWrap(parent)) { Indent.getNormalIndent() } else { Indent.getNoneIndent() @@ -75,13 +72,131 @@ class StrideBlock( parentType == StrideTypes.OBJECT_INIT_FIELDS -> { Indent.getNoneIndent() } + // Function call: indent argument list when wrapping + parentType == StrideTypes.FUNCTION_CALL_EXPRESSION -> { + if (elementType == StrideTypes.ARGUMENT_LIST && shouldWrap(parent)) { + Indent.getNormalIndent() + } else { + Indent.getNoneIndent() + } + } + parentType == StrideTypes.ARGUMENT_LIST -> { + Indent.getNoneIndent() + } else -> Indent.getNoneIndent() } } override fun getSpacing(child1: Block?, child2: Block): Spacing? { + if (child1 == null) return spacingBuilder.getSpacing(this, child1, child2) + + val child1Node = (child1 as? StrideBlock)?.node ?: return spacingBuilder.getSpacing(this, child1, child2) + val child2Node = (child2 as? StrideBlock)?.node ?: return spacingBuilder.getSpacing(this, child1, child2) + val parentType = myNode.elementType + + // Dynamic wrapping for object initialization + if (parentType == StrideTypes.OBJECT_INITIALIZATION) { + val wrap = shouldWrap(myNode) + val child1Type = child1Node.elementType + val child2Type = child2Node.elementType + + if (wrap) { + // Multi-line: line breaks after { and before } + if (child1Type == StrideTypes.LBRACE && child2Type == StrideTypes.OBJECT_INIT_FIELDS) { + return lineBreakSpacing() + } + if (child1Type == StrideTypes.OBJECT_INIT_FIELDS && child2Type == StrideTypes.RBRACE) { + return lineBreakSpacing() + } + // Empty init: no break between { and } + if (child1Type == StrideTypes.LBRACE && child2Type == StrideTypes.RBRACE) { + return noSpacing() + } + } else { + // Single-line: space after { and before } + if (child1Type == StrideTypes.LBRACE && child2Type == StrideTypes.OBJECT_INIT_FIELDS) { + return singleSpacing() + } + if (child1Type == StrideTypes.OBJECT_INIT_FIELDS && child2Type == StrideTypes.RBRACE) { + return singleSpacing() + } + } + } + + // Dynamic wrapping for commas inside object init fields + if (parentType == StrideTypes.OBJECT_INIT_FIELDS) { + val child1Type = child1Node.elementType + + // Find the OBJECT_INITIALIZATION ancestor + val objInitNode = myNode.treeParent + if (objInitNode != null && objInitNode.elementType == StrideTypes.OBJECT_INITIALIZATION) { + val wrap = shouldWrap(objInitNode) + if (wrap && child1Type == StrideTypes.COMMA) { + return lineBreakSpacing() + } + // Single-line comma spacing is handled by the SpacingBuilder + } + } + + // Dynamic wrapping for function calls + if (parentType == StrideTypes.FUNCTION_CALL_EXPRESSION) { + val wrap = shouldWrap(myNode) + val child1Type = child1Node.elementType + val child2Type = child2Node.elementType + + if (wrap) { + // Multi-line: line breaks after ( and before ) + if (child1Type == StrideTypes.LPAREN && child2Type == StrideTypes.ARGUMENT_LIST) { + return lineBreakSpacing() + } + if (child1Type == StrideTypes.ARGUMENT_LIST && child2Type == StrideTypes.RPAREN) { + return lineBreakSpacing() + } + } + // Single-line: default spacing from SpacingBuilder handles this + } + + // Dynamic wrapping for commas inside argument lists + if (parentType == StrideTypes.ARGUMENT_LIST) { + val child1Type = child1Node.elementType + + val funcCallNode = myNode.treeParent + if (funcCallNode != null && funcCallNode.elementType == StrideTypes.FUNCTION_CALL_EXPRESSION) { + val wrap = shouldWrap(funcCallNode) + if (wrap && child1Type == StrideTypes.COMMA) { + return lineBreakSpacing() + } + } + } + return spacingBuilder.getSpacing(this, child1, child2) } override fun isLeaf(): Boolean = myNode.firstChildNode == null + + private fun shouldWrap(node: ASTNode): Boolean { + val maxWidth = strideSettings.MAX_LINE_WIDTH + val textLen = computeSingleLineLength(node) + return textLen > maxWidth + } + + companion object { + fun computeSingleLineLength(node: ASTNode): Int { + // Compute what the text length would be if rendered on a single line + // with standard spacing (one space after commas, spaces around braces, etc.) + return node.text.replace(Regex("\\s+"), " ").trim().length + } + + private fun lineBreakSpacing(): Spacing { + return Spacing.createSpacing(0, 0, 1, true, 0) + } + + private fun singleSpacing(): Spacing { + return Spacing.createSpacing(1, 1, 0, false, 0) + } + + private fun noSpacing(): Spacing { + return Spacing.createSpacing(0, 0, 0, false, 0) + } + } } diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt index c683d9f5..aece95c0 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideFormattingModelBuilder.kt @@ -8,10 +8,11 @@ import com.stride.intellij.psi.StrideTypes class StrideFormattingModelBuilder : FormattingModelBuilder { override fun createModel(element: PsiElement, settings: CodeStyleSettings): FormattingModel { + val strideSettings = settings.getCustomSettings(StrideCodeStyleSettings::class.java) val spacingBuilder = createSpaceBuilder(settings) return FormattingModelProvider.createFormattingModelForPsiFile( element.containingFile, - StrideBlock(element.node, null, null, settings, spacingBuilder), + StrideBlock(element.node, null, null, settings, spacingBuilder, strideSettings), settings ) } @@ -74,10 +75,6 @@ class StrideFormattingModelBuilder : FormattingModelBuilder { .after(StrideTypes.COLON).spaceIf(true) .before(StrideTypes.COLON).spaceIf(false) - // Specific comma rules must come before the general .after(COMMA) rule - // because SpacingBuilder uses first-match semantics - .afterInside(StrideTypes.COMMA, StrideTypes.OBJECT_INIT_FIELDS).lineBreakInCode() - .after(StrideTypes.COMMA).spaceIf(true) .before(StrideTypes.COMMA).spaceIf(false) .before(StrideTypes.SEMICOLON).spaceIf(false) @@ -91,8 +88,8 @@ class StrideFormattingModelBuilder : FormattingModelBuilder { .beforeInside(StrideTypes.RBRACE, StrideTypes.BLOCK_STATEMENT).lineBreakInCode() .afterInside(StrideTypes.LBRACE, StrideTypes.MODULE_STATEMENT).lineBreakInCode() .beforeInside(StrideTypes.RBRACE, StrideTypes.MODULE_STATEMENT).lineBreakInCode() - .afterInside(StrideTypes.LBRACE, StrideTypes.OBJECT_INITIALIZATION).lineBreakInCode() - .beforeInside(StrideTypes.RBRACE, StrideTypes.OBJECT_INITIALIZATION).lineBreakInCode() + // Object initialization brace/comma wrapping is handled dynamically in StrideBlock + // based on line width settings .afterInside(StrideTypes.SCOPED_IDENTIFIER, StrideTypes.MODULE_STATEMENT).spaceIf(true) .afterInside(StrideTypes.COLON_COLON, StrideTypes.OBJECT_INITIALIZATION).spaceIf(false) diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideEnterHandlerDelegate.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideEnterHandlerDelegate.kt index 531444d3..b681492e 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideEnterHandlerDelegate.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/editor/StrideEnterHandlerDelegate.kt @@ -9,6 +9,8 @@ import com.intellij.openapi.util.Ref import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiFile +import com.intellij.psi.codeStyle.CodeStyleSettingsManager +import com.stride.intellij.StrideLanguage import com.stride.intellij.psi.StrideFile class StrideEnterHandlerDelegate : EnterHandlerDelegate { @@ -34,6 +36,9 @@ class StrideEnterHandlerDelegate : EnterHandlerDelegate { val project = editor.project ?: return Result.Continue val document = editor.document + val indentSize = CodeStyleSettingsManager.getSettings(project) + .getCommonSettings(StrideLanguage).indentOptions?.INDENT_SIZE ?: 4 + val indentString = " ".repeat(indentSize) // Commit document to ensure PSI is up to date PsiDocumentManager.getInstance(project).commitDocument(document) @@ -63,7 +68,7 @@ class StrideEnterHandlerDelegate : EnterHandlerDelegate { // Calculate target indentation val targetIndent = if (shouldIndent) { - baseIndent + " " // Add 4 spaces for one indent level + baseIndent + indentString } else { baseIndent } diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideColorSettingsPage.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideColorSettingsPage.kt index b8fb66d2..b1797b6f 100644 --- a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideColorSettingsPage.kt +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/highlight/StrideColorSettingsPage.kt @@ -27,8 +27,8 @@ class StrideColorSettingsPage : ColorSettingsPage { Time::Sleep }; - type Array<T> = T[]; - type IArray = Array<i32>; + type Array<T>; = T[]; + type IArray = Array; /** * Prints the given string and sleeps for the specified duration. diff --git a/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml b/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml index 467674aa..1be5ae1e 100644 --- a/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml +++ b/packages/stride-plugin-intellij/src/main/resources/META-INF/plugin.xml @@ -31,6 +31,8 @@ implementationClass="com.stride.intellij.editor.StrideFoldingBuilder"/> + + diff --git a/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideFormatterTest.kt b/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideFormatterTest.kt index a0d4e2f8..51bd3937 100644 --- a/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideFormatterTest.kt +++ b/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideFormatterTest.kt @@ -17,9 +17,7 @@ class StrideFormatterTest : BasePlatformTestCase() { type Second = First; fn main(): void { - const s: Second = First::{ - member: 123 - }; + const s: Second = First::{ member: 123 }; } """.trimIndent() myFixture.configureByText("test.sr", before) diff --git a/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideLineWidthFormatterTest.kt b/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideLineWidthFormatterTest.kt new file mode 100644 index 00000000..59fecef7 --- /dev/null +++ b/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideLineWidthFormatterTest.kt @@ -0,0 +1,123 @@ +package com.stride.intellij + +import com.intellij.psi.codeStyle.CodeStyleSettingsManager +import com.intellij.testFramework.fixtures.BasePlatformTestCase + +class StrideLineWidthFormatterTest : BasePlatformTestCase() { + + private fun getStrideSettings(): StrideCodeStyleSettings { + val settings = CodeStyleSettingsManager.getSettings(project) + return settings.getCustomSettings(StrideCodeStyleSettings::class.java) + } + + private fun setIndentSize(size: Int) { + val settings = CodeStyleSettingsManager.getSettings(project) + val indentOptions = settings.getCommonSettings(StrideLanguage).indentOptions + indentOptions?.INDENT_SIZE = size + } + + fun testObjectInitFitsOnOneLine() { + getStrideSettings().MAX_LINE_WIDTH = 120 + val before = """ + const obj = SomeObject::{ field1, field2, field3 }; + """.trimIndent() + val after = """ + const obj = SomeObject::{ field1, field2, field3 }; + """.trimIndent() + myFixture.configureByText("test.sr", before) + myFixture.performEditorAction("ReformatCode") + myFixture.checkResult(after) + } + + fun testObjectInitExceedsLineWidth() { + getStrideSettings().MAX_LINE_WIDTH = 40 + val before = """ + const obj = SomeObject::{ field1, field2, field3, field4, field5 }; + """.trimIndent() + val after = """ + const obj = SomeObject::{ + field1, + field2, + field3, + field4, + field5 + }; + """.trimIndent() + myFixture.configureByText("test.sr", before) + myFixture.performEditorAction("ReformatCode") + myFixture.checkResult(after) + } + + fun testFunctionCallFitsOnOneLine() { + getStrideSettings().MAX_LINE_WIDTH = 120 + val before = """ + fn main(): void { + some::function_call(param1, param2, param3); + } + """.trimIndent() + val after = """ + fn main(): void { + some::function_call(param1, param2, param3); + } + """.trimIndent() + myFixture.configureByText("test.sr", before) + myFixture.performEditorAction("ReformatCode") + myFixture.checkResult(after) + } + + fun testFunctionCallExceedsLineWidth() { + getStrideSettings().MAX_LINE_WIDTH = 30 + val before = """ + fn main(): void { + some::function_call(param1, param2, param3); + } + """.trimIndent() + val after = """ + fn main(): void { + some::function_call( + param1, + param2, + param3 + ); + } + """.trimIndent() + myFixture.configureByText("test.sr", before) + myFixture.performEditorAction("ReformatCode") + myFixture.checkResult(after) + } + + fun testNestedObjectInitMixedWrapping() { + getStrideSettings().MAX_LINE_WIDTH = 60 + val before = """ + const player = Entity::{ id: 1, name: "Player1", position: Vec3::{ x: 10, y: 20, z: 30 } }; + """.trimIndent() + val after = """ + const player = Entity::{ + id: 1, + name: "Player1", + position: Vec3::{ x: 10, y: 20, z: 30 } + }; + """.trimIndent() + myFixture.configureByText("test.sr", before) + myFixture.performEditorAction("ReformatCode") + myFixture.checkResult(after) + } + + fun testCustomIndentSize() { + getStrideSettings().MAX_LINE_WIDTH = 30 + setIndentSize(2) + val before = """ + const obj = SomeObject::{ field1, field2, field3 }; + """.trimIndent() + val after = """ + const obj = SomeObject::{ + field1, + field2, + field3 + }; + """.trimIndent() + myFixture.configureByText("test.sr", before) + myFixture.performEditorAction("ReformatCode") + myFixture.checkResult(after) + } +} diff --git a/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideStructInitFormatterTest.kt b/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideStructInitFormatterTest.kt index ef22e3bb..cb6bd070 100644 --- a/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideStructInitFormatterTest.kt +++ b/packages/stride-plugin-intellij/src/test/kotlin/com/stride/intellij/StrideStructInitFormatterTest.kt @@ -8,15 +8,7 @@ class StrideStructInitFormatterTest : BasePlatformTestCase() { const player = Entity::{ id: 1, name: "Player1", position: Vec3::{ x: 10, y: 20, z: 30 }, }; """.trimIndent() val after = """ - const player = Entity::{ - id: 1, - name: "Player1", - position: Vec3::{ - x: 10, - y: 20, - z: 30 - }, - }; + const player = Entity::{ id: 1, name: "Player1", position: Vec3::{ x: 10, y: 20, z: 30 }, }; """.trimIndent() myFixture.configureByText("test.sr", before) myFixture.performEditorAction("ReformatCode") From 15c542246b593889827ad2cb0338b7558802e31f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Fri, 13 Mar 2026 12:15:31 +0100 Subject: [PATCH 13/30] Add custom code style settings for Stride language support --- .../intellij/StrideCodeStyleSettings.kt | 11 +++++ .../StrideCodeStyleSettingsProvider.kt | 33 +++++++++++++ ...StrideLanguageCodeStyleSettingsProvider.kt | 47 +++++++++++++++++++ 3 files changed, 91 insertions(+) create mode 100644 packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettings.kt create mode 100644 packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettingsProvider.kt create mode 100644 packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideLanguageCodeStyleSettingsProvider.kt diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettings.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettings.kt new file mode 100644 index 00000000..83f30a41 --- /dev/null +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettings.kt @@ -0,0 +1,11 @@ +package com.stride.intellij + +import com.intellij.psi.codeStyle.CodeStyleSettings +import com.intellij.psi.codeStyle.CustomCodeStyleSettings + +class StrideCodeStyleSettings(container: CodeStyleSettings) : + CustomCodeStyleSettings("StrideCodeStyleSettings", container) { + + @JvmField + var MAX_LINE_WIDTH: Int = 120 +} diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettingsProvider.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettingsProvider.kt new file mode 100644 index 00000000..cf8db7d2 --- /dev/null +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideCodeStyleSettingsProvider.kt @@ -0,0 +1,33 @@ +package com.stride.intellij + +import com.intellij.application.options.CodeStyleAbstractConfigurable +import com.intellij.application.options.CodeStyleAbstractPanel +import com.intellij.application.options.TabbedLanguageCodeStylePanel +import com.intellij.psi.codeStyle.CodeStyleConfigurable +import com.intellij.psi.codeStyle.CodeStyleSettings +import com.intellij.psi.codeStyle.CodeStyleSettingsProvider +import com.intellij.psi.codeStyle.CustomCodeStyleSettings + +class StrideCodeStyleSettingsProvider : CodeStyleSettingsProvider() { + override fun createCustomSettings(settings: CodeStyleSettings): CustomCodeStyleSettings { + return StrideCodeStyleSettings(settings) + } + + override fun getConfigurableDisplayName(): String = "Stride" + + override fun createConfigurable( + settings: CodeStyleSettings, + modelSettings: CodeStyleSettings + ): CodeStyleConfigurable { + return object : CodeStyleAbstractConfigurable(settings, modelSettings, configurableDisplayName) { + override fun createPanel(settings: CodeStyleSettings): CodeStyleAbstractPanel { + return StrideCodeStyleMainPanel(currentSettings, settings) + } + } + } + + private class StrideCodeStyleMainPanel( + currentSettings: CodeStyleSettings, + settings: CodeStyleSettings + ) : TabbedLanguageCodeStylePanel(StrideLanguage, currentSettings, settings) +} diff --git a/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideLanguageCodeStyleSettingsProvider.kt b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideLanguageCodeStyleSettingsProvider.kt new file mode 100644 index 00000000..aa6ee127 --- /dev/null +++ b/packages/stride-plugin-intellij/src/main/kotlin/com/stride/intellij/StrideLanguageCodeStyleSettingsProvider.kt @@ -0,0 +1,47 @@ +package com.stride.intellij + +import com.intellij.lang.Language +import com.intellij.psi.codeStyle.CodeStyleSettingsCustomizable +import com.intellij.psi.codeStyle.CommonCodeStyleSettings +import com.intellij.psi.codeStyle.LanguageCodeStyleSettingsProvider + +class StrideLanguageCodeStyleSettingsProvider : LanguageCodeStyleSettingsProvider() { + override fun getLanguage(): Language = StrideLanguage + + override fun customizeSettings(consumer: CodeStyleSettingsCustomizable, settingsType: SettingsType) { + if (settingsType == SettingsType.INDENT_SETTINGS) { + consumer.showStandardOptions("INDENT_SIZE", "TAB_SIZE") + } + if (settingsType == SettingsType.WRAPPING_AND_BRACES_SETTINGS) { + consumer.showStandardOptions("RIGHT_MARGIN") + } + } + + override fun customizeDefaults(commonSettings: CommonCodeStyleSettings, indentOptions: CommonCodeStyleSettings.IndentOptions) { + indentOptions.INDENT_SIZE = 4 + indentOptions.TAB_SIZE = 4 + indentOptions.USE_TAB_CHARACTER = false + } + + override fun getCodeSample(settingsType: SettingsType): String { + return """ + package example; + + type Vector3 = { + x: f64; + y: f64; + z: f64; + }; + + fn create_vector(x: f64, y: f64, z: f64): Vector3 { + const v = Vector3::{ x: x, y: y, z: z }; + return v; + } + + fn main(): void { + const short = Vector3::{ x: 1, y: 2, z: 3 }; + IO::Print("x: %f, y: %f, z: %f", short.x, short.y, short.z); + } + """.trimIndent() + } +} From 072b96297d9f22884a783818055f3812e8d64b45 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Fri, 13 Mar 2026 12:28:01 +0100 Subject: [PATCH 14/30] Improve test logging and update test execution command in build scripts --- packages/stride-plugin-intellij/build.gradle.kts | 9 ++++++++- scripts/run-tests.sh | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/packages/stride-plugin-intellij/build.gradle.kts b/packages/stride-plugin-intellij/build.gradle.kts index 3285afd5..4aa40549 100644 --- a/packages/stride-plugin-intellij/build.gradle.kts +++ b/packages/stride-plugin-intellij/build.gradle.kts @@ -80,6 +80,13 @@ tasks { } } -tasks.withType { +tasks.named("compileKotlin") { dependsOn(generateStrideParser, generateStrideLexer) } + +tasks.named("test") { + testLogging { + events("passed", "skipped", "failed") + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL + } +} diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh index b4cb1bb5..7815326a 100755 --- a/scripts/run-tests.sh +++ b/scripts/run-tests.sh @@ -19,4 +19,4 @@ printf "${BLUE}| ${WHITE}Running tests for stride-plugin-intellij... ${BLUE}|${ printf "${BLUE}| |${RESET}\n" printf "${BLUE}+----------------------------------------------+${RESET}\n" -cd "${PROJECT_ROOT}/packages/stride-plugin-intellij"; ./gradlew :test +cd "${PROJECT_ROOT}/packages/stride-plugin-intellij"; ./gradlew clean test From a14c2c75c7d777412f34c6e948e685e8ed099f3f Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Fri, 13 Mar 2026 12:31:17 +0100 Subject: [PATCH 15/30] Updated CI to log Kotlin tests --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 22bcb1ed..1ebe0abc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -87,7 +87,7 @@ jobs: - name: Run Tests working-directory: packages/stride-plugin-intellij - run: ./gradlew test + run: ./gradlew clean test deploy: name: Documentation Site Deployment runs-on: ubuntu-latest From c54f8767e0fe192063292cbb0ee7a9797578a0f8 Mon Sep 17 00:00:00 2001 From: Luca Warmenhoven Date: Fri, 13 Mar 2026 12:57:23 +0100 Subject: [PATCH 16/30] Refactor data structure definitions and enhance formatting for object initialization --- .idea/codeStyles/Project.xml | 4 ++ example.sr | 8 +-- packages/standard-library/datastructures.sr | 35 ++++++------- .../kotlin/com/stride/intellij/StrideBlock.kt | 11 ++-- .../intellij/StrideLineWidthFormatterTest.kt | 50 +++++++++++++++++++ 5 files changed, 73 insertions(+), 35 deletions(-) diff --git a/.idea/codeStyles/Project.xml b/.idea/codeStyles/Project.xml index 6abc312f..0af5a21d 100644 --- a/.idea/codeStyles/Project.xml +++ b/.idea/codeStyles/Project.xml @@ -101,6 +101,10 @@