diff --git a/include/fusilli/attributes/pointwise_attributes.h b/include/fusilli/attributes/pointwise_attributes.h index a1418101..1795a562 100644 --- a/include/fusilli/attributes/pointwise_attributes.h +++ b/include/fusilli/attributes/pointwise_attributes.h @@ -46,7 +46,7 @@ namespace fusilli { OP(GELU_APPROX_TANH_FWD) \ /* OP(GELU_BWD) */ \ OP(GELU_FWD) \ - /* OP(GEN_INDEX) */ \ + OP(GEN_INDEX) \ OP(IDENTITY) \ OP(LOG) \ OP(LOGICAL_AND) \ @@ -116,6 +116,11 @@ class PointwiseAttr : public AttributesCRTP { return *this; } + PointwiseAttr &setGenIdxAxis(int64_t axis) { + genIdxAxis_ = axis; + return *this; + } + // Getters: FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_0) FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_1) @@ -126,6 +131,7 @@ class PointwiseAttr : public AttributesCRTP { float getEluAlpha() const { return eluAlpha_; } float getSoftplusBeta() const { return softplusBeta_; } float getSoftplusThreshold() const { return softplusThreshold_; } + int64_t getGenIdxAxis() const { return genIdxAxis_; } // Utilities for pointwise modes. static const std::unordered_map kModeToStr; @@ -137,6 +143,7 @@ class PointwiseAttr : public AttributesCRTP { float eluAlpha_ = 1.0f; float softplusBeta_ = 1.0f; float softplusThreshold_ = 20.0f; + int64_t genIdxAxis_ = 0; }; #define FUSILLI_DECLARE_STRINGIFY_POINTWISE_MODE(mode) \ @@ -165,6 +172,7 @@ inline const std::unordered_map {PointwiseAttr::Mode::FLOOR, 1}, {PointwiseAttr::Mode::GELU_APPROX_TANH_FWD, 1}, {PointwiseAttr::Mode::GELU_FWD, 1}, + {PointwiseAttr::Mode::GEN_INDEX, 1}, {PointwiseAttr::Mode::LOG, 1}, {PointwiseAttr::Mode::LOGICAL_AND, 2}, {PointwiseAttr::Mode::LOGICAL_NOT, 1}, diff --git a/include/fusilli/attributes/types.h b/include/fusilli/attributes/types.h index d900c5fb..9dd45808 100644 --- a/include/fusilli/attributes/types.h +++ b/include/fusilli/attributes/types.h @@ -22,25 +22,32 @@ namespace fusilli { +// Category of a DataType. Boolean is treated as integer since it is stored as +// i1 and participates in integer arithmetic / index_cast. +enum class DataTypeCategory : uint8_t { + Float, + Integer, +}; + // Define a macro to iterate over all fusilli datatypes and the corresponding -// torch datatypes and mlir asm. +// torch datatypes, mlir asm, and category. #define FUSILLI_FORALL_DATA_TYPES(_) \ - _(Half, Half, "f16") \ - _(BFloat16, BFloat16, "bf16") \ - _(Float, Float, "f32") \ - _(Double, Double, "f64") \ - _(Uint8, Byte, "ui8") \ - _(Int4, Undefined, "si4") \ - _(Int8, Char, "si8") \ - _(Int16, Short, "si16") \ - _(Int32, Int, "si32") \ - _(Int64, Long, "si64") \ - _(Boolean, Bool, "i1") \ - _(FP8E5M2, Float8_e5m2, "f8E5M2") + _(Half, Half, "f16", Float) \ + _(BFloat16, BFloat16, "bf16", Float) \ + _(Float, Float, "f32", Float) \ + _(Double, Double, "f64", Float) \ + _(Uint8, Byte, "ui8", Integer) \ + _(Int4, Undefined, "si4", Integer) \ + _(Int8, Char, "si8", Integer) \ + _(Int16, Short, "si16", Integer) \ + _(Int32, Int, "si32", Integer) \ + _(Int64, Long, "si64", Integer) \ + _(Boolean, Bool, "i1", Integer) \ + _(FP8E5M2, Float8_e5m2, "f8E5M2", Float) enum class DataType : uint8_t { NotSet, -#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) FUSILLI_TYPE, +#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) FUSILLI_TYPE, FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM) #undef DEFINE_ENUM }; @@ -48,7 +55,7 @@ enum class DataType : uint8_t { // Map from Fusilli types to MLIR types. static const std::unordered_map kDataTypeToMlirTypeAsm = { -#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) \ +#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \ {DataType::FUSILLI_TYPE, MLIR_TYPE}, FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM) #undef DEFINE_ENUM @@ -57,7 +64,7 @@ static const std::unordered_map kDataTypeToMlirTypeAsm = // Map from Fusilli types to Torch types. static const std::unordered_map kDataTypeToTorchType = { -#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) \ +#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \ {DataType::FUSILLI_TYPE, torch_upstream::ScalarType::TORCH_TYPE}, FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM) #undef DEFINE_ENUM @@ -66,12 +73,47 @@ static const std::unordered_map // Map from MLIR type ASM strings to Fusilli types. static const std::unordered_map kMlirTypeAsmToDataType = { -#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) \ +#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \ {MLIR_TYPE, DataType::FUSILLI_TYPE}, FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM) #undef DEFINE_ENUM }; +// Map from Fusilli types to their category (float vs integer). +static const std::unordered_map kDataTypeCategory = + { +#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \ + {DataType::FUSILLI_TYPE, DataTypeCategory::CATEGORY}, + FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +// Returns true iff `dtype` is a floating-point type (e.g. f16, bf16, f32). +inline bool isFloatDataType(DataType dtype) { + auto it = kDataTypeCategory.find(dtype); + return it != kDataTypeCategory.end() && it->second == DataTypeCategory::Float; +} + +// Returns true iff `dtype` is an integer type (including Boolean, which is +// stored as i1). +inline bool isIntegerDataType(DataType dtype) { + auto it = kDataTypeCategory.find(dtype); + return it != kDataTypeCategory.end() && + it->second == DataTypeCategory::Integer; +} + +// Returns the signless MLIR element type for a DataType, stripping the +// leading "s"/"u" that kDataTypeToMlirTypeAsm carries on integer types +// ("si32" -> "i32", "ui8" -> "i8"). Builtin MLIR tensors and arith ops use +// signless integers, unlike torch vtensors which preserve signedness. +inline std::string getSignlessElementTypeAsm(DataType dtype) { + std::string elemType = kDataTypeToMlirTypeAsm.at(dtype); + if (elemType.size() >= 2 && + (elemType.substr(0, 2) == "si" || elemType.substr(0, 2) == "ui")) + elemType = "i" + elemType.substr(2); + return elemType; +} + } // namespace fusilli #endif // FUSILLI_ATTRIBUTES_TYPES_H diff --git a/include/fusilli/node/pointwise_node.h b/include/fusilli/node/pointwise_node.h index 89047326..08261659 100644 --- a/include/fusilli/node/pointwise_node.h +++ b/include/fusilli/node/pointwise_node.h @@ -103,6 +103,28 @@ class PointwiseNode : public NodeCRTP { return ok(); } + ErrorObject postValidateNode() const override final { + FUSILLI_LOG_LABEL_ENDL("INFO: Post-Validating PointwiseNode '" + << pointwiseAttr.getName() << "'"); + + if (pointwiseAttr.getMode() == PointwiseAttr::Mode::GEN_INDEX) { + const auto &out = pointwiseAttr.getOUT_0(); + const int64_t axis = pointwiseAttr.getGenIdxAxis(); + const int64_t outRank = static_cast(out->getDim().size()); + FUSILLI_RETURN_ERROR_IF( + axis < 0 || axis >= outRank, ErrorCode::InvalidAttribute, + "GEN_INDEX axis " + std::to_string(axis) + + " is out of range for output of rank " + std::to_string(outRank)); + const DataType outDtype = out->getDataType(); + FUSILLI_RETURN_ERROR_IF( + !isFloatDataType(outDtype) && !isIntegerDataType(outDtype), + ErrorCode::InvalidAttribute, + "GEN_INDEX only supports integer or float output dtypes"); + } + + return ok(); + } + ErrorObject inferPropertiesNode() override final { FUSILLI_LOG_LABEL_ENDL("INFO: Inferring properties for PointwiseNode '" << pointwiseAttr.getName() << "'"); diff --git a/include/fusilli/support/asm_emitter.h b/include/fusilli/support/asm_emitter.h index bd5407cd..4d76371e 100644 --- a/include/fusilli/support/asm_emitter.h +++ b/include/fusilli/support/asm_emitter.h @@ -48,6 +48,7 @@ #include #include // C++20 #include +#include #include #include #include @@ -128,6 +129,47 @@ inline std::string buildTensorTypeStr(const std::vector &dims, return oss.str(); } +// Builds a builtin MLIR tensor type string (e.g. "tensor<16x256xf32>") from +// explicit dims and dtype. Uses the signless element type so the result +// bridges cleanly with torch vtensors via torch_c.from_builtin_tensor. +inline std::string buildBuiltinTensorTypeStr(const std::vector &dims, + DataType dtype) { + std::ostringstream oss; + oss << "tensor<"; + for (int64_t dim : dims) + oss << dim << "x"; + oss << getSignlessElementTypeAsm(dtype) << ">"; + return oss.str(); +} + +// Builds an identity affine_map over `rank` dimensions, e.g. +// rank=3 → "affine_map<(d0, d1, d2) -> (d0, d1, d2)>" +// Used for linalg.generic indexing_maps where every operand/result accesses +// the full iteration space in natural order. +inline std::string getIdentityAffineMapAsm(size_t rank) { + std::ostringstream dims; + std::vector indices(rank); + std::iota(indices.begin(), indices.end(), 0); + interleave( + indices.begin(), indices.end(), [&](size_t i) { dims << "d" << i; }, + [&] { dims << ", "; }); + return std::format("affine_map<({0}) -> ({0})>", dims.str()); +} + +// Builds a comma-separated list of `rank` linalg iterator type attrs all of +// the given kind (e.g. "parallel" or "reduction"), e.g. +// rank=3, kind="parallel" → "\"parallel\", \"parallel\", \"parallel\"" +// Paired with an indexing_maps attribute on a linalg.generic. +inline std::string getIteratorTypesAsm(size_t rank, std::string_view kind) { + std::ostringstream oss; + std::vector indices(rank); + std::iota(indices.begin(), indices.end(), 0); + interleave( + indices.begin(), indices.end(), + [&](size_t) { oss << "\"" << kind << "\""; }, [&] { oss << ", "; }); + return oss.str(); +} + // --------------------------------------------------------------------------- // Torch IR constant helpers // @@ -1894,6 +1936,82 @@ inline std::string PointwiseNode::emitNodePreAsm() const { FUSILLI_DECLARE_SUB_ADD_TORCH_EMITTER(ADD, torch.aten.add.Tensor) FUSILLI_DECLARE_SUB_ADD_TORCH_EMITTER(SUB, torch.aten.sub.Tensor) + case PointwiseAttr::Mode::GEN_INDEX: { + const auto &out0 = pointwiseAttr.getOUT_0(); + const std::string suffix = getName(); + const std::vector &outDims = out0->getDim(); + const int64_t axis = pointwiseAttr.getGenIdxAxis(); + // Range and dtype correctness are enforced in postValidateNode; these + // asserts are safety nets in case emitNodePreAsm is called without a + // prior graph->validate(). + assert(axis >= 0 && axis < static_cast(outDims.size()) && + "GEN_INDEX axis out of range"); + + const size_t rank = outDims.size(); + const DataType outDtype = out0->getDataType(); + const bool isFloat = isFloatDataType(outDtype); + assert((isFloat || isIntegerDataType(outDtype)) && + "GEN_INDEX only supports integer or float output dtypes"); + const std::string signlessDtype = getSignlessElementTypeAsm(outDtype); + + const std::string builtinTensorType = + buildBuiltinTensorTypeStr(outDims, out0->getDataType()); + const std::string vtensorType = out0->getTensorTypeAsm( + /*isValueTensor=*/true, /*useLogicalDims=*/true); + + const std::string indexingMap = getIdentityAffineMapAsm(rank); + const std::string iterators = getIteratorTypesAsm(rank, "parallel"); + + // One full schema per cast flavour: integer outputs take a single + // index_cast, floats go through i64 and then sitofp. Keeping the + // linalg.generic and the cast together lets each variant be read as a + // cohesive block of IR. + constexpr std::string_view kGenIndexIntSchema = R"( + %gen_index_empty_{0} = tensor.empty() : {1} + %gen_index_linalg_{0} = linalg.generic {{indexing_maps = [{2}], iterator_types = [{3}]}} outs(%gen_index_empty_{0} : {1}) {{ + ^bb0(%gen_index_out_{0}: {4}): + %gen_index_idx_{0} = linalg.index {5} : index + %gen_index_val_{0} = arith.index_cast %gen_index_idx_{0} : index to {4} + linalg.yield %gen_index_val_{0} : {4} + }} -> {1} + {6} = torch_c.from_builtin_tensor %gen_index_linalg_{0} : {1} -> {7} + {8})"; + + constexpr std::string_view kGenIndexFloatSchema = R"( + %gen_index_empty_{0} = tensor.empty() : {1} + %gen_index_linalg_{0} = linalg.generic {{indexing_maps = [{2}], iterator_types = [{3}]}} outs(%gen_index_empty_{0} : {1}) {{ + ^bb0(%gen_index_out_{0}: {4}): + %gen_index_idx_{0} = linalg.index {5} : index + %gen_index_int_{0} = arith.index_cast %gen_index_idx_{0} : index to i64 + %gen_index_val_{0} = arith.sitofp %gen_index_int_{0} : i64 to {4} + linalg.yield %gen_index_val_{0} : {4} + }} -> {1} + {6} = torch_c.from_builtin_tensor %gen_index_linalg_{0} : {1} -> {7} + {8})"; + + if (isFloat) + return std::format(kGenIndexFloatSchema, suffix, /* {0} */ + builtinTensorType, /* {1} */ + indexingMap, /* {2} */ + iterators, /* {3} */ + signlessDtype, /* {4} */ + axis, /* {5} */ + getResultNamesAsm(), /* {6} */ + vtensorType, /* {7} */ + permuteOUT0 /* {8} */ + ); + return std::format(kGenIndexIntSchema, suffix, /* {0} */ + builtinTensorType, /* {1} */ + indexingMap, /* {2} */ + iterators, /* {3} */ + signlessDtype, /* {4} */ + axis, /* {5} */ + getResultNamesAsm(), /* {6} */ + vtensorType, /* {7} */ + permuteOUT0 /* {8} */ + ); + } + default: assert(false && "Unsupported pointwise mode"); return ""; diff --git a/samples/pointwise/pointwise_unary_ops.cpp b/samples/pointwise/pointwise_unary_ops.cpp index 975856b4..5cfa2fd0 100644 --- a/samples/pointwise/pointwise_unary_ops.cpp +++ b/samples/pointwise/pointwise_unary_ops.cpp @@ -330,3 +330,68 @@ TEST_CASE("Pointwise unary ops", "[pointwise][graph]") { if (supportsFloat(mode) && supportsNegative(mode)) execute(handle, DataType::Half, half(-3.14)); } + +// GEN_INDEX is a unary pointwise op but the expected value varies per output +// position (not a single scalar), so it doesn't fit the test harness above. +// Keep it in this file so all unary pointwise ops are exercised together. +TEST_CASE("Pointwise gen_index", "[pointwise][graph]") { + const auto dim = std::vector{2, 3, 4, 5}; + const int64_t axis = GENERATE(int64_t(0), int64_t(1), int64_t(2), int64_t(3)); + + auto buildNewGraph = [&](Handle &handle, DataType dt) { + auto graph = std::make_shared(); + graph->setName(std::format("pointwise_gen_index_dt{}_axis{}", + kDataTypeToMlirTypeAsm.at(dt), axis)); + graph->setIODataType(dt).setComputeDataType(dt); + + auto xT = graph->tensor(TensorAttr().setName("in0").setDim(dim).setStride( + generateStrideFromDim(dim, getContiguousStrideOrder(dim.size())))); + + auto pointwiseAttr = PointwiseAttr() + .setMode(PointwiseAttr::Mode::GEN_INDEX) + .setGenIdxAxis(axis); + auto pointwiseResult = graph->pointwise(xT, pointwiseAttr); + + pointwiseResult->setName("result").setOutput(true); + + FUSILLI_REQUIRE_OK(graph->validate()); + FUSILLI_REQUIRE_OK(graph->compile(handle, /*remove=*/true)); + + return std::make_tuple(graph, xT, pointwiseResult); + }; + + auto execute = [&](Handle &handle, DataType dt) { + auto [graph, xT, yT] = buildNewGraph(handle, dt); + + FUSILLI_REQUIRE_ASSIGN(auto xBuf, + allocateBufferOfType(handle, xT, dt, T(0))); + FUSILLI_REQUIRE_ASSIGN(auto yBuf, + allocateBufferOfType(handle, yT, dt, T(0))); + + const std::unordered_map, + std::shared_ptr> + variantPack = {{xT, xBuf}, {yT, yBuf}}; + + FUSILLI_REQUIRE_ASSIGN( + auto workspace, allocateWorkspace(handle, graph->getWorkspaceSize())); + FUSILLI_REQUIRE_OK(graph->execute(handle, variantPack, workspace)); + + std::vector result; + FUSILLI_REQUIRE_OK(yBuf->read(handle, result)); + + // Expected value at flat index `i` is the coordinate along `axis`. + const size_t rank = dim.size(); + std::vector strides(rank, 1); + for (int64_t k = static_cast(rank) - 2; k >= 0; --k) + strides[k] = strides[k + 1] * dim[k + 1]; + + for (size_t i = 0; i < result.size(); ++i) { + int64_t coord = (static_cast(i) / strides[axis]) % dim[axis]; + REQUIRE(result[i] == static_cast(coord)); + } + }; + + FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend)); + execute.template operator()(handle, DataType::Float); + execute.template operator()(handle, DataType::Int32); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c85947f9..419c6c6a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -97,6 +97,7 @@ add_fusilli_tests( PREFIX fusilli_support_tests SRCS + test_asm_emitter_helpers.cpp test_cache.cpp test_dllib.cpp test_extras.cpp @@ -162,6 +163,8 @@ add_fusilli_lit_tests( lit/test_pointwise_asm_emitter_floor.cpp lit/test_pointwise_asm_emitter_gelu_approx_tanh_fwd.cpp lit/test_pointwise_asm_emitter_gelu_fwd.cpp + lit/test_pointwise_asm_emitter_gen_index.cpp + lit/test_pointwise_asm_emitter_gen_index_int.cpp lit/test_pointwise_asm_emitter_identity.cpp lit/test_pointwise_asm_emitter_log.cpp lit/test_pointwise_asm_emitter_logical_and.cpp diff --git a/tests/lit/test_pointwise_asm_emitter_gen_index.cpp b/tests/lit/test_pointwise_asm_emitter_gen_index.cpp new file mode 100644 index 00000000..9091d903 --- /dev/null +++ b/tests/lit/test_pointwise_asm_emitter_gen_index.cpp @@ -0,0 +1,57 @@ +// Copyright 2026 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// RUN: %{TEST_EXE} | iree-opt --verify-roundtrip +// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK +// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK + +// clang-format off +// +// TORCH-CHECK: module @module { +// TORCH-CHECK: func.func @main(%result_: !torch.tensor<[16,256,64,32],f32>, %arg0: !torch.vtensor<[16,256,64,32],f32>) attributes {torch.assume_strict_symbolic_shapes} { +// TORCH-CHECK: %gen_index_empty_pointwise_gen_index = tensor.empty() : tensor<16x256x64x32xf32> +// TORCH-CHECK: %gen_index_linalg_pointwise_gen_index = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%gen_index_empty_pointwise_gen_index : tensor<16x256x64x32xf32>) { +// TORCH-CHECK: ^bb0(%gen_index_out_pointwise_gen_index: f32): +// TORCH-CHECK: %gen_index_idx_pointwise_gen_index = linalg.index 2 : index +// TORCH-CHECK: %gen_index_int_pointwise_gen_index = arith.index_cast %gen_index_idx_pointwise_gen_index : index to i64 +// TORCH-CHECK: %gen_index_val_pointwise_gen_index = arith.sitofp %gen_index_int_pointwise_gen_index : i64 to f32 +// TORCH-CHECK: linalg.yield %gen_index_val_pointwise_gen_index : f32 +// TORCH-CHECK: } -> tensor<16x256x64x32xf32> +// TORCH-CHECK: %result_pointwise_gen_index_perm = torch_c.from_builtin_tensor %gen_index_linalg_pointwise_gen_index : tensor<16x256x64x32xf32> -> !torch.vtensor<[16,256,64,32],f32> +// TORCH-CHECK: %result = torch.aten.permute %result_pointwise_gen_index_perm, %permute_OUT_0_pointwise_gen_index : !torch.vtensor<[16,256,64,32],f32>, !torch.list -> !torch.vtensor<[16,256,64,32],f32> +// TORCH-CHECK: torch.overwrite.tensor.contents %result overwrites %result_ : !torch.vtensor<[16,256,64,32],f32>, !torch.tensor<[16,256,64,32],f32> +// TORCH-CHECK: return +// TORCH-CHECK: } +// TORCH-CHECK: } +// +// AMDGPU-STATS-CHECK: "transient-memory-size": 0 +// AMDGPU-STATS-CHECK: "dispatch-count": 1 +// CPU-STATS-CHECK: "transient-memory-size": 0 +// CPU-STATS-CHECK: "dispatch-count": 1 +// +// clang-format on + +#include + +#include "pointwise_utils.h" + +#include +#include + +using namespace fusilli; + +int main(int argc, char **argv) { + std::string mode = (argc > 1) ? argv[1] : "default"; + + auto status = testGenIndexAsmEmitter("pointwise_asm_emitter_gen_index", + "pointwise_gen_index", mode, + {16, 256, 64, 32}, /*axis=*/2); + if (isError(status)) { + std::cerr << "Test failed: " << status << std::endl; + return 1; + } + return 0; +} diff --git a/tests/lit/test_pointwise_asm_emitter_gen_index_int.cpp b/tests/lit/test_pointwise_asm_emitter_gen_index_int.cpp new file mode 100644 index 00000000..592474bc --- /dev/null +++ b/tests/lit/test_pointwise_asm_emitter_gen_index_int.cpp @@ -0,0 +1,56 @@ +// Copyright 2026 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// RUN: %{TEST_EXE} | iree-opt --verify-roundtrip +// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK +// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK + +// clang-format off +// +// TORCH-CHECK: module @module { +// TORCH-CHECK: func.func @main(%result_: !torch.tensor<[16,256,64,32],si32>, %arg0: !torch.vtensor<[16,256,64,32],si32>) attributes {torch.assume_strict_symbolic_shapes} { +// TORCH-CHECK: %gen_index_empty_pointwise_gen_index = tensor.empty() : tensor<16x256x64x32xi32> +// TORCH-CHECK: %gen_index_linalg_pointwise_gen_index = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%gen_index_empty_pointwise_gen_index : tensor<16x256x64x32xi32>) { +// TORCH-CHECK: ^bb0(%gen_index_out_pointwise_gen_index: i32): +// TORCH-CHECK: %gen_index_idx_pointwise_gen_index = linalg.index 2 : index +// TORCH-CHECK: %gen_index_val_pointwise_gen_index = arith.index_cast %gen_index_idx_pointwise_gen_index : index to i32 +// TORCH-CHECK: linalg.yield %gen_index_val_pointwise_gen_index : i32 +// TORCH-CHECK: } -> tensor<16x256x64x32xi32> +// TORCH-CHECK: %result_pointwise_gen_index_perm = torch_c.from_builtin_tensor %gen_index_linalg_pointwise_gen_index : tensor<16x256x64x32xi32> -> !torch.vtensor<[16,256,64,32],si32> +// TORCH-CHECK: %result = torch.aten.permute %result_pointwise_gen_index_perm, %permute_OUT_0_pointwise_gen_index : !torch.vtensor<[16,256,64,32],si32>, !torch.list -> !torch.vtensor<[16,256,64,32],si32> +// TORCH-CHECK: torch.overwrite.tensor.contents %result overwrites %result_ : !torch.vtensor<[16,256,64,32],si32>, !torch.tensor<[16,256,64,32],si32> +// TORCH-CHECK: return +// TORCH-CHECK: } +// TORCH-CHECK: } +// +// AMDGPU-STATS-CHECK: "transient-memory-size": 0 +// AMDGPU-STATS-CHECK: "dispatch-count": 1 +// CPU-STATS-CHECK: "transient-memory-size": 0 +// CPU-STATS-CHECK: "dispatch-count": 1 +// +// clang-format on + +#include + +#include "pointwise_utils.h" + +#include +#include + +using namespace fusilli; + +int main(int argc, char **argv) { + std::string mode = (argc > 1) ? argv[1] : "default"; + + auto status = testGenIndexAsmEmitter( + "pointwise_asm_emitter_gen_index_int", "pointwise_gen_index", mode, + {16, 256, 64, 32}, /*axis=*/2, DataType::Int32); + if (isError(status)) { + std::cerr << "Test failed: " << status << std::endl; + return 1; + } + return 0; +} diff --git a/tests/pointwise_utils.h b/tests/pointwise_utils.h index 3d334e02..087631d4 100644 --- a/tests/pointwise_utils.h +++ b/tests/pointwise_utils.h @@ -99,6 +99,45 @@ inline ErrorObject testBinaryPointwiseAsmEmitter(const std::string &graphName, return ok(); } +inline ErrorObject +testGenIndexAsmEmitter(const std::string &graphName, const std::string &opName, + const std::string &mode, std::vector inDims, + int64_t axis, DataType dtype = DataType::Float) { + + auto graph = std::make_shared(); + graph->setName(graphName); + graph->setIODataType(dtype).setComputeDataType(dtype); + + auto xT = createTestTensor("arg0", inDims, graph.get()); + + auto pointwiseAttr = PointwiseAttr() + .setMode(PointwiseAttr::Mode::GEN_INDEX) + .setName(opName) + .setGenIdxAxis(axis); + + auto yT = graph->pointwise(xT, pointwiseAttr); + + yT->setName("result").setOutput(true); + + FUSILLI_CHECK_ERROR(graph->validate()); + + if (mode == "default") { + FUSILLI_ASSIGN_OR_RETURN(auto generatedAsm, graph->emitAsm()); + FUSILLI_CHECK_ERROR(checkMlirIndentation(generatedAsm)); + std::cout << generatedAsm << std::endl; + } + + if (mode == "stats") { + FUSILLI_ASSIGN_OR_RETURN(Handle handle, Handle::create(kDefaultBackend)); + FUSILLI_CHECK_ERROR(graph->compile(handle, /*remove=*/true)); + FUSILLI_ASSIGN_OR_RETURN(auto stats, graph->readCompilationCacheFile( + CachedAssetsType::Statistics)); + std::cout << stats << std::endl; + } + + return ok(); +} + } // namespace fusilli #endif // FUSILLI_TESTS_POINTWISE_UTILS_H diff --git a/tests/test_asm_emitter_helpers.cpp b/tests/test_asm_emitter_helpers.cpp new file mode 100644 index 00000000..3ec4fbb1 --- /dev/null +++ b/tests/test_asm_emitter_helpers.cpp @@ -0,0 +1,88 @@ +// Copyright 2026 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include + +#include +#include + +using namespace fusilli; + +TEST_CASE("getSignlessElementTypeAsm strips integer signedness prefix", + "[asm_emitter_helpers]") { + REQUIRE(getSignlessElementTypeAsm(DataType::Int8) == "i8"); + REQUIRE(getSignlessElementTypeAsm(DataType::Int16) == "i16"); + REQUIRE(getSignlessElementTypeAsm(DataType::Int32) == "i32"); + REQUIRE(getSignlessElementTypeAsm(DataType::Int64) == "i64"); + REQUIRE(getSignlessElementTypeAsm(DataType::Uint8) == "i8"); +} + +TEST_CASE("getSignlessElementTypeAsm leaves floats and i1 untouched", + "[asm_emitter_helpers]") { + REQUIRE(getSignlessElementTypeAsm(DataType::Half) == "f16"); + REQUIRE(getSignlessElementTypeAsm(DataType::BFloat16) == "bf16"); + REQUIRE(getSignlessElementTypeAsm(DataType::Float) == "f32"); + REQUIRE(getSignlessElementTypeAsm(DataType::Double) == "f64"); + REQUIRE(getSignlessElementTypeAsm(DataType::FP8E5M2) == "f8E5M2"); + REQUIRE(getSignlessElementTypeAsm(DataType::Boolean) == "i1"); +} + +TEST_CASE("isFloatDataType classifies all float flavours", + "[asm_emitter_helpers]") { + REQUIRE(isFloatDataType(DataType::Half)); + REQUIRE(isFloatDataType(DataType::BFloat16)); + REQUIRE(isFloatDataType(DataType::Float)); + REQUIRE(isFloatDataType(DataType::Double)); + REQUIRE(isFloatDataType(DataType::FP8E5M2)); + REQUIRE_FALSE(isFloatDataType(DataType::Int32)); + REQUIRE_FALSE(isFloatDataType(DataType::Boolean)); + REQUIRE_FALSE(isFloatDataType(DataType::NotSet)); +} + +TEST_CASE("isIntegerDataType classifies integer types (including Boolean)", + "[asm_emitter_helpers]") { + REQUIRE(isIntegerDataType(DataType::Uint8)); + REQUIRE(isIntegerDataType(DataType::Int4)); + REQUIRE(isIntegerDataType(DataType::Int8)); + REQUIRE(isIntegerDataType(DataType::Int16)); + REQUIRE(isIntegerDataType(DataType::Int32)); + REQUIRE(isIntegerDataType(DataType::Int64)); + REQUIRE(isIntegerDataType(DataType::Boolean)); + REQUIRE_FALSE(isIntegerDataType(DataType::Float)); + REQUIRE_FALSE(isIntegerDataType(DataType::Half)); + REQUIRE_FALSE(isIntegerDataType(DataType::NotSet)); +} + +TEST_CASE("buildBuiltinTensorTypeStr uses signless element type", + "[asm_emitter_helpers]") { + const std::vector dims = {16, 256}; + REQUIRE(buildBuiltinTensorTypeStr(dims, DataType::Float) == + "tensor<16x256xf32>"); + REQUIRE(buildBuiltinTensorTypeStr(dims, DataType::Int32) == + "tensor<16x256xi32>"); + REQUIRE(buildBuiltinTensorTypeStr(dims, DataType::Uint8) == + "tensor<16x256xi8>"); + REQUIRE(buildBuiltinTensorTypeStr({4}, DataType::Boolean) == "tensor<4xi1>"); +} + +TEST_CASE("getIdentityAffineMapAsm produces (dN) -> (dN)", + "[asm_emitter_helpers]") { + REQUIRE(getIdentityAffineMapAsm(1) == "affine_map<(d0) -> (d0)>"); + REQUIRE(getIdentityAffineMapAsm(2) == "affine_map<(d0, d1) -> (d0, d1)>"); + REQUIRE(getIdentityAffineMapAsm(4) == + "affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>"); +} + +TEST_CASE("getIteratorTypesAsm repeats the kind `rank` times", + "[asm_emitter_helpers]") { + REQUIRE(getIteratorTypesAsm(1, "parallel") == "\"parallel\""); + REQUIRE(getIteratorTypesAsm(3, "parallel") == + "\"parallel\", \"parallel\", \"parallel\""); + REQUIRE(getIteratorTypesAsm(2, "reduction") == + "\"reduction\", \"reduction\""); +} diff --git a/tests/utils.h b/tests/utils.h index 2a7260f6..26a45265 100644 --- a/tests/utils.h +++ b/tests/utils.h @@ -297,11 +297,12 @@ inline ErrorObject checkMlirIndentation(const std::string &mlir) { continue; } - // Every op inside @main must be indented at exactly 4 spaces. - FUSILLI_RETURN_ERROR_IF(indent != 4, ErrorCode::InvalidAttribute, + // Ops inside @main must be indented at least 4 spaces. Deeper + // indents are allowed for nested region bodies (e.g. linalg.generic). + FUSILLI_RETURN_ERROR_IF(indent < 4, ErrorCode::InvalidAttribute, "MLIR indentation error on line " + std::to_string(lineNum) + - ": expected 4-space indent, got " + + ": expected >= 4-space indent, got " + std::to_string(indent) + ": " + line); } return ok();