diff --git a/include/fusilli/node/norm_utils.h b/include/fusilli/node/norm_utils.h index a849787e..cdc07ed7 100644 --- a/include/fusilli/node/norm_utils.h +++ b/include/fusilli/node/norm_utils.h @@ -15,7 +15,9 @@ #include "fusilli/attributes/tensor_attributes.h" +#include #include +#include #include #include @@ -35,6 +37,23 @@ inline std::vector getNormalizedShape(const std::vector &xDim, return shape; } +// Returns normalized_shape as the trailing suffix of `xDim` where +// `sDim[i] == xDim[i]`. +// +// Precondition: callers must have validated `sDim` against `xDim` (done in +// RmsNormNode::preValidateNode). +inline std::vector +getRMSNormNormalizedShape(const std::vector &xDim, + const std::vector &sDim) { + const auto [sMismatch, _] = + std::mismatch(sDim.rbegin(), sDim.rend(), xDim.rbegin(), xDim.rend()); + // Keep matchCount as the iterator's signed difference_type so the + // `xDim.end() - matchCount` subtraction below doesn't trip + // clang-tidy's bugprone-narrowing-conversions check. + const auto matchCount = std::distance(sDim.rbegin(), sMismatch); + return std::vector(xDim.end() - matchCount, xDim.end()); +} + // Returns [1, ..., B, ..., 1] dim and unit strides for training forward outputs // (e.g. MEAN, INV_VARIANCE, INV_RMS), where only the batch dimension is // preserved from xDim. @@ -66,6 +85,27 @@ getScaleBiasStride(const std::vector &scaleBiasDim, return generateStrideFromDim(scaleBiasDim, strideOrder); } +// Returns dim and stride for an RmsNormNode INV_RMS output: +// invRms[i] = x[i] if scale[i] == 1 (leading region preserved) +// = 1 otherwise (normalized region collapsed) +// Stride preserves x's stride order via getScaleBiasStride (matches hipDNN's +// RMSNormNode.hpp invRms inference and our scale/bias stride convention). +// +// Precondition: callers must have validated `sDim` against `xDim` (done in +// RmsNormNode::preValidateNode). +inline std::pair, std::vector> +getRMSNormInvRmsDimAndStride(const std::vector &xDim, + const std::vector &sDim, + const std::vector &xStride) { + std::vector dim = xDim; + for (size_t i = 0; i < dim.size(); ++i) { + if (sDim[i] != 1) + dim[i] = 1; + } + std::vector stride = getScaleBiasStride(dim, xStride); + return {dim, stride}; +} + // Infers dim of a tensor if not already set. inline void inferDim(std::shared_ptr &tensor, const std::vector &dim) { diff --git a/include/fusilli/node/rmsnorm_node.h b/include/fusilli/node/rmsnorm_node.h index a7d725f7..e1a4220b 100644 --- a/include/fusilli/node/rmsnorm_node.h +++ b/include/fusilli/node/rmsnorm_node.h @@ -22,8 +22,10 @@ #include "fusilli/node/norm_utils.h" #include "fusilli/support/logging.h" +#include #include #include +#include #include #include #include @@ -72,6 +74,8 @@ class RmsNormNode : public NodeCRTP { // Ensure mandatory input and output tensors are set. FUSILLI_RETURN_ERROR_IF(!xT, ErrorCode::AttributeNotSet, "RmsNorm input tensor X not set"); + FUSILLI_RETURN_ERROR_IF(!sT, ErrorCode::AttributeNotSet, + "RmsNorm input tensor SCALE not set"); FUSILLI_RETURN_ERROR_IF(!yT, ErrorCode::AttributeNotSet, "RmsNorm output tensor Y not set"); @@ -89,25 +93,55 @@ class RmsNormNode : public NodeCRTP { // Shape and layout checks on scale tensor. // If scale tensor's dims/strides are not set, they will be inferred in // inferPropertiesNode(). - if (sT) { - if (!sT->getDim().empty()) { - FUSILLI_RETURN_ERROR_IF(sT->getDim() != - norm_utils::getScaleBiasDim(xT->getDim()), - ErrorCode::InvalidAttribute, - "RmsNorm input tensor SCALE must have shape as " - "tensor X with single batch"); - } + // + // Scale encodes the normalized_shape as a trailing suffix of x: + // reduction = maximal trailing suffix where scale[i] == x[i], + // leading region (excluding batch) must be all-1, + // batch dim (scale[0]) must be 1 (broadcast across batch). + if (!sT->getDim().empty()) { + const auto &xDim = xT->getDim(); + const auto &sDim = sT->getDim(); + constexpr size_t batchDim = 0; + + FUSILLI_RETURN_ERROR_IF( + sDim.size() != xDim.size(), ErrorCode::InvalidAttribute, + "RmsNorm SCALE tensor must have the same rank as X"); + + FUSILLI_RETURN_ERROR_IF( + sDim[batchDim] != 1, ErrorCode::InvalidAttribute, + "RmsNorm SCALE tensor must have batch dim equal to 1 " + "(broadcast across batch)"); + + // matchCount = number of trailing dims where scale[i] == x[i]. + const auto [sMismatch, _] = + std::mismatch(sDim.rbegin(), sDim.rend(), xDim.rbegin(), xDim.rend()); + const size_t matchCount = + static_cast(std::distance(sDim.rbegin(), sMismatch)); - if (!sT->getStride().empty()) { + FUSILLI_RETURN_ERROR_IF( + matchCount == 0, ErrorCode::InvalidAttribute, + "RmsNorm SCALE has no trailing dims matching X — at least " + "one normalized dim is required"); + + // Leading region (between batch and the matching trailing + // suffix) must be all-1. + for (size_t i = batchDim + 1; i < sDim.size() - matchCount; ++i) { FUSILLI_RETURN_ERROR_IF( - !sT->isContiguous() && !sT->isChannelsLast(), - ErrorCode::NotImplemented, - "Tensor '" + sT->getName() + - "' is neither contiguous nor channels-last as " - "defined by its stride"); + sDim[i] != 1, ErrorCode::InvalidAttribute, + "RmsNorm SCALE leading region (before normalized shape) " + "must be 1"); } } + if (!sT->getStride().empty()) { + FUSILLI_RETURN_ERROR_IF( + !sT->isContiguous() && !sT->isChannelsLast(), + ErrorCode::NotImplemented, + "Tensor '" + sT->getName() + + "' is neither contiguous nor channels-last as " + "defined by its stride"); + } + // Output tensor checks for training and inference forward phases. if (isTrainingForwardPhase()) { FUSILLI_RETURN_ERROR_IF(!rT, ErrorCode::AttributeNotSet, @@ -141,17 +175,15 @@ class RmsNormNode : public NodeCRTP { // Infer shape and stride of input SCALE tensor if they're not set. std::shared_ptr sT = rmsnormAttr.getSCALE(); - if (sT) { - norm_utils::inferScaleBiasDimAndStride(sT, xDim, xT->getStride()); - } + norm_utils::inferScaleBiasDimAndStride(sT, xDim, xT->getStride()); // Infer shape and stride of output Y tensor. // When stride is unspecified, preserve the stride order of xT. norm_utils::inferDimAndStride(yT, xDim, xT->getStride()); if (isTrainingForwardPhase()) { - const auto &[dim, stride] = - norm_utils::getTrainingForwardOutputDimAndStride(xDim); + const auto &[dim, stride] = norm_utils::getRMSNormInvRmsDimAndStride( + xDim, rmsnormAttr.getSCALE()->getDim(), xT->getStride()); // Infer shape and stride of output INV_RMS tensor. std::shared_ptr rT = rmsnormAttr.getINV_RMS(); @@ -183,21 +215,22 @@ class RmsNormNode : public NodeCRTP { "defined by its stride"); if (isTrainingForwardPhase()) { - const auto &[dim, stride] = - norm_utils::getTrainingForwardOutputDimAndStride(xDim); + const auto &[dim, stride] = norm_utils::getRMSNormInvRmsDimAndStride( + xDim, rmsnormAttr.getSCALE()->getDim(), xT->getStride()); std::shared_ptr rT = rmsnormAttr.getINV_RMS(); // Shape check for output INV_RMS tensor FUSILLI_RETURN_ERROR_IF( dim != rT->getDim(), ErrorCode::InvalidAttribute, - "RmsNorm output INV_RMS tensor must have shape [B, 1, ..., 1] with " - "rank equal to input X tensor's rank, and batch dimension equal " - "to input X tensor's batch dimension"); + "RmsNorm output INV_RMS tensor must have x's leading (broadcast) " + "dims preserved and the normalized (trailing) region collapsed to " + "1, with rank equal to input X tensor's rank"); // Stride check for output INV_RMS tensor FUSILLI_RETURN_ERROR_IF( stride != rT->getStride(), ErrorCode::InvalidAttribute, - "RmsNorm output INV_RMS tensor must have unit strides"); + "RmsNorm output INV_RMS tensor must have strides preserving " + "input X tensor's stride order"); } FUSILLI_RETURN_ERROR_IF( @@ -214,7 +247,8 @@ class RmsNormNode : public NodeCRTP { } std::vector getNormalizedShape() const { - return norm_utils::getNormalizedShape(rmsnormAttr.getX()->getDim()); + return norm_utils::getRMSNormNormalizedShape( + rmsnormAttr.getX()->getDim(), rmsnormAttr.getSCALE()->getDim()); } }; diff --git a/include/fusilli/support/asm_emitter.h b/include/fusilli/support/asm_emitter.h index d7d0daae..cb017455 100644 --- a/include/fusilli/support/asm_emitter.h +++ b/include/fusilli/support/asm_emitter.h @@ -1477,10 +1477,8 @@ inline std::string RmsNormNode::getOperandNamesAsm() const { oss << rmsnormAttr.getX()->getValueNameAsm() << "_" << suffix << "_perm, "; oss << "%normalized_shape_" << suffix << ", "; - - auto sT = rmsnormAttr.getSCALE(); - oss << (sT ? sT->getValueNameAsm() + "_" + suffix + "_perm, " - : "%none_scale_" + suffix + ", "); + oss << rmsnormAttr.getSCALE()->getValueNameAsm() << "_" << suffix + << "_perm, "; oss << "%eps_" << suffix; return oss.str(); @@ -1494,11 +1492,8 @@ inline std::string RmsNormNode::getOperandTypesAsm() const { /*useLogicalDims=*/true) << ", "; oss << "!torch.list" << ", "; - - auto sT = rmsnormAttr.getSCALE(); - oss << (sT ? sT->getTensorTypeAsm(/*isValueTensor=*/true, - /*useLogicalDims=*/true) - : "!torch.none") + oss << rmsnormAttr.getSCALE()->getTensorTypeAsm(/*isValueTensor=*/true, + /*useLogicalDims=*/true) << ", "; oss << "!torch.float"; @@ -1558,10 +1553,8 @@ inline std::string RmsNormNode::emitNodePreAsm() const { std::string permuteY = getLayoutConversionOpsAsm( rmsnormAttr.getY(), "permute_y", uniqueSSASuffix, /*isInput=*/false); std::string permuteScale = - rmsnormAttr.getSCALE() - ? getLayoutConversionOpsAsm(rmsnormAttr.getSCALE(), "permute_scale", - uniqueSSASuffix, /*isInput=*/true) - : torchNoneAsm("none_scale", uniqueSSASuffix); + getLayoutConversionOpsAsm(rmsnormAttr.getSCALE(), "permute_scale", + uniqueSSASuffix, /*isInput=*/true); constexpr std::string_view schema = R"( {0} diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index c830792e..bfc5f2d8 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -99,7 +99,6 @@ add_fusilli_samples( add_fusilli_samples( PREFIX fusilli_rmsnorm_samples SRCS - rmsnorm/rmsnorm_infer_nchw.cpp rmsnorm/rmsnorm_infer_nchw_scale.cpp rmsnorm/rmsnorm_infer_nhwc_scale.cpp DEPS diff --git a/samples/rmsnorm/rmsnorm_infer_nchw.cpp b/samples/rmsnorm/rmsnorm_infer_nchw.cpp deleted file mode 100644 index e8222c1e..00000000 --- a/samples/rmsnorm/rmsnorm_infer_nchw.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2026 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include - -#include "rmsnorm_utils.h" -#include "utils.h" - -#include - -#include -#include -#include -#include -#include -#include -#include - -using namespace fusilli; - -TEST_CASE("RMS normalization; inference mode; NCHW layout; no scale", - "[rmsnorm][graph]") { - constexpr int64_t n = 2, c = 3, h = 32, w = 32; - constexpr float eps = 1e-5f; - - auto buildNewGraph = [=](const Handle &handle) { - auto graph = std::make_shared(); - graph->setName("rmsnorm_infer_sample_nchw"); - graph->setIODataType(DataType::Float).setComputeDataType(DataType::Float); - - auto xT = graph->tensor(TensorAttr() - .setName("x") - .setDim({n, c, h, w}) - .setStride({c * h * w, h * w, w, 1})); // NCHW - - auto epsilonT = graph->tensor(TensorAttr(eps)); - - auto rmsnormAttr = RmsnormAttr() - .setForwardPhase(NormFwdPhase::INFERENCE) - .setEpsilon(epsilonT) - .setName("rmsnorm"); - - // RmsNorm - auto [yT, rT] = graph->rmsnorm(xT, nullptr, rmsnormAttr); - - yT->setName("y").setDataType(DataType::Float).setOutput(true); - - // Validate, infer missing properties - FUSILLI_REQUIRE_OK(graph->validate()); - - // Compile - FUSILLI_REQUIRE_OK(graph->compile(handle, /*remove=*/true)); - - return std::make_tuple(graph, xT, yT); - }; - - // Create handle for the target backend. - FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend)); - - auto [graph, xT, yT] = buildNewGraph(handle); - - auto [inputVals, expectedVals] = - rmsnorm_utils::generateIOTensorsForInferForward(n, c, h, w, 1.f, eps); - - FUSILLI_REQUIRE_ASSIGN(auto xBuf, - allocateBufferOfType(handle, xT, inputVals)); - FUSILLI_REQUIRE_ASSIGN( - auto yBuf, allocateBufferOfType(handle, yT, DataType::Float, 0.0f)); - - // Create variant pack. - const std::unordered_map, std::shared_ptr> - variantPack = { - {xT, xBuf}, - {yT, yBuf}, - }; - - // Allocate workspace buffer if needed. - FUSILLI_REQUIRE_ASSIGN(auto workspaceSize, graph->getWorkspaceSize()); - FUSILLI_REQUIRE_ASSIGN(auto workspace, - allocateWorkspace(handle, workspaceSize)); - - // Execute graph once. - FUSILLI_REQUIRE_OK(graph->execute(handle, variantPack, workspace)); - - std::vector yVals; - FUSILLI_REQUIRE_OK(yBuf->read(handle, yVals)); - - REQUIRE(yVals.size() == expectedVals.size()); - constexpr float tolerance = 1e-4f; - for (size_t i = 0; i < yVals.size(); ++i) { - REQUIRE(std::abs(yVals[i] - expectedVals[i]) < tolerance); - } - - // Execute graph a few times to verify consistent results. - constexpr size_t numIters = 1; - for (size_t i = 0; i < numIters; ++i) - FUSILLI_REQUIRE_OK(graph->execute(handle, variantPack, workspace)); - - // Repeat output buffer checks. - yVals.clear(); - FUSILLI_REQUIRE_OK(yBuf->read(handle, yVals)); - - REQUIRE(yVals.size() == expectedVals.size()); - for (size_t i = 0; i < yVals.size(); ++i) { - REQUIRE(std::abs(yVals[i] - expectedVals[i]) < tolerance); - } -} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5e5bcd5f..d1b1ade7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -202,7 +202,8 @@ add_fusilli_lit_tests( lit/test_layernorm_infer_asm_emitter_scale_bias_nhwc_small_batch.cpp lit/test_layernorm_train_asm_emitter_nchw.cpp lit/test_layernorm_train_asm_emitter_scale_bias_nhwc.cpp - lit/test_rmsnorm_infer_asm_emitter_nchw.cpp + lit/test_rmsnorm_infer_asm_emitter_partial_suffix_hw.cpp + lit/test_rmsnorm_infer_asm_emitter_partial_suffix_w.cpp lit/test_rmsnorm_infer_asm_emitter_scale_nhwc.cpp lit/test_matmul_asm_emitter_basic.cpp lit/test_matmul_asm_emitter_batched.cpp diff --git a/tests/lit/test_rmsnorm_infer_asm_emitter_nchw.cpp b/tests/lit/test_rmsnorm_infer_asm_emitter_nchw.cpp deleted file mode 100644 index a81ae493..00000000 --- a/tests/lit/test_rmsnorm_infer_asm_emitter_nchw.cpp +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2026 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// RUN: %{TEST_EXE} | iree-opt --verify-roundtrip -// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK -// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK - -// clang-format off -// -// TORCH-CHECK: module @module { -// TORCH-CHECK: func.func @main(%result_: !torch.tensor<[16,128,64,32],f32>, %arg0_x: !torch.vtensor<[16,128,64,32],f32>) attributes {torch.assume_strict_symbolic_shapes} { -// TORCH-CHECK: %rmsnorm_infer_EPSILON = torch.vtensor.literal(dense<0x3727C5AC> : tensor<1xf32>) : !torch.vtensor<[1],f32> -// TORCH-CHECK: %normalized_shape_val_0_rmsnorm_infer = torch.constant.int 128 -// TORCH-CHECK: %normalized_shape_val_1_rmsnorm_infer = torch.constant.int 64 -// TORCH-CHECK: %normalized_shape_val_2_rmsnorm_infer = torch.constant.int 32 -// TORCH-CHECK: %normalized_shape_rmsnorm_infer = torch.prim.ListConstruct %normalized_shape_val_0_rmsnorm_infer, %normalized_shape_val_1_rmsnorm_infer, %normalized_shape_val_2_rmsnorm_infer : (!torch.int, !torch.int, !torch.int) -> !torch.list -// TORCH-CHECK: %eps_rmsnorm_infer = torch.aten.item %rmsnorm_infer_EPSILON : !torch.vtensor<[1],f32> -> !torch.float -// TORCH-CHECK: %permute_x_val_0_rmsnorm_infer = torch.constant.int 0 -// TORCH-CHECK: %permute_x_val_1_rmsnorm_infer = torch.constant.int 1 -// TORCH-CHECK: %permute_x_val_2_rmsnorm_infer = torch.constant.int 2 -// TORCH-CHECK: %permute_x_val_3_rmsnorm_infer = torch.constant.int 3 -// TORCH-CHECK: %permute_x_rmsnorm_infer = torch.prim.ListConstruct %permute_x_val_0_rmsnorm_infer, %permute_x_val_1_rmsnorm_infer, %permute_x_val_2_rmsnorm_infer, %permute_x_val_3_rmsnorm_infer : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// TORCH-CHECK: %arg0_x_rmsnorm_infer_perm = torch.aten.permute %arg0_x, %permute_x_rmsnorm_infer : !torch.vtensor<[16,128,64,32],f32>, !torch.list -> !torch.vtensor<[16,128,64,32],f32> -// TORCH-CHECK: %none_scale_rmsnorm_infer = torch.constant.none -// TORCH-CHECK: %result_rmsnorm_infer_perm = torch.aten.rms_norm %arg0_x_rmsnorm_infer_perm, %normalized_shape_rmsnorm_infer, %none_scale_rmsnorm_infer, %eps_rmsnorm_infer : !torch.vtensor<[16,128,64,32],f32>, !torch.list, !torch.none, !torch.float -> !torch.vtensor<[16,128,64,32],f32> -// TORCH-CHECK: %permute_y_val_0_rmsnorm_infer = torch.constant.int 0 -// TORCH-CHECK: %permute_y_val_1_rmsnorm_infer = torch.constant.int 1 -// TORCH-CHECK: %permute_y_val_2_rmsnorm_infer = torch.constant.int 2 -// TORCH-CHECK: %permute_y_val_3_rmsnorm_infer = torch.constant.int 3 -// TORCH-CHECK: %permute_y_rmsnorm_infer = torch.prim.ListConstruct %permute_y_val_0_rmsnorm_infer, %permute_y_val_1_rmsnorm_infer, %permute_y_val_2_rmsnorm_infer, %permute_y_val_3_rmsnorm_infer : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// TORCH-CHECK: %result = torch.aten.permute %result_rmsnorm_infer_perm, %permute_y_rmsnorm_infer : !torch.vtensor<[16,128,64,32],f32>, !torch.list -> !torch.vtensor<[16,128,64,32],f32> -// TORCH-CHECK: torch.overwrite.tensor.contents %result overwrites %result_ : !torch.vtensor<[16,128,64,32],f32>, !torch.tensor<[16,128,64,32],f32> -// TORCH-CHECK: return -// TORCH-CHECK: } -// TORCH-CHECK: } -// -// AMDGPU-STATS-CHECK: "transient-memory-size": 0 -// AMDGPU-STATS-CHECK: "dispatch-count": 1 -// CPU-STATS-CHECK: "transient-memory-size": 0 -// CPU-STATS-CHECK: "dispatch-count": 1 -// -// clang-format on - -#include - -#include "utils.h" - -#include -#include -#include -#include - -using namespace fusilli; - -static ErrorObject testRmsnormInferAsmEmitterNchw(const std::string &mode) { - int64_t n = 16, c = 128, h = 64, w = 32; - auto graph = std::make_shared(); - graph->setName("rmsnorm_infer_asm_emitter_nchw"); - graph->setIODataType(DataType::Float).setComputeDataType(DataType::Float); - - auto xT = graph->tensor(TensorAttr() - .setName("arg0_x") - .setDim({n, c, h, w}) - .setStride({c * h * w, h * w, w, 1})); // NCHW - - auto epsilonT = graph->tensor(TensorAttr(1e-5f)); - - auto rmsnormAttr = RmsnormAttr() - .setForwardPhase(NormFwdPhase::INFERENCE) - .setEpsilon(epsilonT) - .setName("rmsnorm_infer"); - - auto [yT, rT] = graph->rmsnorm(xT, nullptr, rmsnormAttr); - - yT->setName("result").setOutput(true); - - FUSILLI_CHECK_ERROR(graph->validate()); - - if (mode == "default") { - FUSILLI_ASSIGN_OR_RETURN(auto generatedAsm, graph->emitAsm()); - FUSILLI_CHECK_ERROR(checkMlirIndentation(generatedAsm)); - std::cout << generatedAsm << std::endl; - } - - if (mode == "stats") { - FUSILLI_ASSIGN_OR_RETURN(Handle handle, Handle::create(kDefaultBackend)); - FUSILLI_CHECK_ERROR(graph->compile(handle, /*remove=*/true)); - FUSILLI_ASSIGN_OR_RETURN(auto stats, graph->readCompilationCacheFile( - CachedAssetsType::Statistics)); - std::cout << stats << std::endl; - } - - return ok(); -} - -int main(int argc, char **argv) { - std::string mode = (argc > 1) ? argv[1] : "default"; - - auto status = testRmsnormInferAsmEmitterNchw(mode); - if (isError(status)) { - std::cerr << "Test failed: " << status << std::endl; - return 1; - } - return 0; -} diff --git a/tests/lit/test_rmsnorm_infer_asm_emitter_partial_suffix_hw.cpp b/tests/lit/test_rmsnorm_infer_asm_emitter_partial_suffix_hw.cpp new file mode 100644 index 00000000..172a5d94 --- /dev/null +++ b/tests/lit/test_rmsnorm_infer_asm_emitter_partial_suffix_hw.cpp @@ -0,0 +1,100 @@ +// Copyright 2026 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Verifies that an RMSNorm with a partial trailing-suffix scale shape +// (scale=[1,1,H,W] for x=[N,C,H,W]) emits torch.aten.rms_norm with +// normalized_shape=[H, W] (only the trailing matching dims of x), not the +// whole non-batch shape. + +// RUN: %{TEST_EXE} | iree-opt --verify-roundtrip +// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK +// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK + +// clang-format off +// +// TORCH-CHECK: module @module { +// TORCH-CHECK: func.func @main({{.*}}) attributes {torch.assume_strict_symbolic_shapes} { +// TORCH-CHECK: %normalized_shape_val_0_rmsnorm_infer = torch.constant.int 64 +// TORCH-CHECK: %normalized_shape_val_1_rmsnorm_infer = torch.constant.int 32 +// TORCH-CHECK: %normalized_shape_rmsnorm_infer = torch.prim.ListConstruct %normalized_shape_val_0_rmsnorm_infer, %normalized_shape_val_1_rmsnorm_infer : (!torch.int, !torch.int) -> !torch.list +// +// AMDGPU-STATS-CHECK: "transient-memory-size": 0 +// AMDGPU-STATS-CHECK: "dispatch-count": 1 +// CPU-STATS-CHECK: "transient-memory-size": 0 +// CPU-STATS-CHECK: "dispatch-count": 1 +// +// clang-format on + +#include + +#include "utils.h" + +#include +#include +#include +#include + +using namespace fusilli; + +static ErrorObject +testRmsnormInferAsmEmitterPartialSuffixHw(const std::string &mode) { + int64_t n = 16, c = 128, h = 64, w = 32; + auto graph = std::make_shared(); + graph->setName("rmsnorm_infer_asm_emitter_partial_suffix_hw"); + graph->setIODataType(DataType::Float).setComputeDataType(DataType::Float); + + auto xT = graph->tensor(TensorAttr() + .setName("arg0_x") + .setDim({n, c, h, w}) + .setStride({c * h * w, h * w, w, 1})); // NCHW + + // scale=[1, 1, H, W]: trailing match is the last 2 dims of x → expect + // normalized_shape=[H, W]. + auto scaleT = graph->tensor(TensorAttr() + .setName("arg0_scale") + .setDim({1, 1, h, w}) + .setStride({h * w, h * w, w, 1})); + + auto epsilonT = graph->tensor(TensorAttr(1e-5f)); + + auto rmsnormAttr = RmsnormAttr() + .setForwardPhase(NormFwdPhase::INFERENCE) + .setEpsilon(epsilonT) + .setName("rmsnorm_infer"); + + auto [yT, rT] = graph->rmsnorm(xT, scaleT, rmsnormAttr); + + yT->setName("result").setOutput(true); + + FUSILLI_CHECK_ERROR(graph->validate()); + + if (mode == "default") { + FUSILLI_ASSIGN_OR_RETURN(auto generatedAsm, graph->emitAsm()); + FUSILLI_CHECK_ERROR(checkMlirIndentation(generatedAsm)); + std::cout << generatedAsm << std::endl; + } + + if (mode == "stats") { + FUSILLI_ASSIGN_OR_RETURN(Handle handle, Handle::create(kDefaultBackend)); + FUSILLI_CHECK_ERROR(graph->compile(handle, /*remove=*/true)); + FUSILLI_ASSIGN_OR_RETURN(auto stats, graph->readCompilationCacheFile( + CachedAssetsType::Statistics)); + std::cout << stats << std::endl; + } + + return ok(); +} + +int main(int argc, char **argv) { + std::string mode = (argc > 1) ? argv[1] : "default"; + + auto status = testRmsnormInferAsmEmitterPartialSuffixHw(mode); + if (isError(status)) { + std::cerr << "Test failed: " << status << std::endl; + return 1; + } + return 0; +} diff --git a/tests/lit/test_rmsnorm_infer_asm_emitter_partial_suffix_w.cpp b/tests/lit/test_rmsnorm_infer_asm_emitter_partial_suffix_w.cpp new file mode 100644 index 00000000..ae920a05 --- /dev/null +++ b/tests/lit/test_rmsnorm_infer_asm_emitter_partial_suffix_w.cpp @@ -0,0 +1,99 @@ +// Copyright 2026 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Verifies that an RMSNorm with a single-dim trailing-suffix scale +// (scale=[1,1,1,W] for x=[N,C,H,W]) emits torch.aten.rms_norm with +// normalized_shape=[W] (only the last dim of x), exercising the narrowest +// valid trailing suffix under the cuDNN/hipDNN rule. + +// RUN: %{TEST_EXE} | iree-opt --verify-roundtrip +// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK +// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK + +// clang-format off +// +// TORCH-CHECK: module @module { +// TORCH-CHECK: func.func @main({{.*}}) attributes {torch.assume_strict_symbolic_shapes} { +// TORCH-CHECK: %normalized_shape_val_0_rmsnorm_infer = torch.constant.int 32 +// TORCH-CHECK: %normalized_shape_rmsnorm_infer = torch.prim.ListConstruct %normalized_shape_val_0_rmsnorm_infer : (!torch.int) -> !torch.list +// +// AMDGPU-STATS-CHECK: "transient-memory-size": 0 +// AMDGPU-STATS-CHECK: "dispatch-count": 1 +// CPU-STATS-CHECK: "transient-memory-size": 0 +// CPU-STATS-CHECK: "dispatch-count": 1 +// +// clang-format on + +#include + +#include "utils.h" + +#include +#include +#include +#include + +using namespace fusilli; + +static ErrorObject +testRmsnormInferAsmEmitterPartialSuffixW(const std::string &mode) { + int64_t n = 16, c = 128, h = 64, w = 32; + auto graph = std::make_shared(); + graph->setName("rmsnorm_infer_asm_emitter_partial_suffix_w"); + graph->setIODataType(DataType::Float).setComputeDataType(DataType::Float); + + auto xT = graph->tensor(TensorAttr() + .setName("arg0_x") + .setDim({n, c, h, w}) + .setStride({c * h * w, h * w, w, 1})); // NCHW + + // scale=[1, 1, 1, W]: trailing match is just W → expect + // normalized_shape=[W]. + auto scaleT = graph->tensor(TensorAttr() + .setName("arg0_scale") + .setDim({1, 1, 1, w}) + .setStride({w, w, w, 1})); + + auto epsilonT = graph->tensor(TensorAttr(1e-5f)); + + auto rmsnormAttr = RmsnormAttr() + .setForwardPhase(NormFwdPhase::INFERENCE) + .setEpsilon(epsilonT) + .setName("rmsnorm_infer"); + + auto [yT, rT] = graph->rmsnorm(xT, scaleT, rmsnormAttr); + + yT->setName("result").setOutput(true); + + FUSILLI_CHECK_ERROR(graph->validate()); + + if (mode == "default") { + FUSILLI_ASSIGN_OR_RETURN(auto generatedAsm, graph->emitAsm()); + FUSILLI_CHECK_ERROR(checkMlirIndentation(generatedAsm)); + std::cout << generatedAsm << std::endl; + } + + if (mode == "stats") { + FUSILLI_ASSIGN_OR_RETURN(Handle handle, Handle::create(kDefaultBackend)); + FUSILLI_CHECK_ERROR(graph->compile(handle, /*remove=*/true)); + FUSILLI_ASSIGN_OR_RETURN(auto stats, graph->readCompilationCacheFile( + CachedAssetsType::Statistics)); + std::cout << stats << std::endl; + } + + return ok(); +} + +int main(int argc, char **argv) { + std::string mode = (argc > 1) ? argv[1] : "default"; + + auto status = testRmsnormInferAsmEmitterPartialSuffixW(mode); + if (isError(status)) { + std::cerr << "Test failed: " << status << std::endl; + return 1; + } + return 0; +} diff --git a/tests/test_rmsnorm_node.cpp b/tests/test_rmsnorm_node.cpp index ed491642..a0bc46da 100644 --- a/tests/test_rmsnorm_node.cpp +++ b/tests/test_rmsnorm_node.cpp @@ -58,47 +58,53 @@ TEST_CASE("RmsNormNode preValidateNode detects missing attributes", REQUIRE(status.getMessage() == "RmsNorm input tensor X not set"); } - SECTION("Output Y missing") { + SECTION("Input SCALE missing") { attr.setForwardPhase(NormFwdPhase::INFERENCE) .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); + attr.setY(std::make_shared( + TensorAttr().setDim({2, 3}).setStride({3, 1}))); RmsNormNode node(std::move(attr), ctx); auto status = node.preValidateNode(); REQUIRE(isError(status)); REQUIRE(status.getCode() == ErrorCode::AttributeNotSet); - REQUIRE(status.getMessage() == "RmsNorm output tensor Y not set"); + REQUIRE(status.getMessage() == "RmsNorm input tensor SCALE not set"); } - SECTION("Epsilon missing") { - attr.setForwardPhase(NormFwdPhase::INFERENCE); + SECTION("Output Y missing") { + attr.setForwardPhase(NormFwdPhase::INFERENCE) + .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); - attr.setY(std::make_shared( - TensorAttr().setDim({2, 3}).setStride({3, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 3}).setStride({3, 1}))); RmsNormNode node(std::move(attr), ctx); auto status = node.preValidateNode(); REQUIRE(isError(status)); REQUIRE(status.getCode() == ErrorCode::AttributeNotSet); - REQUIRE(status.getMessage() == "RmsNorm epsilon not set"); + REQUIRE(status.getMessage() == "RmsNorm output tensor Y not set"); } - SECTION("All required attributes present for INFERENCE forward phase") { - attr.setForwardPhase(NormFwdPhase::INFERENCE) - .setEpsilon(std::make_shared(1e-5f)); + SECTION("Epsilon missing") { + attr.setForwardPhase(NormFwdPhase::INFERENCE); attr.setX(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 3}).setStride({3, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); RmsNormNode node(std::move(attr), ctx); - FUSILLI_REQUIRE_OK(node.preValidateNode()); + auto status = node.preValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::AttributeNotSet); + REQUIRE(status.getMessage() == "RmsNorm epsilon not set"); } - SECTION("All required and optional attributes present for INFERENCE forward " - "phase") { + SECTION("All required attributes present for INFERENCE forward phase") { attr.setForwardPhase(NormFwdPhase::INFERENCE) .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( @@ -117,6 +123,8 @@ TEST_CASE("RmsNormNode preValidateNode detects missing attributes", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 3}).setStride({3, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); attr.setINV_RMS(std::make_shared( @@ -135,6 +143,8 @@ TEST_CASE("RmsNormNode preValidateNode detects missing attributes", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 3}).setStride({3, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({2, 3}).setStride({3, 1}))); RmsNormNode node(std::move(attr), ctx); @@ -146,21 +156,6 @@ TEST_CASE("RmsNormNode preValidateNode detects missing attributes", } SECTION("All required attributes present for TRAINING forward phase") { - attr.setForwardPhase(NormFwdPhase::TRAINING) - .setEpsilon(std::make_shared(1e-5f)); - attr.setX(std::make_shared( - TensorAttr().setDim({2, 3}).setStride({3, 1}))); - attr.setY(std::make_shared( - TensorAttr().setDim({2, 3}).setStride({3, 1}))); - attr.setINV_RMS(std::make_shared( - TensorAttr().setDim({2, 1}).setStride({1, 1}))); - RmsNormNode node(std::move(attr), ctx); - - FUSILLI_REQUIRE_OK(node.preValidateNode()); - } - - SECTION("All required and optional attributes present for TRAINING forward " - "phase") { attr.setForwardPhase(NormFwdPhase::TRAINING) .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( @@ -188,6 +183,8 @@ TEST_CASE( attr.setX(std::make_shared( TensorAttr().setDim({n, c}).setStride({c, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c}).setStride({c, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({n, c}).setStride({c, 1}))); attr.setINV_RMS(std::make_shared( @@ -215,6 +212,8 @@ TEST_CASE( attr.setX(std::make_shared( TensorAttr().setDim({n, c}).setStride({c, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c}).setStride({c, 1}))); attr.setY(std::make_shared()); attr.setINV_RMS(std::make_shared()); @@ -289,25 +288,132 @@ TEST_CASE("RmsNormNode shape checks on SCALE tensor", "[rmsnorm_node]") { Context ctx; RmsnormAttr attr; - int64_t n = 2, c = 3, d = 4; + // 4D NCHW-packed x for trailing-suffix coverage. + int64_t n = 2, c = 4, h = 8, w = 8; + std::vector xDim = {n, c, h, w}; + std::vector xStride = {c * h * w, h * w, w, 1}; - SECTION("Incorrect SCALE shape") { - attr.setForwardPhase(NormFwdPhase::INFERENCE) - .setEpsilon(std::make_shared(1e-5f)); - attr.setX(std::make_shared( - TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); + attr.setForwardPhase(NormFwdPhase::INFERENCE) + .setEpsilon(std::make_shared(1e-5f)); + attr.setX(std::make_shared( + TensorAttr().setDim(xDim).setStride(xStride))); + attr.setY(std::make_shared()); + + // Each SECTION re-runs from the top of the TEST_CASE body, so `attr` is + // freshly populated; we override SCALE per SECTION to test the trailing- + // suffix rule: + // reduction = maximal trailing suffix where scale[i] == x[i], + // leading region (excluding batch) must be all-1, + // batch dim (scale[0]) must be 1. + + SECTION("Canonical full suffix accepted") { attr.setSCALE(std::make_shared( - TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); - attr.setY(std::make_shared()); + TensorAttr().setDim({1, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + FUSILLI_REQUIRE_OK(node.preValidateNode()); + } + + SECTION("Trailing suffix [H,W] accepted") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, h, w}).setStride({h * w, h * w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + FUSILLI_REQUIRE_OK(node.preValidateNode()); + } + + SECTION("Trailing suffix [W] only accepted") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, 1, w}).setStride({w, w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + FUSILLI_REQUIRE_OK(node.preValidateNode()); + } + + SECTION("Leading region non-1 rejected (sandwich)") { + // scale=[1, c, 1, w]: trailing match is [w] (matchCount=1), but the + // leading region (positions 1..2) contains scale[1]=c != 1. + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, 1, w}).setStride({c * w, w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + + auto status = node.preValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); + REQUIRE(status.getMessage() == "RmsNorm SCALE leading region (before " + "normalized shape) must be 1"); + } + + SECTION("No trailing match rejected (per-channel)") { + // scale=[1, c, 1, 1]: trailing dim 1 vs x's W=w -> no match. + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, 1, 1}).setStride({c, 1, 1, 1}))); + RmsNormNode node(std::move(attr), ctx); + + auto status = node.preValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); + REQUIRE(status.getMessage() == + "RmsNorm SCALE has no trailing dims matching X — at least " + "one normalized dim is required"); + } + + SECTION("All-1 scale on non-degenerate x rejected") { + // x=[N,C,H,W] with non-1 spatials, scale=[1,1,1,1] -> no trailing + // match (scale[3]=1 vs x[3]=W). + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, 1, 1}).setStride({1, 1, 1, 1}))); + RmsNormNode node(std::move(attr), ctx); + + auto status = node.preValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); + REQUIRE(status.getMessage() == + "RmsNorm SCALE has no trailing dims matching X — at least " + "one normalized dim is required"); + } + + SECTION("Different rank rejected") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({c, h, w}).setStride({h * w, w, 1}))); RmsNormNode node(std::move(attr), ctx); auto status = node.preValidateNode(); REQUIRE(isError(status)); REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == - "RmsNorm input tensor SCALE must have shape as " - "tensor X with single batch"); + "RmsNorm SCALE tensor must have the same rank as X"); } + + SECTION("Non-1 batch dim rejected") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({n, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + + auto status = node.preValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); + REQUIRE(status.getMessage() == + "RmsNorm SCALE tensor must have batch dim equal to 1 " + "(broadcast across batch)"); + } +} + +TEST_CASE("RmsNormNode degenerate all-1 x accepted with all-1 scale", + "[rmsnorm_node]") { + // Special case from the cuDNN/hipDNN rule: x=[N,1,1,...] with + // scale=[1,1,1,...] passes because trailing match goes all the way to + // the batch boundary. + Context ctx; + RmsnormAttr attr; + + attr.setForwardPhase(NormFwdPhase::INFERENCE) + .setEpsilon(std::make_shared(1e-5f)); + attr.setX(std::make_shared( + TensorAttr().setDim({2, 1, 1, 1}).setStride({1, 1, 1, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, 1, 1}).setStride({1, 1, 1, 1}))); + attr.setY(std::make_shared()); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); } TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", @@ -322,6 +428,8 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, d}).setStride({c * d, d, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({n + 1, c, d}).setStride({c * d, d, 1}))); RmsNormNode node(std::move(attr), ctx); @@ -341,6 +449,8 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, d}).setStride({c * d, d, 1}))); attr.setY(std::make_shared(TensorAttr() .setDim({n, c, d}) .setStride({d, c * d, 1}) @@ -362,6 +472,8 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, d}).setStride({c * d, d, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); attr.setINV_RMS(std::make_shared( @@ -374,9 +486,9 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", REQUIRE(isError(status)); REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == - "RmsNorm output INV_RMS tensor must have shape [B, 1, ..., 1] with " - "rank equal to input X tensor's rank, and batch dimension equal " - "to input X tensor's batch dimension"); + "RmsNorm output INV_RMS tensor must have x's leading (broadcast) " + "dims preserved and the normalized (trailing) region collapsed to " + "1, with rank equal to input X tensor's rank"); } SECTION("Output INV_RMS has incorrect stride") { @@ -384,6 +496,8 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, d}).setStride({c * d, d, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); attr.setINV_RMS(std::make_shared( @@ -397,7 +511,59 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", REQUIRE(isError(status)); REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == - "RmsNorm output INV_RMS tensor must have unit strides"); + "RmsNorm output INV_RMS tensor must have strides preserving " + "input X tensor's stride order"); + } + + SECTION("Output INV_RMS partial trailing suffix [H,W] accepted") { + // x=[N,C,H,W], scale=[1,1,H,W] -> normalized=[H,W], invRms=[N,C,1,1]. + int64_t n = 2, c = 4, h = 8, w = 8; + attr.setForwardPhase(NormFwdPhase::TRAINING) + .setEpsilon(std::make_shared(1e-5f)); + attr.setX(std::make_shared( + TensorAttr().setDim({n, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, h, w}).setStride({h * w, h * w, w, 1}))); + attr.setY(std::make_shared( + TensorAttr().setDim({n, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + attr.setINV_RMS(std::make_shared( + TensorAttr().setDim({n, c, 1, 1}).setStride({c, 1, 1, 1}))); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + auto status = node.postValidateNode(); + // Training is gated NotImplemented at the very end; INV_RMS shape/stride + // checks succeed first, so this is the only remaining error we expect. + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::NotImplemented); + } + + SECTION("Output INV_RMS canonical shape rejected for partial-suffix scale") { + // x=[N,C,H,W], scale=[1,1,H,W] -> expected invRms=[N,C,1,1]; pre-setting + // [N,1,1,1] (the old "always-canonical" shape) must now be rejected. + int64_t n = 2, c = 4, h = 8, w = 8; + attr.setForwardPhase(NormFwdPhase::TRAINING) + .setEpsilon(std::make_shared(1e-5f)); + attr.setX(std::make_shared( + TensorAttr().setDim({n, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, h, w}).setStride({h * w, h * w, w, 1}))); + attr.setY(std::make_shared( + TensorAttr().setDim({n, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + attr.setINV_RMS(std::make_shared( + TensorAttr().setDim({n, 1, 1, 1}).setStride({1, 1, 1, 1}))); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + auto status = node.postValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); + REQUIRE(status.getMessage() == + "RmsNorm output INV_RMS tensor must have x's leading (broadcast) " + "dims preserved and the normalized (trailing) region collapsed to " + "1, with rank equal to input X tensor's rank"); } SECTION("TRAINING forward phase is not yet supported") { @@ -405,6 +571,8 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", .setEpsilon(std::make_shared(1e-5f)); attr.setX(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, d}).setStride({c * d, d, 1}))); attr.setY(std::make_shared( TensorAttr().setDim({n, c, d}).setStride({c * d, d, 1}))); attr.setINV_RMS(std::make_shared( @@ -418,3 +586,108 @@ TEST_CASE("RmsNormNode postValidateNode detects incorrect shapes and strides", REQUIRE(status.getCode() == ErrorCode::NotImplemented); } } + +TEST_CASE("RmsNormNode inferPropertiesNode infers INV_RMS shape from scale's " + "broadcast pattern", + "[rmsnorm_node]") { + // The INV_RMS output collapses x's trailing (normalized) dims to 1 while + // preserving the leading (broadcast) dims (cuDNN/hipDNN trailing-suffix + // rule): + // invRms[i] = (scale[i] == 1) ? x[i] : 1 + Context ctx; + RmsnormAttr attr; + + // 4D NCHW-packed x for trailing-suffix coverage. + int64_t n = 2, c = 4, h = 8, w = 8; + std::vector xDim = {n, c, h, w}; + std::vector xStride = {c * h * w, h * w, w, 1}; + + attr.setForwardPhase(NormFwdPhase::TRAINING) + .setEpsilon(std::make_shared(1e-5f)); + attr.setX(std::make_shared( + TensorAttr().setDim(xDim).setStride(xStride))); + attr.setY(std::make_shared()); + attr.setINV_RMS(std::make_shared()); + + SECTION("Canonical full suffix: scale=[1,C,H,W] -> INV_RMS=[N,1,1,1]") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, c, h, w}).setStride({c * h * w, h * w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + + auto rT = node.rmsnormAttr.getINV_RMS(); + REQUIRE(rT->getDim() == std::vector{n, 1, 1, 1}); + REQUIRE(rT->getStride() == std::vector{1, 1, 1, 1}); + } + + SECTION("Trailing suffix [H,W]: scale=[1,1,H,W] -> INV_RMS=[N,C,1,1]") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, h, w}).setStride({h * w, h * w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + + auto rT = node.rmsnormAttr.getINV_RMS(); + REQUIRE(rT->getDim() == std::vector{n, c, 1, 1}); + REQUIRE(rT->getStride() == std::vector{c, 1, 1, 1}); + } + + SECTION("Trailing suffix [W]: scale=[1,1,1,W] -> INV_RMS=[N,C,H,1]") { + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, 1, w}).setStride({w, w, w, 1}))); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + + auto rT = node.rmsnormAttr.getINV_RMS(); + REQUIRE(rT->getDim() == std::vector{n, c, h, 1}); + REQUIRE(rT->getStride() == std::vector{c * h, h, 1, 1}); + } +} + +TEST_CASE("RmsNormNode infers INV_RMS strides from input layout", + "[rmsnorm_node]") { + // NHWC x -> INV_RMS strides preserve NHWC order, mirroring the scale-stride + // pattern from getScaleBiasStride. + Context ctx; + RmsnormAttr attr; + + int64_t n = 2, c = 4, h = 8, w = 8; + // NHWC stride order = [0, 3, 1, 2]: outermost is N, then H, W, C innermost. + // Stride values: N stride = H*W*C, H stride = W*C, W stride = C, C stride + // = 1. + std::vector xDim = {n, c, h, w}; + std::vector xStride = {h * w * c, 1, w * c, c}; + + attr.setForwardPhase(NormFwdPhase::TRAINING) + .setEpsilon(std::make_shared(1e-5f)); + attr.setX(std::make_shared( + TensorAttr().setDim(xDim).setStride(xStride))); + attr.setSCALE(std::make_shared( + TensorAttr().setDim({1, 1, 1, w}).setStride({w, 1, w, 1}))); + attr.setY(std::make_shared()); + attr.setINV_RMS(std::make_shared()); + RmsNormNode node(std::move(attr), ctx); + + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + + auto rT = node.rmsnormAttr.getINV_RMS(); + // dim collapses W (the normalized region) -> [N, C, H, 1] + REQUIRE(rT->getDim() == std::vector{n, c, h, 1}); + // Stride preserves x's NHWC order. xStride [256, 1, 32, 4] sorts the dims + // outer-to-inner as [N, H, W, C] (idx 0, 2, 3, 1). Walking that order on + // dim [N=2, C=4, H=8, W=1]: + // C (innermost): stride[1] = 1 + // W (next, dim=1): stride[3] = 1 * dim[C]=4 = 4 + // H (next, dim=8): stride[2] = 4 * dim[W]=1 = 4 + // N (outermost, dim=2): stride[0] = 4 * dim[H]=8 = 32 + // Stride[3]==stride[2]==4 because W's dim is 1 (collapsed); both indices + // step the same number of bytes, which is fine since W=1 is never indexed + // beyond 0. + REQUIRE(rT->getStride() == std::vector{h * c, 1, c, c}); +}