diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 22bcb1ed..1ebe0abc 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -87,7 +87,7 @@ jobs:
- name: Run Tests
working-directory: packages/stride-plugin-intellij
- run: ./gradlew test
+ run: ./gradlew clean test
deploy:
name: Documentation Site Deployment
runs-on: ubuntu-latest
diff --git a/.idea/codeStyles/Project.xml b/.idea/codeStyles/Project.xml
index 6abc312f..0af5a21d 100644
--- a/.idea/codeStyles/Project.xml
+++ b/.idea/codeStyles/Project.xml
@@ -101,6 +101,10 @@
+
+
+
+
diff --git a/example.sr b/example.sr
index c4b763e3..d60698a8 100644
--- a/example.sr
+++ b/example.sr
@@ -1,34 +1,18 @@
import System::{
- IO::Print,
- IO::Read
+ io::print,
};
-type Array = {
- length: i32;
- data: T[];
-};
-
-type SomeCar = {
- make: string;
- model: string;
- year: i32;
-};
-
-type SomePerson = {
- name: string;
- age: i32;
- cars: Array;
-};
-
-fn make_car(make: string, model: string, year: i32): SomeCar {
- return SomeCar::{ make, model, year };
+enum Test {
+ First: 123,
+ Second: 456,
+ Third: 789
}
fn main(): i32 {
- const my_car = make_car("Toyota", "Corolla", 2020);
+ const k: Test = Test::First;
- IO::Print("My car is a %s %s from %d", my_car.make, my_car.model, my_car.year);
+ io::print("k[0] = %d", k);
return 0;
}
\ No newline at end of file
diff --git a/packages/compiler/compilation-diagram.md b/packages/compiler/compilation-diagram.md
deleted file mode 100644
index 9a1ff026..00000000
--- a/packages/compiler/compilation-diagram.md
+++ /dev/null
@@ -1,18 +0,0 @@
-```mermaid
-flowchart TD
- File1
- File2
- File3
- File4
-
- Tokenizer
-
- File1 --> Tokenizer
- File2 --> Tokenizer
- File3 --> Tokenizer
- File4 --> Tokenizer
-
- Tokenizer -- Convert to AST --> Parser["Parser"]
-
- Parser --> Validation["Validation - Type resolution takes place here"]
-```
\ No newline at end of file
diff --git a/packages/compiler/config-json-schema.json b/packages/compiler/config-json-schema.json
deleted file mode 100644
index d2fb135b..00000000
--- a/packages/compiler/config-json-schema.json
+++ /dev/null
@@ -1,7 +0,0 @@
-{
- "$schema": "http://json-schema.org/draft-07/schema",
- "type": "object",
- "properties": {
-
- }
-}
\ No newline at end of file
diff --git a/packages/compiler/include/ast/closures.h b/packages/compiler/include/ast/closures.h
index fcb44787..46fd8202 100644
--- a/packages/compiler/include/ast/closures.h
+++ b/packages/compiler/include/ast/closures.h
@@ -60,7 +60,8 @@ namespace stride::ast::closures
*/
llvm::Function* find_lambda_function(
llvm::Module* module,
- const llvm::FunctionType* fn_type
+ const llvm::FunctionType* fn_type,
+ bool prefer_captures = false
);
/**
diff --git a/packages/compiler/include/ast/flags.h b/packages/compiler/include/ast/flags.h
index 46c23bbf..fcca461f 100644
--- a/packages/compiler/include/ast/flags.h
+++ b/packages/compiler/include/ast/flags.h
@@ -18,10 +18,11 @@
#define SRFLAG_TYPE_VARIADIC (0x80)
#define SRFLAG_TYPE_INT_SIGNED (0x100)
#define SRFLAG_TYPE_FUNCTION (0x200)
-#define SRFLAG_FN_TYPE_VARIADIC (0x400)
-#define SRFLAG_FN_TYPE_EXTERN (0x800)
-#define SRFLAG_FN_TYPE_ASYNC (0x01000)
-#define SRFLAG_FN_TYPE_ANONYMOUS (0x2000)
+#define SRFLAG_TYPE_GENERIC_REF (0x400)
+#define SRFLAG_FN_TYPE_VARIADIC (0x800)
+#define SRFLAG_FN_TYPE_EXTERN (0x1000)
+#define SRFLAG_FN_TYPE_ASYNC (0x02000)
+#define SRFLAG_FN_TYPE_ANONYMOUS (0x4000)
#define SRFLAG_FN_PARAM_DEF_VARIADIC (0x1)
#define SRFLAG_FN_PARAM_DEF_MUTABLE (0x2)
diff --git a/packages/compiler/include/ast/nodes/enumerables.h b/packages/compiler/include/ast/nodes/enumerables.h
index 1d8ce5b8..db2cafa9 100644
--- a/packages/compiler/include/ast/nodes/enumerables.h
+++ b/packages/compiler/include/ast/nodes/enumerables.h
@@ -40,7 +40,10 @@ namespace stride::ast
std::string to_string() override;
- llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override { return nullptr; }
+ llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override
+ {
+ return nullptr;
+ }
std::unique_ptr clone() override;
};
@@ -77,14 +80,21 @@ namespace stride::ast
std::string to_string() override;
- llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override { return nullptr; }
+ llvm::Value* codegen(llvm::Module* module, llvm::IRBuilderBase* builder) override
+ {
+ return nullptr;
+ }
std::unique_ptr clone() override;
+
+ void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override;
};
std::unique_ptr parse_enumerable_member(
const std::shared_ptr& context,
- TokenSet& set);
+ TokenSet& set,
+ size_t element_index
+ );
std::unique_ptr parse_enumerable_declaration(
const std::shared_ptr& context,
diff --git a/packages/compiler/include/ast/nodes/expression.h b/packages/compiler/include/ast/nodes/expression.h
index a5b5f3d3..a46928fd 100644
--- a/packages/compiler/include/ast/nodes/expression.h
+++ b/packages/compiler/include/ast/nodes/expression.h
@@ -5,6 +5,11 @@
#include "errors.h"
#include "ast/symbols.h"
+namespace llvm
+{
+ class Function;
+}
+
namespace stride::ast
{
enum class VisibilityModifier;
@@ -181,6 +186,9 @@ namespace stride::ast
IAstExpression(symbol.symbol_position, context),
_symbol(std::move(symbol)) {}
+ [[nodiscard]]
+ std::optional get_definition() const;
+
[[nodiscard]]
const std::string& get_name() const
{
@@ -216,24 +224,24 @@ namespace stride::ast
class AstArrayMemberAccessor
: public IAstExpression
{
- std::unique_ptr _array_identifier;
+ std::unique_ptr _array_base;
std::unique_ptr _index_accessor_expr;
public:
explicit AstArrayMemberAccessor(
const SourceFragment& source,
const std::shared_ptr& context,
- std::unique_ptr array_identifier,
+ std::unique_ptr array_base,
std::unique_ptr index_expr
) :
IAstExpression(source, context),
- _array_identifier(std::move(array_identifier)),
+ _array_base(std::move(array_base)),
_index_accessor_expr(std::move(index_expr)) {}
[[nodiscard]]
- AstIdentifier* get_array_identifier() const
+ IAstExpression* get_array_base() const
{
- return this->_array_identifier.get();
+ return this->_array_base.get();
}
[[nodiscard]]
@@ -257,40 +265,36 @@ namespace stride::ast
std::unique_ptr clone() override;
};
- class AstMemberAccessor
+ /// Represents a chained postfix expression: base.member, where base is any expression
+ /// and followup is an AstIdentifier (the member name). Multi-step chains like a.b.c
+ /// are represented left-recursively: ChainedExpression(ChainedExpression(a, b), c).
+ class AstChainedExpression
: public IAstExpression
{
- // When adding function chaining support,
- // these types will have to be changed to IAstExpression
- std::unique_ptr _base;
- std::vector> _members;
+ std::unique_ptr _base;
+ std::unique_ptr _followup; // always AstIdentifier at each leaf
public:
- explicit AstMemberAccessor(
+ explicit AstChainedExpression(
const SourceFragment& source,
const std::shared_ptr& context,
- std::unique_ptr base,
- std::vector> members
+ std::unique_ptr base,
+ std::unique_ptr followup
) :
IAstExpression(source, context),
_base(std::move(base)),
- _members(std::move(members)) {}
+ _followup(std::move(followup)) {}
[[nodiscard]]
- AstIdentifier* get_base() const
+ IAstExpression* get_base() const
{
return this->_base.get();
}
[[nodiscard]]
- std::vector get_members() const;
-
- [[nodiscard]]
- AstIdentifier* get_last_member() const
+ IAstExpression* get_followup() const
{
- return this->_members.empty()
- ? nullptr
- : this->_members.back().get();
+ return this->_followup.get();
}
llvm::Value* codegen(
@@ -315,23 +319,76 @@ namespace stride::ast
) const;
};
+ /// Represents an indirect (expression-based) function call: expr(args).
+ /// Used when the callee is not a simple named identifier, e.g. arr[i]() or obj.fn().
+ class AstIndirectCall
+ : public IAstExpression
+ {
+ std::unique_ptr _callee;
+ ExpressionList _args;
+
+ public:
+ explicit AstIndirectCall(
+ const SourceFragment& source,
+ const std::shared_ptr& context,
+ std::unique_ptr callee,
+ ExpressionList args
+ ) :
+ IAstExpression(source, context),
+ _callee(std::move(callee)),
+ _args(std::move(args)) {}
+
+ [[nodiscard]]
+ IAstExpression* get_callee() const
+ {
+ return this->_callee.get();
+ }
+
+ [[nodiscard]]
+ const ExpressionList& get_args() const
+ {
+ return this->_args;
+ }
+
+ llvm::Value* codegen(
+ llvm::Module* module,
+ llvm::IRBuilderBase* builder
+ ) override;
+
+ std::string to_string() override;
+
+ bool is_reducible() override
+ {
+ return false;
+ }
+
+ std::optional> reduce() override
+ {
+ return std::nullopt;
+ }
+
+ std::unique_ptr clone() override;
+
+ void validate() override;
+ };
+
class AstFunctionCall
: public IAstExpression
{
ExpressionList _arguments;
- const Symbol _symbol;
+ std::unique_ptr _function_name_identifier;
int _flags;
public:
explicit AstFunctionCall(
const std::shared_ptr& context,
- Symbol function_call_sym,
+ std::unique_ptr function_name_identifier,
ExpressionList arguments,
const int flags = SRFLAG_NONE
) :
- IAstExpression(function_call_sym.symbol_position, context),
+ IAstExpression(function_name_identifier->get_source_fragment(), context),
_arguments(std::move(arguments)),
- _symbol(std::move(function_call_sym)),
+ _function_name_identifier(std::move(function_name_identifier)),
_flags(flags) {}
[[nodiscard]]
@@ -346,13 +403,19 @@ namespace stride::ast
[[nodiscard]]
const std::string& get_function_name() const
{
- return this->_symbol.name;
+ return this->_function_name_identifier->get_name();
}
[[nodiscard]]
const std::string& get_scoped_function_name() const
{
- return this->_symbol.internal_name;
+ return this->_function_name_identifier->get_scoped_name();
+ }
+
+ [[nodiscard]]
+ AstIdentifier* get_function_name_identifier() const
+ {
+ return this->_function_name_identifier.get();
}
[[nodiscard]]
@@ -376,8 +439,11 @@ namespace stride::ast
void validate() override;
+ [[nodiscard]]
std::string get_formatted_call() const;
+ void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override;
+
private:
[[nodiscard]]
std::string format_function_name() const;
@@ -388,6 +454,17 @@ namespace stride::ast
llvm::Module* module,
llvm::IRBuilderBase* builder
) const;
+
+ [[nodiscard]]
+ llvm::Function* resolve_regular_callee(
+ llvm::Module* module
+ ) const;
+
+ llvm::Value* codegen_regular_function_call(
+ llvm::Function* callee,
+ llvm::Module* module,
+ llvm::IRBuilderBase* builder
+ ) const;
};
class AstVariableDeclaration
@@ -681,7 +758,7 @@ namespace stride::ast
class AstVariableReassignment
: public IAstExpression
{
- const std::string _variable_name;
+ std::unique_ptr _identifier;
std::unique_ptr _value;
MutativeAssignmentType _operator;
@@ -691,19 +768,19 @@ namespace stride::ast
explicit AstVariableReassignment(
const SourceFragment& source,
const std::shared_ptr& context,
- std::string variable_name,
+ std::unique_ptr identifier,
const MutativeAssignmentType op,
std::unique_ptr value
) :
IAstExpression(source, context),
- _variable_name(std::move(variable_name)),
+ _identifier(std::move(identifier)),
_value(std::move(value)),
_operator(op) {}
[[nodiscard]]
const std::string& get_variable_name() const
{
- return _variable_name;
+ return this->_identifier->get_name();
}
[[nodiscard]]
@@ -712,6 +789,12 @@ namespace stride::ast
return this->_value.get();
}
+ [[nodiscard]]
+ AstIdentifier* get_identifier() const
+ {
+ return this->_identifier.get();
+ }
+
[[nodiscard]]
MutativeAssignmentType get_operator() const
{
@@ -742,7 +825,7 @@ namespace stride::ast
class AstObjectInitializer
: public IAstExpression
{
- std::string _struct_name;
+ std::string _object_type_name;
std::vector _member_initializers;
GenericTypeList _generic_type_arguments;
@@ -758,7 +841,7 @@ namespace stride::ast
GenericTypeList generic_type_arguments = {}
) :
IAstExpression(source, context),
- _struct_name(std::move(struct_name)),
+ _object_type_name(std::move(struct_name)),
_member_initializers(std::move(member_initializers)),
_generic_type_arguments(std::move(generic_type_arguments)) {}
@@ -771,7 +854,7 @@ namespace stride::ast
[[nodiscard]]
const std::string& get_struct_name() const
{
- return _struct_name;
+ return _object_type_name;
}
[[nodiscard]]
@@ -794,6 +877,8 @@ namespace stride::ast
std::unique_ptr clone() override;
+ void resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder) override;
+
private:
std::unique_ptr get_instantiated_object_type();
};
@@ -934,14 +1019,14 @@ namespace stride::ast
/// Parses a function invocation into an AstFunctionCall expression node
std::unique_ptr parse_function_call(
const std::shared_ptr& context,
- const SymbolNameSegments& function_name_segments,
+ AstIdentifier* identifier,
TokenSet& set);
/// Parses a variable assignment statement
std::optional>
parse_variable_reassignment(
const std::shared_ptr& context,
- const std::string& variable_name,
+ AstIdentifier* identifier,
TokenSet& set);
/// Parses a binary arithmetic operation using precedence climbing
@@ -953,11 +1038,11 @@ namespace stride::ast
int min_precedence
);
- /// This parses both function call chaining, and struct member access
- std::unique_ptr parse_chained_member_access(
+ /// Parses a single chained member access step: consumes `.identifier` and wraps lhs
+ std::unique_ptr parse_chained_member_access(
const std::shared_ptr& context,
TokenSet& set,
- const std::unique_ptr& lhs
+ std::unique_ptr lhs
);
/// Parses a unary operator expression
@@ -972,11 +1057,18 @@ namespace stride::ast
TokenSet& set
);
- /// Parses an array member accessor expression, e.g., []
- std::unique_ptr parse_array_member_accessor(
+ /// Parses an array subscript: consumes `[]` and wraps the base expression
+ std::unique_ptr parse_array_member_accessor(
const std::shared_ptr& context,
TokenSet& set,
- std::unique_ptr array_identifier
+ std::unique_ptr array_base
+ );
+
+ /// Parses an indirect call: consumes `()` and wraps the callee expression
+ std::unique_ptr parse_indirect_call(
+ const std::shared_ptr& context,
+ TokenSet& set,
+ std::unique_ptr callee
);
/// Parses a struct initializer expression into an AstObjectInitializer node
@@ -986,7 +1078,11 @@ namespace stride::ast
);
/// Parses a dot-separated identifier into its individual name segments, e.g., `foo::bar::baz`
- SymbolNameSegments parse_segmented_identifier(TokenSet& set, const std::string& error_message);
+ std::unique_ptr parse_segmented_identifier(
+ const std::shared_ptr& context,
+ TokenSet& set,
+ const std::string& error_message
+ );
/// Parses a lambda function literal into an expression node
std::unique_ptr parse_anonymous_fn_expression(
@@ -1037,7 +1133,6 @@ namespace stride::ast
/// This is the case if an expression starts with `{ }`
bool is_struct_initializer(const TokenSet& set);
- /// Checks whether the subsequent tokens might be member accessors
- /// e.g., .
- bool is_member_accessor(IAstExpression* lhs, const TokenSet& set);
+ /// Checks whether the next tokens begin a member access: `.identifier`
+ bool is_member_accessor(const TokenSet& set);
} // namespace stride::ast
diff --git a/packages/compiler/include/ast/nodes/import.h b/packages/compiler/include/ast/nodes/import.h
index 21f77da2..38516978 100644
--- a/packages/compiler/include/ast/nodes/import.h
+++ b/packages/compiler/include/ast/nodes/import.h
@@ -1,5 +1,6 @@
#pragma once
+#include "expression.h"
#include "ast/symbols.h"
#include "ast/nodes/ast_node.h"
@@ -19,32 +20,31 @@ namespace stride::ast
class AstImport
: public IAstNode
{
- const Dependency _dependency;
+
+ std::unique_ptr _package_identifier;
+ std::vector> _import_list;
public:
explicit AstImport(
const SourceFragment& source,
const std::shared_ptr& context,
- Dependency dependency) :
+ std::unique_ptr package_identifier,
+ std::vector> import_list
+ ) :
IAstNode(source, context),
- _dependency(std::move(dependency)) {}
-
- [[nodiscard]]
- const Symbol& get_module() const
- {
- return this->_dependency.package_name;
- }
+ _package_identifier(std::move(package_identifier)),
+ _import_list(std::move(import_list)) {}
[[nodiscard]]
- const Dependency& get_dependency() const
+ AstIdentifier* get_package_identifier() const
{
- return this->_dependency;
+ return this->_package_identifier.get();
}
[[nodiscard]]
- const std::vector& get_submodules() const
+ const std::vector>& get_import_list() const
{
- return this->_dependency.submodules;
+ return this->_import_list;
}
void validate() override;
diff --git a/packages/compiler/include/ast/nodes/types.h b/packages/compiler/include/ast/nodes/types.h
index 661c208f..2f85817b 100644
--- a/packages/compiler/include/ast/nodes/types.h
+++ b/packages/compiler/include/ast/nodes/types.h
@@ -283,10 +283,7 @@ namespace stride::ast
[[nodiscard]]
std::unique_ptr clone() override;
- std::string get_type_name() override
- {
- return this->get_name();
- }
+ std::string get_type_name() override;
std::string to_string() override;
@@ -313,18 +310,14 @@ namespace stride::ast
[[nodiscard]]
std::optional> get_reference_type() const;
- /// Returns the cached underlying type pointer (resolved including generics),
- /// populating the cache via get_underlying_type() if needed.
- /// The returned pointer is valid as long as this AstAliasType is alive.
- [[nodiscard]]
- IAstType* get_underlying_type_ptr();
-
/// Returns the super base type of the reference, e.g., if we have:
/// type RootType = i32;
/// type MidType = RootType;
/// type LeafType = MidType;
/// Then, calling `get_base_reference_type` on `LeafType` will return `i32`.
- [[nodiscard]] std::optional> get_underlying_type();
+ /// Throws an exception if no underlying type can be resolved.
+ [[nodiscard]]
+ IAstType* get_underlying_type();
[[nodiscard]]
std::optional get_type_definition() const;
@@ -482,31 +475,22 @@ namespace stride::ast
_instantiated_generics(std::move(instantiated_generics)),
_type_name(std::move(type_name)) {}
- [[nodiscard]]
- const GenericTypeList& get_instantiated_generics() const
- {
- return _instantiated_generics;
- }
+ [[nodiscard]] const GenericTypeList& get_instantiated_generics() const;
- [[nodiscard]]
- ObjectTypeMemberList get_members() const;
+ [[nodiscard]] ObjectTypeMemberList get_members() const;
- [[nodiscard]]
- std::optional get_member_field_type(const std::string& field_name) const;
+ [[nodiscard]] std::optional get_member_field_type(const std::string& field_name) const;
- [[nodiscard]]
- std::optional get_member_field_index(const std::string& field_name) const;
+ [[nodiscard]] std::optional get_member_field_index(const std::string& field_name) const;
- [[nodiscard]]
- std::string get_type_name() override
+ [[nodiscard]] const std::string& get_base_name() const
{
return this->_type_name;
}
- std::string to_string() override
- {
- return this->get_type_name();
- }
+ [[nodiscard]] std::string get_type_name() override;
+
+ std::string to_string() override;
[[nodiscard]]
bool equals(IAstType* other) override;
@@ -582,6 +566,18 @@ namespace stride::ast
std::string error_message;
std::string type_name;
int flags;
+ const std::vector& generic_types;
+
+ TypeParsingOptions(
+ std::string error_message,
+ std::string type_name = "",
+ const int flags = SRFLAG_NONE,
+ const std::vector& generic_types = {}
+ ) :
+ error_message(std::move(error_message)),
+ type_name(std::move(type_name)),
+ flags(flags),
+ generic_types(generic_types) {}
};
std::unique_ptr parse_type(
@@ -593,8 +589,7 @@ namespace stride::ast
std::unique_ptr parse_type_metadata(
std::unique_ptr base_type,
- TokenSet& set,
- int context_type_flags
+ TokenSet& set
);
std::optional> parse_primitive_type_optional(
@@ -602,7 +597,7 @@ namespace stride::ast
TokenSet& set,
const TypeParsingOptions& options);
- std::optional> parse_named_type_optional(
+ std::optional> parse_alias_type_optional(
const std::shared_ptr& context,
TokenSet& set,
const TypeParsingOptions& options);
diff --git a/packages/compiler/include/ast/parsing_context.h b/packages/compiler/include/ast/parsing_context.h
index 21f9a14c..95dc0d3c 100644
--- a/packages/compiler/include/ast/parsing_context.h
+++ b/packages/compiler/include/ast/parsing_context.h
@@ -217,6 +217,12 @@ namespace stride::ast
return this->_flags;
}
+ [[nodiscard]]
+ bool is_variadic() const
+ {
+ return (this->_flags & SRFLAG_FN_TYPE_VARIADIC) != 0;
+ }
+
~FunctionDefinition() override = default;
bool matches_type_signature(const std::string& name, const AstFunctionType* signature) const;
@@ -350,7 +356,7 @@ namespace stride::ast
) const;
[[nodiscard]]
- std::optional get_struct_type(const std::string& name) const;
+ std::optional get_object_type(const std::string& name) const;
[[nodiscard]]
const definition::IdentifiableSymbolDef* get_symbol_def(
diff --git a/packages/compiler/include/ast/type_inference.h b/packages/compiler/include/ast/type_inference.h
index c953c2c4..172941ca 100644
--- a/packages/compiler/include/ast/type_inference.h
+++ b/packages/compiler/include/ast/type_inference.h
@@ -11,7 +11,8 @@ namespace stride::ast
class IAstExpression;
class AstFunctionCall;
class AstLiteral;
- class AstMemberAccessor;
+ class AstChainedExpression;
+ class AstIndirectCall;
class AstObjectInitializer;
class AstUnaryOp;
class IAstType;
@@ -46,10 +47,13 @@ namespace stride::ast
const AstVariableDeclaration* declaration,
int recursion_guard);
- /// Infers the type of the field accessed via a member accessor expression
- std::unique_ptr infer_object_member_accessor_type(const AstMemberAccessor* member_accessor_expr);
+ /// Infers the type produced by a chained expression (base.followup member access)
+ std::unique_ptr infer_chained_expression_type(const AstChainedExpression* chained_expr);
- /// Infers the type of the element accessed via an array member accessor expression, which is the element type of the array being accessed
+ /// Infers the return type of an indirect call expression (expr(args))
+ std::unique_ptr infer_indirect_call_type(const AstIndirectCall* call_expr);
+
+ /// Infers the element type produced by an array subscript expression
std::unique_ptr infer_array_accessor_type(const AstArrayMemberAccessor* accessor, int recursion_guard);
std::unique_ptr infer_function_type(
diff --git a/packages/compiler/include/ast/visitor.h b/packages/compiler/include/ast/visitor.h
index 28041c37..65008451 100644
--- a/packages/compiler/include/ast/visitor.h
+++ b/packages/compiler/include/ast/visitor.h
@@ -40,8 +40,8 @@ namespace stride::ast
{
std::string _current_file_name; // temporary values
std::map<
- std::string, /* package_name */
- std::string /* file_name */
+ std::string, /* package_name */
+ std::vector /* file_names */
> _package_file_mapping;
std::map<
std::string, /* file_name */
diff --git a/packages/compiler/src/ast/ast.cpp b/packages/compiler/src/ast/ast.cpp
index f0e06567..87e5b3df 100644
--- a/packages/compiler/src/ast/ast.cpp
+++ b/packages/compiler/src/ast/ast.cpp
@@ -127,8 +127,6 @@ std::unique_ptr stride::ast::parse_next_statement(
return parse_package_declaration(context, set);
case TokenType::KEYWORD_IMPORT:
return parse_import_statement(context, set);
- case TokenType::KEYWORD_TYPE:
- return parse_type_definition(context, set, visibility_modifier);
case TokenType::KEYWORD_CONTINUE:
return parse_continue_statement(context, set);
case TokenType::KEYWORD_BREAK:
diff --git a/packages/compiler/src/ast/closures.cpp b/packages/compiler/src/ast/closures.cpp
index 516b89c1..a669b66d 100644
--- a/packages/compiler/src/ast/closures.cpp
+++ b/packages/compiler/src/ast/closures.cpp
@@ -89,7 +89,8 @@ namespace stride::ast::closures
llvm::Function* find_lambda_function(
llvm::Module* module,
- const llvm::FunctionType* fn_type
+ const llvm::FunctionType* fn_type,
+ const bool prefer_captures
)
{
if (!module || !fn_type)
@@ -102,6 +103,13 @@ namespace stride::ast::closures
// 1. Return type must match
// 2. The LAST N parameters must match (where N = declared params)
// 3. The lambda should have >= N parameters
+ //
+ // prefer_captures controls disambiguation when multiple lambdas match:
+ // true → prefer match WITH captures (callee is a closure env, e.g. struct field)
+ // false → prefer match WITHOUT captures (callee is a raw fn ptr)
+
+ llvm::Function* exact_match = nullptr;
+ llvm::Function* capture_match = nullptr;
for (auto& fn : module->functions())
{
@@ -138,11 +146,24 @@ namespace stride::ast::closures
if (params_match)
{
- return &fn;
+ if (lambda_params == num_declared)
+ {
+ if (!exact_match)
+ exact_match = &fn;
+ }
+ else
+ {
+ if (!capture_match)
+ capture_match = &fn;
+ }
}
}
}
- return nullptr;
+
+ if (prefer_captures)
+ return capture_match ? capture_match : exact_match;
+
+ return exact_match ? exact_match : capture_match;
}
std::vector generate_capture_arguments(
diff --git a/packages/compiler/src/ast/context/type_registry.cpp b/packages/compiler/src/ast/context/type_registry.cpp
index f2f04f59..ce6571fe 100644
--- a/packages/compiler/src/ast/context/type_registry.cpp
+++ b/packages/compiler/src/ast/context/type_registry.cpp
@@ -52,7 +52,7 @@ std::optional ParsingContext::get_type_definition(const std::st
/// Gets the root struct type layout for name.
/// Will recursively look up the parent struct definition if name is a reference struct type.
-std::optional ParsingContext::get_struct_type(const std::string& name) const
+std::optional ParsingContext::get_object_type(const std::string& name) const
{
const auto type_def = get_type_definition(name);
diff --git a/packages/compiler/src/ast/generics.cpp b/packages/compiler/src/ast/generics.cpp
index 0b0b4a0a..a272424f 100644
--- a/packages/compiler/src/ast/generics.cpp
+++ b/packages/compiler/src/ast/generics.cpp
@@ -118,16 +118,28 @@ std::unique_ptr stride::ast::resolve_generics(
}
GenericTypeList resolved_generics;
- resolved_generics.reserve(object_type->get_instantiated_generics().size());
- for (const auto& gen : object_type->get_instantiated_generics())
+ if (const auto& instantiated_generics = object_type->get_instantiated_generics();
+ instantiated_generics.empty() && !instantiated_types.empty() && !param_names.empty())
{
- resolved_generics.push_back(resolve_generics(gen.get(), param_names, instantiated_types));
+ resolved_generics.reserve(instantiated_types.size());
+ for (const auto& gen : instantiated_types)
+ {
+ resolved_generics.push_back(gen->clone_ty());
+ }
+ }
+ else
+ {
+ resolved_generics.reserve(instantiated_generics.size());
+ for (const auto& gen : instantiated_generics)
+ {
+ resolved_generics.push_back(resolve_generics(gen.get(), param_names, instantiated_types));
+ }
}
return std::make_unique(
object_type->get_source_fragment(),
object_type->get_context(),
- object_type->get_type_name(),
+ object_type->get_base_name(),
std::move(resolved_members),
object_type->get_flags(),
std::move(resolved_generics)
@@ -172,9 +184,9 @@ std::unique_ptr stride::ast::resolve_generics(
std::unique_ptr stride::ast::instantiate_generic_type(
const AstAliasType* alias_type,
- const definition::TypeDefinition* type_definition
-)
+ const definition::TypeDefinition* type_definition)
{
+
const auto& instantiated_types = alias_type->get_instantiated_generic_types();
const auto& generic_param_names = type_definition->get_generics_parameters();
@@ -183,10 +195,21 @@ std::unique_ptr stride::ast::instantiate_generic_type(
// Ensure we instantiate the type with the correct amount of parameters
if (instantiated_types.size() != generic_param_names.size())
{
+ if (generic_param_names.empty())
+ {
+ throw parsing_error(
+ ErrorType::TYPE_ERROR,
+ std::format(
+ "Failed to resolve generic for type '{}': type is not generic",
+ alias_type->get_name()
+ ),
+ alias_type->get_source_fragment()
+ );
+ }
throw parsing_error(
ErrorType::TYPE_ERROR,
std::format(
- "Failed to instantiate generic type type '{}': expected {} parameters, got {}",
+ "Failed to instantiate generic type '{}': expected {} parameters, got {}",
alias_type->get_name(),
generic_param_names.size(),
instantiated_types.size()
@@ -248,7 +271,7 @@ std::unique_ptr stride::ast::instantiate_generic_type(
return std::make_unique(
type->get_source_fragment(),
type->get_context(),
- type->get_type_name(),
+ type->get_base_name(),
std::move(resolved_members),
type->get_flags(),
std::move(resolved_args)
diff --git a/packages/compiler/src/ast/nodes/enumerables.cpp b/packages/compiler/src/ast/nodes/enumerables.cpp
index de6576e1..6c683eb5 100644
--- a/packages/compiler/src/ast/nodes/enumerables.cpp
+++ b/packages/compiler/src/ast/nodes/enumerables.cpp
@@ -5,7 +5,6 @@
#include "ast/tokens/token_set.h"
#include
-#include
using namespace stride::ast;
using namespace stride::ast::definition;
@@ -18,7 +17,9 @@ using namespace stride::ast::definition;
*/
std::unique_ptr stride::ast::parse_enumerable_member(
const std::shared_ptr& context,
- TokenSet& set)
+ TokenSet& set,
+ size_t element_index
+)
{
const auto member_name_tok = set.expect(TokenType::IDENTIFIER);
auto member_sym = member_name_tok.get_lexeme();
@@ -27,21 +28,33 @@ std::unique_ptr stride::ast::parse_enumerable_member(
Symbol(
member_name_tok.get_source_fragment(),
context->get_name(),
- member_sym,
- /* internal_name = */
- member_sym),
- SymbolType::ENUM_MEMBER);
+ member_sym
+ ),
+ SymbolType::ENUM_MEMBER
+ );
+
+ // Using index as element value if no explicit value is provided, and allowing optional trailing comma
+ if (!set.has_next() || !set.peek_next_eq(TokenType::COLON))
+ {
+ return std::make_unique(
+ member_name_tok.get_source_fragment(),
+ context,
+ std::move(member_sym),
+ std::make_unique(
+ member_name_tok.get_source_fragment(),
+ context,
+ PrimitiveType::INT32,
+ element_index
+ )
+ );
+ }
set.expect(TokenType::COLON, "Expected a colon after enum member name");
auto value = parse_literal_optional(context, set);
if (!value.has_value())
- {
set.throw_error("Expected a literal value for enum member");
- }
-
- set.expect(TokenType::COMMA, "Expected a comma after enum member value");
return std::make_unique(
member_name_tok.get_source_fragment(),
@@ -58,8 +71,7 @@ std::unique_ptr stride::ast::parse_enumerable_declaration(
)
{
const auto reference_token = set.expect(TokenType::KEYWORD_ENUM);
- const auto enumerable_name_tok = set.expect(TokenType::IDENTIFIER);
- auto enumerable_name = enumerable_name_tok.get_lexeme();
+ const auto enumerable_name = set.expect(TokenType::IDENTIFIER).get_lexeme();
context->define_symbol(
Symbol(reference_token.get_source_fragment(),
@@ -70,15 +82,18 @@ std::unique_ptr stride::ast::parse_enumerable_declaration(
auto enum_body_subset = collect_block_required(set, "Expected a block in enum declaration");
- std::vector> members = {};
+ std::vector> members;
auto enum_definition_context = std::make_shared(
context,
context->get_context_type());
- while (enum_body_subset.has_next())
+ members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset, 0));
+
+ for (size_t i = 1; enum_body_subset.has_next(); ++i)
{
- members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset));
+ enum_body_subset.expect(TokenType::COMMA, "Expected a comma between enum members");
+ members.push_back(parse_enumerable_member(enum_definition_context, enum_body_subset, i));
}
return std::make_unique(
@@ -89,12 +104,20 @@ std::unique_ptr stride::ast::parse_enumerable_declaration(
);
}
+void AstEnumerable::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder)
+{
+ for (const auto& member : this->_members)
+ {
+ member->resolve_forward_references(module, builder);
+ }
+}
+
std::unique_ptr AstEnumerable::clone()
{
std::vector> cloned_members;
- cloned_members.reserve(this->get_members().size());
+ cloned_members.reserve(this->_members.size());
- for (const auto& member : this->get_members())
+ for (const auto& member : this->_members)
{
cloned_members.push_back(member->clone_as());
}
@@ -125,20 +148,15 @@ std::string AstEnumerableMember::to_string()
std::string AstEnumerable::to_string()
{
- std::ostringstream imploded;
+ std::vector members;
- if (this->get_members().empty())
- {
- return std::format("Enumerable {} (empty)", this->get_name());
- }
-
- imploded << this->get_members()[0]->to_string();
- for (size_t i = 1; i < this->get_members().size(); ++i)
+ for (const auto& member : this->get_members())
{
- imploded << "\n " << this->get_members()[i]->to_string();
+ members.push_back(member->to_string());
}
- return std::format("Enumerable {} (\n {}\n)",
- this->get_name(),
- imploded.str());
+ return std::format(
+ "Enumerable {} (\n {}\n)",
+ this->get_name(),
+ join(members, ",\n "));
}
diff --git a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp
index 31c91601..73d68a32 100644
--- a/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/array_initializer.cpp
@@ -42,15 +42,14 @@ std::unique_ptr stride::ast::parse_array_initializer(
}
}
- const auto& ref_src_pos = reference_token.get_source_fragment();
const auto& last_token_pos = set.peek(-1).get_source_fragment();
// `]` is already consumed, so we peek back at it
+ const auto position = SourceFragment::combine(reference_token.get_source_fragment(), last_token_pos);
+
return std::make_unique(
- SourceFragment(
- ref_src_pos.source,
- ref_src_pos.offset,
- last_token_pos.offset + last_token_pos.length - ref_src_pos.offset),
+ position,
context,
- std::move(elements));
+ std::move(elements)
+ );
}
diff --git a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp
index 75704bc1..eb9cdd31 100644
--- a/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/array_member_accessor.cpp
@@ -7,14 +7,15 @@
#include "ast/tokens/token_set.h"
#include
+#include
#include
using namespace stride::ast;
-std::unique_ptr stride::ast::parse_array_member_accessor(
+std::unique_ptr stride::ast::parse_array_member_accessor(
const std::shared_ptr& context,
TokenSet& set,
- std::unique_ptr array_identifier)
+ std::unique_ptr array_base)
{
auto expression_block = collect_block_variant(
set,
@@ -32,19 +33,19 @@ std::unique_ptr stride::ast::parse_array_member_accessor(
auto index_expression = parse_inline_expression(context, expression_block.value());
- const auto source_pos = SourceFragment::combine(array_identifier->get_source_fragment(), last_src_pos);
+ const auto source_pos = SourceFragment::combine(array_base->get_source_fragment(), last_src_pos);
return std::make_unique(
source_pos,
context,
- std::move(array_identifier),
+ std::move(array_base),
std::move(index_expression)
);
}
void AstArrayMemberAccessor::validate()
{
- this->_array_identifier->validate();
+ this->_array_base->validate();
this->_index_accessor_expr->validate();
const auto index_accessor_type = this->_index_accessor_expr->get_type();
@@ -78,31 +79,17 @@ llvm::Value* AstArrayMemberAccessor::codegen(
llvm::IRBuilderBase* builder
)
{
- std::unique_ptr array_iden_type = this->_array_identifier->get_type()->clone_ty();
+ std::unique_ptr array_base_type = this->_array_base->get_type()->clone_ty();
- if (const auto named_ty = cast_type(array_iden_type.get()))
+ if (const auto named_ty = cast_type(array_base_type.get()))
{
- if (const auto reference_type = named_ty->get_underlying_type();
- reference_type.has_value())
- {
- array_iden_type = reference_type.value()->clone_ty();
- }
- else
- {
- throw parsing_error(
- ErrorType::SEMANTIC_ERROR,
- "Array member accessor used on non-reference type",
- this->get_source_fragment()
- );
- }
+ array_base_type = named_ty->get_underlying_type()->clone_ty();
}
- llvm::Value* base_ptr = this->_array_identifier->codegen(module, builder);
+ llvm::Value* base_val = this->_array_base->codegen(module, builder);
llvm::Value* index_val = this->_index_accessor_expr->codegen(module, builder);
- // Element type, not the array type.
- // Assumes `array_iden_type` is something like "T[]" and has an element type you can extract.
- const auto* array_ty = cast_type(array_iden_type.get());
+ const auto* array_ty = cast_type(array_base_type.get());
if (!array_ty)
{
throw parsing_error(
@@ -113,16 +100,52 @@ llvm::Value* AstArrayMemberAccessor::codegen(
llvm::Type* elem_llvm_ty = array_ty->get_element_type()->get_llvm_type(module);
- // Treat base_ptr as elem*
- llvm::Value* typed_base_ptr = builder->CreateBitCast(
- base_ptr,
- llvm::PointerType::getUnqual(module->getContext()),
- "array_base_cast"
- );
+ // Ensure we have a pointer to GEP into. The base expression's codegen may
+ // have produced a non-pointer value (e.g. identifier codegen loads the
+ // alloca's allocated type which can be [0 x T] for dynamically-sized arrays).
+ // In that case, look through the load to recover the source pointer.
+ llvm::Value* base_ptr = base_val;
+ if (!base_ptr->getType()->isPointerTy())
+ {
+ if (auto* load_inst = llvm::dyn_cast(base_val))
+ {
+ base_ptr = load_inst->getPointerOperand();
+ load_inst->eraseFromParent();
+ }
+ }
+ // If the base is a pointer to an alloca of array type, load the stored
+ // pointer first (arrays are always behind a pointer in stride).
+ if (base_ptr->getType()->isPointerTy())
+ {
+ if (const auto* alloca_inst = llvm::dyn_cast(base_ptr))
+ {
+ llvm::Type* allocated_ty = alloca_inst->getAllocatedType();
+ if (allocated_ty->isArrayTy())
+ {
+ // The alloca holds a pointer to the array data. Load it.
+ llvm::Value* array_ptr = builder->CreateLoad(
+ llvm::PointerType::getUnqual(module->getContext()),
+ base_ptr,
+ "array_ptr"
+ );
+
+ llvm::Value* element_ptr = builder->CreateInBoundsGEP(
+ elem_llvm_ty,
+ array_ptr,
+ index_val,
+ "array_elem_ptr"
+ );
+
+ return builder->CreateLoad(elem_llvm_ty, element_ptr, "array_load");
+ }
+ }
+ }
+
+ // Opaque pointer or decayed pointer-to-element: single-index GEP.
llvm::Value* element_ptr = builder->CreateInBoundsGEP(
elem_llvm_ty,
- typed_base_ptr,
+ base_ptr,
index_val,
"array_elem_ptr"
);
@@ -135,7 +158,7 @@ std::unique_ptr AstArrayMemberAccessor::clone()
return std::make_unique(
this->get_source_fragment(),
this->get_context(),
- this->_array_identifier->clone_as(),
+ this->_array_base->clone_as(),
this->_index_accessor_expr->clone_as()
);
}
@@ -144,7 +167,7 @@ std::string AstArrayMemberAccessor::to_string()
{
return std::format(
"ArrayAccess({}, {})",
- this->_array_identifier->to_string(),
+ this->_array_base->to_string(),
this->_index_accessor_expr->to_string()
);
}
@@ -152,7 +175,7 @@ std::string AstArrayMemberAccessor::to_string()
bool AstArrayMemberAccessor::is_reducible()
{
// If the value is a literal, it's reducible for sure.
- if (cast_expr(this->_array_identifier.get()))
+ if (cast_expr(this->_array_base.get()))
{
return true;
}
diff --git a/packages/compiler/src/ast/nodes/expressions/expression.cpp b/packages/compiler/src/ast/nodes/expressions/expression.cpp
index a48516ae..0ce7383a 100644
--- a/packages/compiler/src/ast/nodes/expressions/expression.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/expression.cpp
@@ -29,102 +29,113 @@ std::unique_ptr stride::ast::parse_inline_expression_part(
TokenSet& set
)
{
+ std::unique_ptr result;
+
if (auto lit = parse_literal_optional(context, set); lit.has_value())
{
- return std::move(lit.value());
+ result = std::move(lit.value());
}
-
// Will try to parse ::{ ... }
- if (is_struct_initializer(set))
+ else if (is_struct_initializer(set))
{
- return parse_object_initializer(context, set);
+ result = parse_object_initializer(context, set);
}
-
// Will try to parse [ ... ]
- if (is_array_initializer(set))
+ else if (is_array_initializer(set))
{
- return parse_array_initializer(context, set);
+ result = parse_array_initializer(context, set);
}
-
- // Could either be a function call, or object access
- if (set.peek_next_eq(TokenType::IDENTIFIER))
+ // Could either be a function call, or object/array access
+ else if (set.peek_next_eq(TokenType::IDENTIFIER))
{
- /// Regular identifier parsing; can be variable reference
const auto reference_token = set.peek_next();
// Mangled name including module, e.g., `Math__PI`
- const SymbolNameSegments name_segments = parse_segmented_identifier(
- set,
- "Expected identifier in expression");
- const auto internal_name = resolve_internal_name(name_segments);
-
- auto identifier = std::make_unique(
+ auto identifier = parse_segmented_identifier(
context,
- Symbol(reference_token.get_source_fragment(), internal_name));
+ set,
+ "Expected identifier in expression"
+ );
- if (auto reassignment = parse_variable_reassignment(context, internal_name, set);
+ if (auto reassignment = parse_variable_reassignment(context, identifier.get(), set);
reassignment.has_value())
{
return std::move(reassignment.value());
}
- /// Function invocations, e.g., `(...)`, or `::`
+ // Named function invocations, e.g., `(...)` or `::(...)`
if (set.peek_next_eq(TokenType::LPAREN))
{
- return parse_function_call(context, name_segments, set);
+ result = parse_function_call(context, identifier.get(), set);
}
-
- if (set.peek_next_eq(TokenType::LSQUARE_BRACKET))
+ else
{
- return parse_array_member_accessor(
- context,
- set,
- std::move(identifier)
- );
+ result = std::move(identifier);
}
-
- if (is_member_accessor(identifier.get(), set))
- {
- return parse_chained_member_access(
- context,
- set,
- std::move(identifier)
- );
- }
-
- return std::move(identifier);
}
-
// If the next token is a '(', we'll try to descend into it
// until we find another one, e.g. `(1 + (2 * 3))` with nested parentheses
- if (set.peek_next_eq(TokenType::LPAREN))
+ else if (set.peek_next_eq(TokenType::LPAREN))
{
if ((set.peek_eq(TokenType::IDENTIFIER, 1) // Checks for "(: ..."
&& set.peek_eq(TokenType::COLON, 2))
|| (set.peek_eq(TokenType::RPAREN, 1) && // Checks for "():"
set.peek_eq(TokenType::COLON, 2)))
{
- return parse_anonymous_fn_expression(context, set);
+ result = parse_anonymous_fn_expression(context, set);
+ }
+ else
+ {
+ set.next();
+ // Fixed: Use parse_inline_expression (full expression parser) instead of
+ // parse_inline_expression_part to allow binary operations inside parentheses.
+ auto expr = parse_inline_expression(context, set);
+ // TODO: If we have a comma next, it might be a tuple expression
+ set.expect(TokenType::RPAREN, "Expected ')' after expression");
+ result = std::move(expr);
}
-
- set.next();
- // Fixed: Use parse_inline_expression (full expression parser) instead of
- // parse_inline_expression_part to allow binary operations inside parentheses.
- auto expr = parse_inline_expression(context, set);
- // TODO: If we have a comma next, it might be a tuple expression
- set.expect(TokenType::RPAREN, "Expected ')' after expression");
- return expr;
}
-
- if (set.peek_next_eq(TokenType::THREE_DOTS))
+ else if (set.peek_next_eq(TokenType::THREE_DOTS))
{
const auto& ref = set.next();
- return std::make_unique(
+ result = std::make_unique(
ref.get_source_fragment(),
context
);
}
+ else
+ {
+ set.throw_error("Invalid token found in expression");
+ }
- set.throw_error("Invalid token found in expression");
+ // Unified postfix operator loop:
+ // Handles `.member`, `[index]`, and `(args)` chaining on any primary expression.
+ // Builds a left-recursive tree so each step's base is the result of the previous step.
+ int recursion_depth = 0;
+ while (true)
+ {
+ if (is_member_accessor(set))
+ {
+ result = parse_chained_member_access(context, set, std::move(result));
+ }
+ else if (set.peek_next_eq(TokenType::LSQUARE_BRACKET))
+ {
+ result = parse_array_member_accessor(context, set, std::move(result));
+ }
+ else if (set.peek_next_eq(TokenType::LPAREN))
+ {
+ result = parse_indirect_call(context, set, std::move(result));
+ }
+ else
+ {
+ break;
+ }
+ if (++recursion_depth > MAX_RECURSION_DEPTH)
+ {
+ set.throw_error("Expression too complex");
+ }
+ }
+
+ return result;
}
/*
@@ -300,13 +311,17 @@ std::unique_ptr stride::ast::parse_inline_expression(
return parse_expression_internal(context, set);
}
-SymbolNameSegments stride::ast::parse_segmented_identifier(
+std::unique_ptr stride::ast::parse_segmented_identifier(
+ const std::shared_ptr& context,
TokenSet& set,
const std::string& error_message)
{
- std::vector segments = {};
+ std::vector segments;
+
+ const auto initial_identifier = set.expect(TokenType::IDENTIFIER, error_message);
+ segments.push_back(initial_identifier.get_lexeme());
- segments.push_back(set.expect(TokenType::IDENTIFIER, error_message).get_lexeme());
+ std::optional last_fragment = std::nullopt;
while (set.peek_eq(TokenType::DOUBLE_COLON, 0)
&& set.peek_eq(TokenType::IDENTIFIER, 1))
@@ -317,7 +332,15 @@ SymbolNameSegments stride::ast::parse_segmented_identifier(
error_message
);
segments.push_back(subseq_iden.get_lexeme());
+ last_fragment = subseq_iden.get_source_fragment();
}
- return segments;
+ const auto source_pos = last_fragment.has_value()
+ ? SourceFragment::combine(initial_identifier.get_source_fragment(), last_fragment.value())
+ : initial_identifier.get_source_fragment();
+
+ return std::make_unique(
+ context,
+ Symbol(source_pos, resolve_internal_name(segments))
+ );
}
diff --git a/packages/compiler/src/ast/nodes/expressions/identifier.cpp b/packages/compiler/src/ast/nodes/expressions/identifier.cpp
index aab1200a..cf14a845 100644
--- a/packages/compiler/src/ast/nodes/expressions/identifier.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/identifier.cpp
@@ -8,37 +8,50 @@
using namespace stride::ast;
-llvm::Value* AstIdentifier::codegen(
- llvm::Module* module,
- llvm::IRBuilderBase* builder
-)
+std::optional AstIdentifier::get_definition() const
{
- llvm::Value* val = nullptr;
-
- std::string internal_name = this->get_scoped_name();
+ const std::string internal_name = this->get_scoped_name();
- // Prefer exact internal-name match first. This ensures that an unqualified
- // reference like `field` resolves to the global `field` (internal_name="field")
- // rather than a same-named symbol in a sibling module like `Foo__field`.
if (const auto var_def = this->get_context()->lookup_variable(internal_name, false))
{
- internal_name = var_def->get_internal_symbol_name();
+ return var_def;
}
+
// Fall back to name-based lookup, which resolves short names to their internal
// names (e.g. `x` → `x.0` for locals with a counter suffix).
- else if (const auto symbol_definition = this->get_context()->lookup_symbol(internal_name))
+ if (const auto symbol_definition = this->get_context()->lookup_symbol(internal_name))
{
- internal_name = symbol_definition->get_internal_symbol_name();
+ return symbol_definition;
}
- else
+
+ // Last resort: raw name match (handles captured variables).
+ if (const auto definition = this->get_context()->lookup_variable(internal_name, true))
{
- // Last resort: raw name match (handles captured variables).
- if (const auto definition = this->get_context()->lookup_variable(internal_name, true))
- {
- internal_name = definition->get_internal_symbol_name();
- }
+ return definition;
+ }
+
+ return std::nullopt;
+}
+
+llvm::Value* AstIdentifier::codegen(
+ llvm::Module* module,
+ llvm::IRBuilderBase* builder
+)
+{
+ const auto definition = this->get_definition();
+
+ if (!definition.has_value())
+ {
+ throw parsing_error(
+ ErrorType::REFERENCE_ERROR,
+ std::format("Identifier '{}' not found in this scope", this->get_name()),
+ this->get_source_fragment()
+ );
}
+ const std::string internal_name = definition.value()->get_internal_symbol_name();
+ llvm::Value* val = nullptr;
+
if (const auto block = builder->GetInsertBlock())
{
if (llvm::Function* function = block->getParent())
@@ -64,11 +77,15 @@ llvm::Value* AstIdentifier::codegen(
{
return builder->CreateLoad(
global->getValueType(),
- global,
- internal_name
+ global
);
}
+ if (auto* arg = llvm::dyn_cast_or_null(val))
+ {
+ return arg;
+ }
+
throw parsing_error(
ErrorType::REFERENCE_ERROR,
std::format("Identifier '{}' not found in this scope", this->get_name()),
@@ -83,14 +100,18 @@ llvm::Value* AstIdentifier::codegen(
return arg;
}
+ if (auto* load = llvm::dyn_cast_or_null(val))
+ {
+ return load;
+ }
+
if (auto* alloca = llvm::dyn_cast_or_null(val))
{
// Load the value from the allocated variable
// Note: This is safe because 'val' is only found if GetInsertBlock() was not null
return builder->CreateLoad(
alloca->getAllocatedType(),
- alloca,
- internal_name
+ alloca
);
}
@@ -101,8 +122,7 @@ llvm::Value* AstIdentifier::codegen(
{
return builder->CreateLoad(
global->getValueType(),
- global,
- internal_name
+ global
);
}
diff --git a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp
index 0cb21e98..32fb9514 100644
--- a/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/member_accessor.cpp
@@ -1,9 +1,11 @@
#include "errors.h"
#include "formatting.h"
#include "ast/casting.h"
+#include "ast/closures.h"
#include "ast/parsing_context.h"
#include "ast/nodes/enumerables.h"
#include "ast/nodes/expression.h"
+#include "ast/nodes/blocks.h"
#include "ast/tokens/token_set.h"
#include
@@ -11,160 +13,129 @@
using namespace stride::ast;
-/// This parses both function call chaining, and struct member access
-/// e.g.,
-/// foo().bar.baz()
-/// or
-/// struct_var.member.member2 ...
-std::unique_ptr stride::ast::parse_chained_member_access(
+/// Checks whether the next two tokens begin a member access (`.identifier`).
+bool stride::ast::is_member_accessor(const TokenSet& set)
+{
+ return set.peek_eq(TokenType::DOT, 0)
+ && set.peek_eq(TokenType::IDENTIFIER, 1);
+}
+
+/// Consumes `.identifier` and wraps `lhs` in an AstChainedExpression.
+std::unique_ptr stride::ast::parse_chained_member_access(
const std::shared_ptr& context,
TokenSet& set,
- const std::unique_ptr& lhs
+ std::unique_ptr lhs
)
{
- std::vector> chained_accessors = {};
-
- // Initial accessors
- const auto reference_token = set.peek_next();
-
- while (set.peek_next_eq(TokenType::DOT))
- {
- set.expect(TokenType::DOT, "Expected '.' after identifier in member access");
+ set.expect(TokenType::DOT, "Expected '.' in member access");
+ const auto member_tok = set.expect(TokenType::IDENTIFIER, "Expected identifier after '.' in member access");
- const auto accessor_iden_tok = set.expect(TokenType::IDENTIFIER,
- "Expected identifier after '.' in member access");
-
- auto symbol = Symbol(accessor_iden_tok.get_source_fragment(),
- accessor_iden_tok.get_lexeme());
-
- chained_accessors.push_back(std::make_unique(context, symbol));
- }
-
- const auto lhs_source_pos = lhs->get_source_fragment();
-
- // TODO: Allow function calls to be the last element as well.
- auto lhs_identifier = cast_expr(lhs.get());
- if (!lhs_identifier)
- {
- throw parsing_error(
- ErrorType::TYPE_ERROR,
- "Member access base must be an identifier",
- lhs_source_pos);
- }
+ auto member_id = std::make_unique(
+ context,
+ Symbol(member_tok.get_source_fragment(), member_tok.get_lexeme())
+ );
- const auto last_source_pos = chained_accessors.back().get()->get_source_fragment();
- const auto source_pos = SourceFragment::combine(lhs_source_pos, last_source_pos);
+ const auto source = SourceFragment::combine(lhs->get_source_fragment(), member_tok.get_source_fragment());
- return std::make_unique(
- source_pos,
+ return std::make_unique(
+ source,
context,
- lhs_identifier->clone_as(),
- std::move(chained_accessors)
+ std::move(lhs),
+ std::move(member_id)
);
}
-std::vector AstMemberAccessor::get_members() const
+/// Consumes `()` and wraps `callee` in an AstIndirectCall.
+std::unique_ptr stride::ast::parse_indirect_call(
+ const std::shared_ptr& context,
+ TokenSet& set,
+ std::unique_ptr callee
+)
{
- // We don't wanna transfer ownership to anyone else...
- std::vector result;
- result.reserve(this->_members.size());
- std::ranges::transform(
- this->_members,
- std::back_inserter(result),
- [](const std::unique_ptr& member)
- {
- return member.get();
- });
- return result;
-}
+ const auto callee_src = callee->get_source_fragment();
+ auto param_block = collect_parenthesized_block(set);
-bool stride::ast::is_member_accessor(IAstExpression* lhs, const TokenSet& set)
-{
- // We assume the expression and subsequent tokens are "member accessors"
- // if the LHS is an identifier, and it's followed by `.`
- // E.g., `struct_var.member`
- if (cast_expr(lhs))
+ ExpressionList args;
+ if (param_block.has_value())
{
- return set.peek_eq(TokenType::DOT, 0)
- && set.peek_eq(TokenType::IDENTIFIER, 1);
+ auto subset = param_block.value();
+ if (subset.has_next())
+ {
+ args.push_back(parse_inline_expression(context, subset));
+ while (subset.has_next())
+ {
+ subset.expect(TokenType::COMMA, "Expected ',' between arguments");
+ args.push_back(parse_inline_expression(context, subset));
+ }
+ }
}
- return false;
+
+ const auto close_src = set.peek(-1).get_source_fragment();
+ const auto source = SourceFragment::combine(callee_src, close_src);
+
+ return std::make_unique(
+ source,
+ context,
+ std::move(callee),
+ std::move(args)
+ );
}
-llvm::Value* AstMemberAccessor::codegen_global_member_accessor(
+llvm::Value* AstChainedExpression::codegen_global_member_accessor(
llvm::Module* module,
llvm::IRBuilderBase* builder
) const
{
llvm::Value* base_val = this->_base->codegen(module, builder);
- IAstType* cloned_base_type = this->_base->get_type();
- std::string base_type_name = cloned_base_type->get_type_name();
+ IAstType* current_type = this->_base->get_type();
+ std::string current_struct_name = current_type->get_type_name();
- // We look for a GlobalVariable with an initializer
auto* global_var = llvm::dyn_cast_or_null(base_val);
if (!global_var || !global_var->hasInitializer())
{
- // If the base isn't a global with an initializer, we can't fold it.
return nullptr;
}
llvm::Constant* current_const = global_var->getInitializer();
- std::string current_struct_name = base_type_name;
- for (const auto& accessor : this->_members)
+ const auto* member_id = cast_expr(this->_followup.get());
+ if (!member_id)
{
- auto struct_def_opt = this->get_context()->get_struct_type(current_struct_name);
- if (!struct_def_opt.has_value())
- return nullptr;
-
- const auto struct_def = struct_def_opt.value();
-
- const auto member_index = struct_def->get_member_field_index(accessor->get_name());
- if (!member_index.has_value())
- {
- return nullptr;
- }
+ return nullptr;
+ }
- // Extract the constant field value
- current_const = current_const->getAggregateElement(member_index.value());
- if (!current_const)
- {
- // Index out of bounds or invalid aggregate
- throw parsing_error(
- ErrorType::COMPILATION_ERROR,
- std::format("Invalid member access on constant '{}'", base_type_name),
- this->get_source_fragment()
- );
- }
+ auto struct_def_opt = this->get_context()->get_object_type(current_struct_name);
+ if (!struct_def_opt.has_value())
+ return nullptr;
- auto member_field_type = struct_def->get_member_field_type(accessor->get_name());
- if (!member_field_type.has_value())
- {
- return nullptr;
- }
+ const auto struct_def = struct_def_opt.value();
+ const auto member_index = struct_def->get_member_field_index(member_id->get_name());
+ if (!member_index.has_value())
+ return nullptr;
- cloned_base_type = member_field_type.value();
- current_struct_name = cloned_base_type->get_type_name();
+ current_const = current_const->getAggregateElement(member_index.value());
+ if (!current_const)
+ {
+ throw parsing_error(
+ ErrorType::COMPILATION_ERROR,
+ std::format("Invalid member access on constant '{}'", current_struct_name),
+ this->get_source_fragment()
+ );
}
return current_const;
}
-llvm::Value* AstMemberAccessor::codegen(
+llvm::Value* AstChainedExpression::codegen(
llvm::Module* module,
llvm::IRBuilderBase* builder
)
{
- // Global struct definitions have no insertion point, so we need to do
- // constant folding here by looking up the global variable and its initializer.
if (!builder->GetInsertBlock())
{
return codegen_global_member_accessor(module, builder);
}
- // Standard Code Generation (Function context)
-
- // Will codegen identifier - reference to variable (getptr)
llvm::Value* current_val = this->_base->codegen(module, builder);
if (!current_val)
{
@@ -172,191 +143,279 @@ llvm::Value* AstMemberAccessor::codegen(
}
const auto base_struct_type = get_object_type_from_type(this->_base->get_type());
-
- // Base must be a struct for member access to be valid.
- // This would be okay if there were on members, however, this should never happen
if (!base_struct_type.has_value())
{
throw parsing_error(
ErrorType::TYPE_ERROR,
- "Member access base must be a struct",
+ "Member access base must be a struct type",
+ this->get_source_fragment()
+ );
+ }
+
+ const auto* member_id = cast_expr(this->_followup.get());
+ if (!member_id)
+ {
+ throw parsing_error(
+ ErrorType::TYPE_ERROR,
+ "Chained expression followup must be an identifier (member name)",
this->get_source_fragment()
);
}
IAstType* parent_type = base_struct_type.value();
- std::string parent_struct_internalized_name = base_struct_type.value()->get_internalized_name();
- std::string current_accessor_name = this->_base->get_name();
+ const std::string parent_struct_internalized_name = base_struct_type.value()->get_internalized_name();
- // With opaque pointers, we need to know if we are operating on an address (L-value)
- // or a loaded struct value (R-value). Pointers allow GEP, values require ExtractValue.
const bool is_pointer_ty = current_val->getType()->isPointerTy();
- for (const auto& accessor : this->_members)
+ auto parent_struct_type_opt = get_object_type_from_type(parent_type);
+ if (!parent_struct_type_opt.has_value())
{
- auto parent_struct_type_opt = get_object_type_from_type(parent_type);
-
- // In next iteration, it's possible that the previous member produced a non-struct type,
- // hence yielding a nullptr.
- if (!parent_struct_type_opt.has_value())
- {
- throw parsing_error(
- ErrorType::TYPE_ERROR,
- std::format(
- "Cannot access member '{}' of non-struct type '{}'",
- accessor->get_name(),
- current_accessor_name
- ),
- this->get_source_fragment()
- );
- }
-
- const auto parent_struct_type = parent_struct_type_opt.value();
+ throw parsing_error(
+ ErrorType::TYPE_ERROR,
+ std::format(
+ "Cannot access member '{}' of non-struct type",
+ member_id->get_name()
+ ),
+ this->get_source_fragment()
+ );
+ }
- parent_struct_internalized_name = parent_struct_type->get_internalized_name();
+ const auto parent_struct_type = parent_struct_type_opt.value();
+ const std::string internalized_name = parent_struct_type->get_internalized_name();
- const auto member_index = parent_struct_type->get_member_field_index(accessor->get_name());
+ const auto member_index = parent_struct_type->get_member_field_index(member_id->get_name());
+ if (!member_index.has_value())
+ {
+ throw parsing_error(
+ ErrorType::TYPE_ERROR,
+ std::format(
+ "Cannot access member '{}' of object type '{}': member does not exist",
+ member_id->get_name(),
+ parent_struct_type->get_type_name()
+ ),
+ this->get_source_fragment()
+ );
+ }
- // Validate whether the previous struct contains the "member" field
- if (!member_index.has_value())
+ if (is_pointer_ty)
+ {
+ llvm::StructType* struct_llvm_type = llvm::StructType::getTypeByName(
+ module->getContext(),
+ internalized_name
+ );
+ if (!struct_llvm_type)
{
throw parsing_error(
- ErrorType::TYPE_ERROR,
- std::format(
- "Cannot access member '{}' of object type '{}': member does not exist",
- accessor->get_name(),
- parent_struct_type->get_type_name()
- ),
+ ErrorType::COMPILATION_ERROR,
+ std::format("Object '{}' not registered internally", parent_struct_type->get_type_name()),
this->get_source_fragment()
);
}
- if (is_pointer_ty)
- {
- // Get the LLVM type for the current struct to generate the GEP
- llvm::StructType* struct_llvm_type = llvm::StructType::getTypeByName(
- module->getContext(),
- parent_struct_internalized_name
- );
- if (!struct_llvm_type)
- {
- throw parsing_error(
- ErrorType::COMPILATION_ERROR,
- std::format(
- "Object '{}' not registered internally",
- parent_struct_type->get_type_name()),
- this->get_source_fragment());
- }
-
- // Create the GEP (GetElementPtr) instruction: ¤t_ptr->member
- current_val = builder->CreateStructGEP(
- struct_llvm_type,
- current_val,
- member_index.value(),
- "ptr_" + accessor->get_name()
- );
- }
- else
- {
- // We have a direct value, extract the member: current_val.member
- current_val = builder->CreateExtractValue(
- current_val,
- member_index.value(),
- "val_" + accessor->get_name()
- );
- }
+ llvm::Value* member_ptr = builder->CreateStructGEP(
+ struct_llvm_type,
+ current_val,
+ member_index.value(),
+ "ptr_" + member_id->get_name()
+ );
- auto member_field_type = parent_struct_type->get_member_field_type(accessor->get_name());
+ auto member_field_type = parent_struct_type->get_member_field_type(member_id->get_name());
if (!member_field_type.has_value())
{
throw parsing_error(
ErrorType::COMPILATION_ERROR,
- std::format(
- "Unknown member type '{}' in object '{}'",
- accessor->get_name(),
- parent_struct_type->get_type_name()
- ),
- this->get_source_fragment());
+ std::format("Unknown member type '{}' in object '{}'",
+ member_id->get_name(), parent_struct_type->get_type_name()),
+ this->get_source_fragment()
+ );
}
- parent_type = member_field_type.value();
+ llvm::Type* final_llvm_type = member_field_type.value()->get_llvm_type(module);
+ return builder->CreateLoad(final_llvm_type, member_ptr, "val_member_access");
}
- if (!parent_type)
+ // Value (not pointer) — use ExtractValue
+ llvm::Value* val = builder->CreateExtractValue(
+ current_val,
+ member_index.value(),
+ "val_" + member_id->get_name()
+ );
+
+ return val;
+}
+
+std::unique_ptr AstChainedExpression::clone()
+{
+ return std::make_unique(
+ this->get_source_fragment(),
+ this->get_context(),
+ this->_base->clone_as(),
+ this->_followup->clone_as()
+ );
+}
+
+void AstChainedExpression::validate()
+{
+ this->_base->validate();
+ this->_followup->validate();
+}
+
+std::string AstChainedExpression::to_string()
+{
+ return std::format(
+ "ChainedExpression(base: {}, followup: {})",
+ this->_base->to_string(),
+ this->_followup->to_string()
+ );
+}
+
+std::optional> AstChainedExpression::reduce()
+{
+ return std::nullopt;
+}
+
+bool AstChainedExpression::is_reducible()
+{
+ return false;
+}
+
+llvm::Value* AstIndirectCall::codegen(
+ llvm::Module* module,
+ llvm::IRBuilderBase* builder
+)
+{
+ llvm::Value* callee_val = this->get_callee()->codegen(module, builder);
+ if (!callee_val)
{
throw parsing_error(
ErrorType::COMPILATION_ERROR,
- std::format(
- "Invalid member access on non-struct type '{}'",
- this->_base->get_type()->get_type_name()
- ),
- this->get_source_fragment());
+ "Indirect call: could not evaluate callee expression",
+ this->get_source_fragment()
+ );
}
- // if we were working with pointers, we need to load the final result
- if (is_pointer_ty)
+ // Derive the LLVM function type from the callee's AST type
+ auto callee_ast_type = this->get_callee()->get_type()->clone_ty();
+ if (auto* alias_ty = cast_type(callee_ast_type.get()))
{
- llvm::Type* final_llvm_type = parent_type->get_llvm_type(module);
+ callee_ast_type = alias_ty->get_underlying_type()->clone_ty();
+ }
- return builder->CreateLoad(
- final_llvm_type,
- current_val,
- "val_member_access"
+ const auto* fn_type = dynamic_cast(callee_ast_type.get());
+ if (!fn_type)
+ {
+ throw parsing_error(
+ ErrorType::TYPE_ERROR,
+ "Indirect call: callee expression does not have a function type",
+ this->get_source_fragment()
);
}
- // If we were working with values (ExtractValue), we already have the result.
- return current_val;
+ std::vector llvm_param_types;
+ llvm_param_types.reserve(fn_type->get_parameter_types().size());
+ for (const auto& param : fn_type->get_parameter_types())
+ {
+ llvm_param_types.push_back(param->get_llvm_type(module));
+ }
+
+ llvm::FunctionType* llvm_fn_type = llvm::FunctionType::get(
+ fn_type->get_return_type()->get_llvm_type(module),
+ llvm_param_types,
+ fn_type->is_variadic()
+ );
+
+ std::vector args_v;
+ llvm::FunctionType* call_fn_type = llvm_fn_type;
+ llvm::Value* actual_fn_ptr = callee_val;
+
+ // When the callee is a struct field access, the value is likely a closure
+ // env ptr (heap-allocated {fn_ptr, captures...}). Prefer the lambda with
+ // captures so we extract them correctly. For other callees (e.g. return
+ // values from function calls), prefer the exact-match lambda.
+ const bool callee_is_field_access =
+ cast_expr(this->get_callee()) != nullptr;
+
+ // Check if this is a closure call that needs capture extraction
+ if (llvm::Function* lambda_fn =
+ closures::find_lambda_function(module, llvm_fn_type, callee_is_field_access))
+ {
+ const size_t num_captures = lambda_fn->arg_size()
+ - fn_type->get_parameter_types().size();
+
+ if (num_captures > 0)
+ {
+ auto capture_args = closures::extract_closure_captures(
+ module, builder, callee_val, lambda_fn);
+
+ if (capture_args.size() == num_captures)
+ {
+ call_fn_type = lambda_fn->getFunctionType();
+
+ // Extract the actual function pointer from offset 0 of the closure env
+ actual_fn_ptr = builder->CreateLoad(
+ lambda_fn->getType(),
+ callee_val,
+ "closure_fn_ptr"
+ );
+
+ args_v.insert(args_v.end(), capture_args.begin(), capture_args.end());
+ }
+ }
+ }
+
+ // Add the user-provided arguments
+ for (const auto& arg : this->get_args())
+ {
+ llvm::Value* arg_val = arg->codegen(module, builder);
+ if (!arg_val)
+ return nullptr;
+ args_v.push_back(arg_val);
+ }
+
+ const auto instruction_name =
+ call_fn_type->getReturnType()->isVoidTy() ? "" : "indcalltmp";
+
+ return builder->CreateCall(call_fn_type, actual_fn_ptr, args_v, instruction_name);
}
-std::unique_ptr AstMemberAccessor::clone()
+std::unique_ptr AstIndirectCall::clone()
{
- std::vector> members;
- members.reserve(this->_members.size());
-
- for (const auto& member : this->_members)
+ ExpressionList cloned_args;
+ cloned_args.reserve(this->get_args().size());
+ for (const auto& arg : this->get_args())
{
- members.push_back(member->clone_as());
+ cloned_args.push_back(arg->clone_as());
}
- return std::make_unique(
+ return std::make_unique(
this->get_source_fragment(),
this->get_context(),
- this->_base->clone_as(),
- std::move(members)
+ this->_callee->clone_as(),
+ std::move(cloned_args)
);
}
-void AstMemberAccessor::validate()
+void AstIndirectCall::validate()
{
- for (const auto& member : this->_members)
+ this->get_callee()->validate();
+ for (const auto& arg : this->get_args())
{
- member->validate();
+ arg->validate();
}
}
-std::string AstMemberAccessor::to_string()
+std::string AstIndirectCall::to_string()
{
- std::vector member_names;
- member_names.reserve(this->_members.size());
-
- for (const auto& member : this->_members)
+ std::vector arg_strs;
+ arg_strs.reserve(this->get_args().size());
+ for (const auto& arg : this->get_args())
{
- member_names.push_back(member->get_name());
+ arg_strs.push_back(arg->to_string());
}
return std::format(
- "MemberAccessor(base: {}, member: {})",
- this->get_base()->to_string(),
- join(member_names, ","));
-}
-
-std::optional> AstMemberAccessor::reduce()
-{
- return std::nullopt; // Not yet reducible
-}
-
-bool AstMemberAccessor::is_reducible()
-{
- return false;
+ "IndirectCall(callee: {}, args: [{}])",
+ this->_callee->to_string(),
+ join(arg_strs, ", ")
+ );
}
diff --git a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp
index a2a977d4..dafac600 100644
--- a/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/object_initializer.cpp
@@ -86,17 +86,33 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ
return this->_object_type->clone_as();
}
- const auto type_def = this->get_context()->get_type_definition(this->_struct_name);
+ const auto type_def = this->get_context()->get_type_definition(this->_object_type_name);
if (!type_def.has_value())
{
throw parsing_error(
ErrorType::COMPILATION_ERROR,
- std::format("Struct type '{}' is undefined", this->_struct_name),
+ std::format("Object type '{}' is undefined", this->_object_type_name),
this->get_source_fragment()
);
}
+ if (!this->_generic_type_arguments.empty() && !type_def.value()->get_generics_parameters().empty())
+ {
+ if (this->_generic_type_arguments.size() != type_def.value()->get_generics_parameters().size())
+ {
+ throw parsing_error(
+ ErrorType::TYPE_ERROR,
+ std::format("Invalid instantiation of object type '{}': expected {} generic arguments, got {}",
+ this->_object_type_name,
+ type_def.value()->get_generics_parameters().size(),
+ this->_generic_type_arguments.size()
+ ),
+ this->get_source_fragment()
+ );
+ }
+ }
+
if (const auto* object_def = cast_type(type_def.value()->get_type()))
{
auto resolved_type = instantiate_generic_type(this, const_cast(object_def), type_def.value());
@@ -110,17 +126,7 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ
{
const auto underlying_type = alias_def->get_underlying_type();
- if (!underlying_type.has_value())
- {
- throw parsing_error(
- ErrorType::COMPILATION_ERROR,
- std::format("Named type '{}' does not reference another type, cannot be used as object type",
- alias_def->get_name()),
- this->get_source_fragment()
- );
- }
-
- if (auto* object_def = cast_type(underlying_type.value().get()))
+ if (auto* object_def = cast_type(underlying_type))
{
auto resolved_type = instantiate_generic_type(this, object_def, type_def.value());
@@ -132,7 +138,7 @@ std::unique_ptr AstObjectInitializer::get_instantiated_object_typ
throw parsing_error(
ErrorType::COMPILATION_ERROR,
- std::format("Type '{}' is not an object", this->_struct_name),
+ std::format("Type '{}' is not an object", this->_object_type_name),
this->get_source_fragment()
);
}
@@ -199,7 +205,7 @@ void AstObjectInitializer::validate()
std::format(
"Too {} members found in object '{}': expected {}, got {}",
object_members.size() > this->_member_initializers.size() ? "few" : "many",
- this->_struct_name,
+ this->_object_type_name,
object_members.size(),
this->_member_initializers.size()),
this->get_source_fragment()
@@ -217,7 +223,7 @@ void AstObjectInitializer::validate()
ErrorType::TYPE_ERROR,
std::format(
"Object type '{}' has no member named '{}'",
- this->_struct_name,
+ this->_object_type_name,
field_name
),
this->get_source_fragment());
@@ -266,6 +272,14 @@ void AstObjectInitializer::validate()
}
}
+void AstObjectInitializer::resolve_forward_references(llvm::Module* module, llvm::IRBuilderBase* builder)
+{
+ for (const auto& val : this->_member_initializers | std::views::values)
+ {
+ val->resolve_forward_references(module, builder);
+ }
+}
+
llvm::Value* AstObjectInitializer::codegen(
llvm::Module* module,
llvm::IRBuilderBase* builder
@@ -308,7 +322,7 @@ llvm::Value* AstObjectInitializer::codegen(
{
throw parsing_error(
ErrorType::COMPILATION_ERROR,
- std::format("Struct type '{}' is undefined", this->_struct_name),
+ std::format("Struct type '{}' is undefined", this->_object_type_name),
this->get_source_fragment()
);
}
@@ -336,6 +350,7 @@ llvm::Value* AstObjectInitializer::codegen(
}
else
{
+ // FIXME:
// For non-pointers, if they are both structs but have different names,
// LLVM doesn't allow bitcast. However, they should have been the same.
// If we reach here, we might need a more complex conversion or
@@ -359,9 +374,9 @@ llvm::Value* AstObjectInitializer::codegen(
for (size_t i = 0; i < dynamic_members.size(); ++i)
{
auto* member_val = dynamic_members[i];
- auto* target_type = struct_type->getElementType(i);
- if (member_val->getType() != target_type)
+ if (auto* target_type = struct_type->getElementType(i);
+ member_val->getType() != target_type)
{
if (member_val->getType()->isPointerTy() && target_type->isPointerTy())
{
@@ -376,8 +391,7 @@ llvm::Value* AstObjectInitializer::codegen(
current_struct_val = builder->CreateInsertValue(
current_struct_val,
member_val,
- { static_cast(i) },
- "object.construct"
+ { static_cast(i) }
);
}
@@ -405,7 +419,7 @@ std::unique_ptr AstObjectInitializer::clone()
return std::make_unique(
this->get_source_fragment(),
this->get_context(),
- this->_struct_name,
+ this->_object_type_name,
std::move(member_initializers),
std::move(member_generic_types)
);
@@ -413,5 +427,23 @@ std::unique_ptr AstObjectInitializer::clone()
std::string AstObjectInitializer::to_string()
{
- return std::format("StructInit{{...}}");
+ const auto object_name = this->_object_type_name;
+ std::string members;
+
+ for (const auto& [name, expr] : this->_member_initializers)
+ {
+ members += std::format("{}: {}, ", name, expr->to_string());
+ }
+
+ if (!this->_generic_type_arguments.empty())
+ {
+ std::vector generic_names;
+ for (const auto& generic : this->_generic_type_arguments)
+ {
+ generic_names.push_back(generic->to_string());
+ }
+ return std::format("(Object) {}<{}>{{ {} }}", object_name, join(generic_names, ", "), members);
+ }
+
+ return std::format("Object {}{{ {} }}", object_name, members);
}
diff --git a/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp b/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp
index 6905ae32..5b2a59bd 100644
--- a/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp
+++ b/packages/compiler/src/ast/nodes/expressions/variable_reassignation.cpp
@@ -103,7 +103,7 @@ MutativeAssignmentType parse_mutative_assignment_type(const Token& token)
std::optional> stride::ast::parse_variable_reassignment(
const std::shared_ptr& context,
- const std::string& variable_name,
+ AstIdentifier* identifier,
TokenSet& set
)
{
@@ -132,7 +132,7 @@ std::optional> stride::ast::parse_varia
return std::make_unique(
reference_token.get_source_fragment(),
context,
- variable_name,
+ identifier->clone_as(),
operation,
std::move(expression)
);
@@ -141,9 +141,10 @@ std::optional> stride::ast::parse_varia
void AstVariableReassignment::validate()
{
this->_value->validate();
- const auto identifier_def = this->get_context()->lookup_variable(this->get_variable_name(), true);
- if (!identifier_def)
+ const auto definition = this->get_identifier()->get_definition();
+
+ if (!definition.has_value())
{
throw parsing_error(
ErrorType::REFERENCE_ERROR,
@@ -153,11 +154,12 @@ void AstVariableReassignment::validate()
this->get_source_fragment());
}
- this->_internal_name = identifier_def->get_internal_symbol_name();
+ this->_internal_name = definition.value()->get_internal_symbol_name();
+ const auto& identifier_ty = this->get_identifier()->get_type();
- if (is_bitwise_mutative_operation(this->_operator) &&
- this->_value->get_type()->is_primitive() &&
- cast_type(this->_value->get_type())->is_fp())
+ if (is_bitwise_mutative_operation(this->get_operator()) &&
+ this->get_value()->get_type()->is_primitive() &&
+ cast_type