align RMSNorm API with hipDNN#388
Conversation
|
Could you populate the PR description with the "what" and "why" for this change? The title says align the API with hipdnn but I don't know what was misaligned and how we're fixing that. The PR description also gets rolled into the final commit message so good for record-keeping (especially for agents to look back at the history). |
c1e112e to
79a9506
Compare
RMSNorm computes RMS over what PyTorch calls `normalized_shape`. cuDNN (and hipDNN) derive `normalized_shape` from the maximal trailing suffix of input where `scale[i] == input[i]`, with the leading region (where `scale[i] == 1`) preserved on the inv_rms output. cuDNN's rules: * scale rank == input rank * scale[0] == 1 (batch broadcast) * normalized_shape = maximal trailing suffix where scale[i] == input[i], must be non-empty * in the leading region (dims before the trailing match), scale[i] == 1 Examples: * input [N,C,H,W], scale [1,C,H,W] -> normalized_shape=[C,H,W], inv_rms=[N,1,1,1] * input [N,C,H,W], scale [1,1,H,W] -> normalized_shape=[H,W], inv_rms=[N,C,1,1] * input [N,C,H,W], scale [1,1,1,W] -> normalized_shape=[W], inv_rms=[N,C,H,1] * (degenerate) input [N,1,1,1], scale [1,1,1,1] -> normalized_shape=[1,1,1], inv_rms=[N,1,1,1] Equivalently, inv_rms[i] = (scale[i] == 1) ? input[i] : 1 — every dim that scale broadcasts over is kept; every dim it matches is collapsed to 1. Fusilli previously took all non-batch dims of x for `normalized_shape` regardless of scale, treated SCALE as optional (with a no-scale path nothing downstream actually supported), and rejected any inv_rms shape other than [N, 1, ..., 1]. Changes: * Make SCALE a required input — *breaking change*: callers passing null SCALE (C++) or omitting it (Python binding) now hit `Tensor SCALE not set`. The optional-SCALE fallback in preValidateNode, inferPropertiesNode, getNormalizedShape, and the ASM emitter is removed. * Replace `normalized_shape = all non-batch dims of x` with the cuDNN trailing-suffix rule. Validator enforces rank, batch broadcast, non-empty trailing match, and leading-region all-1. ASM emitter derives the MLIR `normalized_shape` from scale's trailing match. * Replace inv_rms shape inference + validation with the per-dim rule `invRms[i] = (scale[i] == 1) ? x[i] : 1`. Stride preserves x's stride order via norm_utils::getScaleBiasStride (matches hipDNN's RMSNormNode and our scale/bias stride convention). * New norm_utils helpers: getRMSNormNormalizedShape(xDim, sDim) and getRMSNormInvRmsDimAndStride(xDim, sDim, xStride). Distinct names from the existing batch-driven LayerNorm helper to keep call sites unambiguous about which derivation rule is in play. Tests: trailing-suffix accept (full canonical, [H,W], [W]), leading-region reject, no-trailing-match reject, degenerate all-1 accept, partial-suffix inv_rms accept and canonical-when-partial reject; new unit tests for inv_rms shape inference (canonical, partial-suffix variants) and stride preservation under NHWC; lit tests for partial-suffix [H,W] / [W] emission. The deleted `samples/rmsnorm/rmsnorm_infer_nchw.cpp` and `tests/lit/test_rmsnorm_infer_asm_emitter_nchw.cpp` exercised the no-scale path that no longer exists. Empirical confirmation of cuDNN's contract on A100 + cuDNN 9.19.0 via `rmsnorm_semantics_demo/test_scale_shape_reduction.py` and `test_rmsnorm_no_scale.py`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
79a9506 to
1cf8710
Compare
sjain-stanford
left a comment
There was a problem hiding this comment.
I think we should make normalized_shape an input for rmsnorm and layernorm and not derive normalized_shape from the maximal trailing suffix of input. Here's why:
For it to work, we need either scale to be required (to infer norm_shape) or we require a norm shape to be specified.
Both layernorm and RMSnorm in PT don't require a scale -> only if elementwise_affine is true. Making it required in fusilli is a bit jarring if users trained without scaling.
So the least common divisor is to expect a norm_shape on both, and then let scale / bias be same shape as norm_shape.
We should do this at fusilli layer. hipdnn can do what it does today which is using trailing suffix of input for norm shape, and feed that "inferred" normalized shape into fusilli's op builder.
Why
RMSNorm over an example 1d vector is defined as
normalized_shape.Fusilli currently always uses all non-batch dimensions for$\texttt{RMS}$ over
normalized_shape. For example with a[N, C, H, W]fusilli computes[C, H, W]regardless of what scale is set and expects inverse rms output to be[N, 1, 1, 1]for all input shapes.The semantics diverge from
cuDNN/hipDNNwhich derivenormalized_shapefrom the maximal trailing suffix ofinputwherescale[i] == input[i].cuDNN's rules for determining normalized shape from scale:Inverse RMS shape is always the "kept dimensions" aka inputs dimensions where scale[i] == 1, exactly the dimensions that are not normalized over.
Examples:
[N, C, H, W], scale[1, C, H, W]->normalized_shape=[C, H, W],inv_rms=[N, 1, 1, 1][N, C, H, W], scale[1, 1, H, W]->normalized_shape=[H, W],inv_rms=[N, C, 1, 1][N, C, H, W], scale[1, 1, 1, W]->normalized_shape=[W],inv_rms=[N, C, H, 1][N, 1, 1, 1], scale[1, 1, 1, 1]->normalized_shape=[1,1,1],inv_rms=[N, 1, 1, 1]Note:
cuDNNand PyTorch frontends requirenormalized_shapeto match input dims without broadcasts. For example, input[N, C, H, W], scale[1, 1, H, 1]would be rejected rather than assuming that scale is broadcast overWdimension.What
normalized_shapecalculation and emission to matchcuDNN/hipDNNsemanticsinv_rmsshape to be the kept dimensions