Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions include/fusilli/node/norm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

#include "fusilli/attributes/tensor_attributes.h"

#include <algorithm>
#include <cstdint>
#include <iterator>
#include <utility>
#include <vector>

Expand All @@ -35,6 +37,23 @@ inline std::vector<int64_t> getNormalizedShape(const std::vector<int64_t> &xDim,
return shape;
}

// Returns normalized_shape as the trailing suffix of `xDim` where
// `sDim[i] == xDim[i]`.
//
// Precondition: callers must have validated `sDim` against `xDim` (done in
// RmsNormNode::preValidateNode).
inline std::vector<int64_t>
getRMSNormNormalizedShape(const std::vector<int64_t> &xDim,
const std::vector<int64_t> &sDim) {
const auto [sMismatch, _] =
std::mismatch(sDim.rbegin(), sDim.rend(), xDim.rbegin(), xDim.rend());
// Keep matchCount as the iterator's signed difference_type so the
// `xDim.end() - matchCount` subtraction below doesn't trip
// clang-tidy's bugprone-narrowing-conversions check.
const auto matchCount = std::distance(sDim.rbegin(), sMismatch);
return std::vector<int64_t>(xDim.end() - matchCount, xDim.end());
}

// Returns [1, ..., B, ..., 1] dim and unit strides for training forward outputs
// (e.g. MEAN, INV_VARIANCE, INV_RMS), where only the batch dimension is
// preserved from xDim.
Expand Down Expand Up @@ -66,6 +85,27 @@ getScaleBiasStride(const std::vector<int64_t> &scaleBiasDim,
return generateStrideFromDim(scaleBiasDim, strideOrder);
}

// Returns dim and stride for an RmsNormNode INV_RMS output:
// invRms[i] = x[i] if scale[i] == 1 (leading region preserved)
// = 1 otherwise (normalized region collapsed)
// Stride preserves x's stride order via getScaleBiasStride (matches hipDNN's
// RMSNormNode.hpp invRms inference and our scale/bias stride convention).
//
// Precondition: callers must have validated `sDim` against `xDim` (done in
// RmsNormNode::preValidateNode).
inline std::pair<std::vector<int64_t>, std::vector<int64_t>>
getRMSNormInvRmsDimAndStride(const std::vector<int64_t> &xDim,
const std::vector<int64_t> &sDim,
const std::vector<int64_t> &xStride) {
std::vector<int64_t> dim = xDim;
for (size_t i = 0; i < dim.size(); ++i) {
if (sDim[i] != 1)
dim[i] = 1;
}
std::vector<int64_t> stride = getScaleBiasStride(dim, xStride);
return {dim, stride};
}

// Infers dim of a tensor if not already set.
inline void inferDim(std::shared_ptr<TensorAttr> &tensor,
const std::vector<int64_t> &dim) {
Expand Down
86 changes: 60 additions & 26 deletions include/fusilli/node/rmsnorm_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
#include "fusilli/node/norm_utils.h"
#include "fusilli/support/logging.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -72,6 +74,8 @@ class RmsNormNode : public NodeCRTP<RmsNormNode> {
// Ensure mandatory input and output tensors are set.
FUSILLI_RETURN_ERROR_IF(!xT, ErrorCode::AttributeNotSet,
"RmsNorm input tensor X not set");
FUSILLI_RETURN_ERROR_IF(!sT, ErrorCode::AttributeNotSet,
"RmsNorm input tensor SCALE not set");
FUSILLI_RETURN_ERROR_IF(!yT, ErrorCode::AttributeNotSet,
"RmsNorm output tensor Y not set");

Expand All @@ -89,25 +93,55 @@ class RmsNormNode : public NodeCRTP<RmsNormNode> {
// Shape and layout checks on scale tensor.
// If scale tensor's dims/strides are not set, they will be inferred in
// inferPropertiesNode().
if (sT) {
if (!sT->getDim().empty()) {
FUSILLI_RETURN_ERROR_IF(sT->getDim() !=
norm_utils::getScaleBiasDim(xT->getDim()),
ErrorCode::InvalidAttribute,
"RmsNorm input tensor SCALE must have shape as "
"tensor X with single batch");
}
//
// Scale encodes the normalized_shape as a trailing suffix of x:
// reduction = maximal trailing suffix where scale[i] == x[i],
// leading region (excluding batch) must be all-1,
// batch dim (scale[0]) must be 1 (broadcast across batch).
if (!sT->getDim().empty()) {
const auto &xDim = xT->getDim();
const auto &sDim = sT->getDim();
constexpr size_t batchDim = 0;

FUSILLI_RETURN_ERROR_IF(
sDim.size() != xDim.size(), ErrorCode::InvalidAttribute,
"RmsNorm SCALE tensor must have the same rank as X");

FUSILLI_RETURN_ERROR_IF(
sDim[batchDim] != 1, ErrorCode::InvalidAttribute,
"RmsNorm SCALE tensor must have batch dim equal to 1 "
"(broadcast across batch)");

// matchCount = number of trailing dims where scale[i] == x[i].
const auto [sMismatch, _] =
std::mismatch(sDim.rbegin(), sDim.rend(), xDim.rbegin(), xDim.rend());
const size_t matchCount =
static_cast<size_t>(std::distance(sDim.rbegin(), sMismatch));

if (!sT->getStride().empty()) {
FUSILLI_RETURN_ERROR_IF(
matchCount == 0, ErrorCode::InvalidAttribute,
"RmsNorm SCALE has no trailing dims matching X — at least "
"one normalized dim is required");

// Leading region (between batch and the matching trailing
// suffix) must be all-1.
for (size_t i = batchDim + 1; i < sDim.size() - matchCount; ++i) {
FUSILLI_RETURN_ERROR_IF(
!sT->isContiguous() && !sT->isChannelsLast(),
ErrorCode::NotImplemented,
"Tensor '" + sT->getName() +
"' is neither contiguous nor channels-last as "
"defined by its stride");
sDim[i] != 1, ErrorCode::InvalidAttribute,
"RmsNorm SCALE leading region (before normalized shape) "
"must be 1");
}
}

if (!sT->getStride().empty()) {
FUSILLI_RETURN_ERROR_IF(
!sT->isContiguous() && !sT->isChannelsLast(),
ErrorCode::NotImplemented,
"Tensor '" + sT->getName() +
"' is neither contiguous nor channels-last as "
"defined by its stride");
}

// Output tensor checks for training and inference forward phases.
if (isTrainingForwardPhase()) {
FUSILLI_RETURN_ERROR_IF(!rT, ErrorCode::AttributeNotSet,
Expand Down Expand Up @@ -141,17 +175,15 @@ class RmsNormNode : public NodeCRTP<RmsNormNode> {

// Infer shape and stride of input SCALE tensor if they're not set.
std::shared_ptr<TensorAttr> sT = rmsnormAttr.getSCALE();
if (sT) {
norm_utils::inferScaleBiasDimAndStride(sT, xDim, xT->getStride());
}
norm_utils::inferScaleBiasDimAndStride(sT, xDim, xT->getStride());

// Infer shape and stride of output Y tensor.
// When stride is unspecified, preserve the stride order of xT.
norm_utils::inferDimAndStride(yT, xDim, xT->getStride());

if (isTrainingForwardPhase()) {
const auto &[dim, stride] =
norm_utils::getTrainingForwardOutputDimAndStride(xDim);
const auto &[dim, stride] = norm_utils::getRMSNormInvRmsDimAndStride(
xDim, rmsnormAttr.getSCALE()->getDim(), xT->getStride());

// Infer shape and stride of output INV_RMS tensor.
std::shared_ptr<TensorAttr> rT = rmsnormAttr.getINV_RMS();
Expand Down Expand Up @@ -183,21 +215,22 @@ class RmsNormNode : public NodeCRTP<RmsNormNode> {
"defined by its stride");

if (isTrainingForwardPhase()) {
const auto &[dim, stride] =
norm_utils::getTrainingForwardOutputDimAndStride(xDim);
const auto &[dim, stride] = norm_utils::getRMSNormInvRmsDimAndStride(
xDim, rmsnormAttr.getSCALE()->getDim(), xT->getStride());

std::shared_ptr<TensorAttr> rT = rmsnormAttr.getINV_RMS();

// Shape check for output INV_RMS tensor
FUSILLI_RETURN_ERROR_IF(
dim != rT->getDim(), ErrorCode::InvalidAttribute,
"RmsNorm output INV_RMS tensor must have shape [B, 1, ..., 1] with "
"rank equal to input X tensor's rank, and batch dimension equal "
"to input X tensor's batch dimension");
"RmsNorm output INV_RMS tensor must have x's leading (broadcast) "
"dims preserved and the normalized (trailing) region collapsed to "
"1, with rank equal to input X tensor's rank");
// Stride check for output INV_RMS tensor
FUSILLI_RETURN_ERROR_IF(
stride != rT->getStride(), ErrorCode::InvalidAttribute,
"RmsNorm output INV_RMS tensor must have unit strides");
"RmsNorm output INV_RMS tensor must have strides preserving "
"input X tensor's stride order");
}

FUSILLI_RETURN_ERROR_IF(
Expand All @@ -214,7 +247,8 @@ class RmsNormNode : public NodeCRTP<RmsNormNode> {
}

std::vector<int64_t> getNormalizedShape() const {
return norm_utils::getNormalizedShape(rmsnormAttr.getX()->getDim());
return norm_utils::getRMSNormNormalizedShape(
rmsnormAttr.getX()->getDim(), rmsnormAttr.getSCALE()->getDim());
}
};

Expand Down
19 changes: 6 additions & 13 deletions include/fusilli/support/asm_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1477,10 +1477,8 @@ inline std::string RmsNormNode::getOperandNamesAsm() const {

oss << rmsnormAttr.getX()->getValueNameAsm() << "_" << suffix << "_perm, ";
oss << "%normalized_shape_" << suffix << ", ";

auto sT = rmsnormAttr.getSCALE();
oss << (sT ? sT->getValueNameAsm() + "_" + suffix + "_perm, "
: "%none_scale_" + suffix + ", ");
oss << rmsnormAttr.getSCALE()->getValueNameAsm() << "_" << suffix
<< "_perm, ";
oss << "%eps_" << suffix;

return oss.str();
Expand All @@ -1494,11 +1492,8 @@ inline std::string RmsNormNode::getOperandTypesAsm() const {
/*useLogicalDims=*/true)
<< ", ";
oss << "!torch.list<int>" << ", ";

auto sT = rmsnormAttr.getSCALE();
oss << (sT ? sT->getTensorTypeAsm(/*isValueTensor=*/true,
/*useLogicalDims=*/true)
: "!torch.none")
oss << rmsnormAttr.getSCALE()->getTensorTypeAsm(/*isValueTensor=*/true,
/*useLogicalDims=*/true)
<< ", ";
oss << "!torch.float";

Expand Down Expand Up @@ -1558,10 +1553,8 @@ inline std::string RmsNormNode::emitNodePreAsm() const {
std::string permuteY = getLayoutConversionOpsAsm(
rmsnormAttr.getY(), "permute_y", uniqueSSASuffix, /*isInput=*/false);
std::string permuteScale =
rmsnormAttr.getSCALE()
? getLayoutConversionOpsAsm(rmsnormAttr.getSCALE(), "permute_scale",
uniqueSSASuffix, /*isInput=*/true)
: torchNoneAsm("none_scale", uniqueSSASuffix);
getLayoutConversionOpsAsm(rmsnormAttr.getSCALE(), "permute_scale",
uniqueSSASuffix, /*isInput=*/true);

constexpr std::string_view schema = R"(
{0}
Expand Down
1 change: 0 additions & 1 deletion samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ add_fusilli_samples(
add_fusilli_samples(
PREFIX fusilli_rmsnorm_samples
SRCS
rmsnorm/rmsnorm_infer_nchw.cpp
rmsnorm/rmsnorm_infer_nchw_scale.cpp
rmsnorm/rmsnorm_infer_nhwc_scale.cpp
DEPS
Expand Down
110 changes: 0 additions & 110 deletions samples/rmsnorm/rmsnorm_infer_nchw.cpp

This file was deleted.

3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ add_fusilli_lit_tests(
lit/test_layernorm_infer_asm_emitter_scale_bias_nhwc_small_batch.cpp
lit/test_layernorm_train_asm_emitter_nchw.cpp
lit/test_layernorm_train_asm_emitter_scale_bias_nhwc.cpp
lit/test_rmsnorm_infer_asm_emitter_nchw.cpp
lit/test_rmsnorm_infer_asm_emitter_partial_suffix_hw.cpp
lit/test_rmsnorm_infer_asm_emitter_partial_suffix_w.cpp
lit/test_rmsnorm_infer_asm_emitter_scale_nhwc.cpp
lit/test_matmul_asm_emitter_basic.cpp
lit/test_matmul_asm_emitter_batched.cpp
Expand Down
Loading
Loading