Skip to content

align RMSNorm API with hipDNN#388

Open
AaronStGeorge wants to merge 1 commit into
iree-org:mainfrom
AaronStGeorge:p047-rmsnorm-end-to-end
Open

align RMSNorm API with hipDNN#388
AaronStGeorge wants to merge 1 commit into
iree-org:mainfrom
AaronStGeorge:p047-rmsnorm-end-to-end

Conversation

@AaronStGeorge
Copy link
Copy Markdown
Contributor

@AaronStGeorge AaronStGeorge commented May 4, 2026

Why

RMSNorm over an example 1d vector is defined as

$$y_i = \frac{x_i}{\texttt{RMS}(x)} \cdot \gamma_i$$ $$\texttt{RMS}(x) = \sqrt{\epsilon + \frac{1}{n}\sum_{j=1}^{n} x_j^2}$$

$\texttt{RMS}$ is computed over what PyTorch calls normalized_shape.

Fusilli currently always uses all non-batch dimensions for normalized_shape. For example with a [N, C, H, W] fusilli computes $\texttt{RMS}$ over [C, H, W] regardless of what scale is set and expects inverse rms output to be [N, 1, 1, 1] for all input shapes.

The semantics diverge from cuDNN/hipDNN which derive normalized_shape from the maximal trailing suffix of input where scale[i] == input[i].

cuDNN's rules for determining normalized shape from scale:

  • scale rank == input rank
  • scale[0] == 1 (batch broadcast)
  • normalized shape = maximal trailing suffix where scale[i] == input[i]; must be non-empty
  • in the leading region (dims before the trailing match), scale[i] == 1

Inverse RMS shape is always the "kept dimensions" aka inputs dimensions where scale[i] == 1, exactly the dimensions that are not normalized over.

Examples:

  • input [N, C, H, W], scale [1, C, H, W] -> normalized_shape = [C, H, W], inv_rms = [N, 1, 1, 1]
  • input [N, C, H, W], scale [1, 1, H, W] ->normalized_shape = [H, W], inv_rms = [N, C, 1, 1]
  • input [N, C, H, W], scale [1, 1, 1, W] -> normalized_shape = [W], inv_rms = [N, C, H, 1]
  • (degenerate but valid) input [N, 1, 1, 1], scale [1, 1, 1, 1] -> normalized_shape = [1,1,1], inv_rms = [N, 1, 1, 1]

Note: cuDNN and PyTorch frontends require normalized_shape to match input dims without broadcasts. For example, input [N, C, H, W], scale [1, 1, H, 1] would be rejected rather than assuming that scale is broadcast over W dimension.

What

  • Updates scale to required parameter for RMSNorm
  • Updates normalized_shape calculation and emission to match cuDNN/hipDNN semantics
  • Updates expected inv_rms shape to be the kept dimensions

@AaronStGeorge AaronStGeorge marked this pull request as ready for review May 4, 2026 03:53
@sjain-stanford
Copy link
Copy Markdown
Member

Could you populate the PR description with the "what" and "why" for this change? The title says align the API with hipdnn but I don't know what was misaligned and how we're fixing that. The PR description also gets rolled into the final commit message so good for record-keeping (especially for agents to look back at the history).

@AaronStGeorge AaronStGeorge force-pushed the p047-rmsnorm-end-to-end branch 6 times, most recently from c1e112e to 79a9506 Compare May 8, 2026 21:07
RMSNorm computes RMS over what PyTorch calls `normalized_shape`. cuDNN
(and hipDNN) derive `normalized_shape` from the maximal trailing
suffix of input where `scale[i] == input[i]`, with the leading region
(where `scale[i] == 1`) preserved on the inv_rms output.

cuDNN's rules:
* scale rank == input rank
* scale[0] == 1 (batch broadcast)
* normalized_shape = maximal trailing suffix where scale[i] == input[i],
  must be non-empty
* in the leading region (dims before the trailing match), scale[i] == 1

Examples:
* input [N,C,H,W], scale [1,C,H,W] -> normalized_shape=[C,H,W],
  inv_rms=[N,1,1,1]
* input [N,C,H,W], scale [1,1,H,W] -> normalized_shape=[H,W],
  inv_rms=[N,C,1,1]
* input [N,C,H,W], scale [1,1,1,W] -> normalized_shape=[W],
  inv_rms=[N,C,H,1]
* (degenerate) input [N,1,1,1], scale [1,1,1,1] ->
  normalized_shape=[1,1,1], inv_rms=[N,1,1,1]

Equivalently, inv_rms[i] = (scale[i] == 1) ? input[i] : 1 — every dim
that scale broadcasts over is kept; every dim it matches is collapsed
to 1.

Fusilli previously took all non-batch dims of x for `normalized_shape`
regardless of scale, treated SCALE as optional (with a no-scale path
nothing downstream actually supported), and rejected any inv_rms
shape other than [N, 1, ..., 1].

Changes:
* Make SCALE a required input — *breaking change*: callers passing
  null SCALE (C++) or omitting it (Python binding) now hit
  `Tensor SCALE not set`. The optional-SCALE fallback in
  preValidateNode, inferPropertiesNode, getNormalizedShape, and the
  ASM emitter is removed.
* Replace `normalized_shape = all non-batch dims of x` with the
  cuDNN trailing-suffix rule. Validator enforces rank, batch
  broadcast, non-empty trailing match, and leading-region all-1.
  ASM emitter derives the MLIR `normalized_shape` from scale's
  trailing match.
* Replace inv_rms shape inference + validation with the per-dim
  rule `invRms[i] = (scale[i] == 1) ? x[i] : 1`. Stride preserves
  x's stride order via norm_utils::getScaleBiasStride (matches
  hipDNN's RMSNormNode and our scale/bias stride convention).
* New norm_utils helpers: getRMSNormNormalizedShape(xDim, sDim) and
  getRMSNormInvRmsDimAndStride(xDim, sDim, xStride). Distinct names
  from the existing batch-driven LayerNorm helper to keep call
  sites unambiguous about which derivation rule is in play.

Tests: trailing-suffix accept (full canonical, [H,W], [W]),
leading-region reject, no-trailing-match reject, degenerate all-1
accept, partial-suffix inv_rms accept and canonical-when-partial
reject; new unit tests for inv_rms shape inference (canonical,
partial-suffix variants) and stride preservation under NHWC; lit
tests for partial-suffix [H,W] / [W] emission. The deleted
`samples/rmsnorm/rmsnorm_infer_nchw.cpp` and
`tests/lit/test_rmsnorm_infer_asm_emitter_nchw.cpp` exercised the
no-scale path that no longer exists.

Empirical confirmation of cuDNN's contract on A100 + cuDNN 9.19.0
via `rmsnorm_semantics_demo/test_scale_shape_reduction.py` and
`test_rmsnorm_no_scale.py`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@AaronStGeorge AaronStGeorge force-pushed the p047-rmsnorm-end-to-end branch from 79a9506 to 1cf8710 Compare May 8, 2026 21:13
Copy link
Copy Markdown
Member

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make normalized_shape an input for rmsnorm and layernorm and not derive normalized_shape from the maximal trailing suffix of input. Here's why:

For it to work, we need either scale to be required (to infer norm_shape) or we require a norm shape to be specified.

Both layernorm and RMSnorm in PT don't require a scale -> only if elementwise_affine is true. Making it required in fusilli is a bit jarring if users trained without scaling.

So the least common divisor is to expect a norm_shape on both, and then let scale / bias be same shape as norm_shape.

We should do this at fusilli layer. hipdnn can do what it does today which is using trailing suffix of input for norm shape, and feed that "inferred" normalized shape into fusilli's op builder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants