Skip to content

[Common] MOE Split dBias#2674

Open
Oleg-Goncharov wants to merge 7 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_split_dbias
Open

[Common] MOE Split dBias#2674
Oleg-Goncharov wants to merge 7 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_split_dbias

Conversation

@Oleg-Goncharov
Copy link
Collaborator

Description

This PR adds a new kernel that computes dbias separately for each tensor in a group and outputs a grouped dbias tensor containing per-tensor dbias values.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added the grouped dbias kernel

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Summary

This PR extends the grouped MXFP8 quantization to compute separate dbias values for each tensor in a group, changing from a single global dbias output to a per-tensor grouped dbias tensor.

Key changes:

  • Added group_reduce_dbias_kernel that computes per-tensor dbias from the workspace by reducing partial sums with correct offset calculations for SAME_BOTH_DIMS and VARYING_FIRST_DIM shape representations
  • Changed dbias parameter type from Tensor* to GroupedTensor* across all quantization and activation APIs
  • Added device-side validation ensuring each tensor's first dimension is divisible by 128 (required for correct workspace alignment)
  • Updated expected dbias shape from {K} to {num_tensors, K} for grouped operations
  • Strengthened grid computation check from last_logical_dim % 128 == 0 to elts_total % ELTS_PER_CHUNK == 0 for varying-dimension cases
  • Tests updated to validate per-tensor dbias correctness with new test cases

Important limitation: Grouped dbias remains unsupported for VARYING_LAST_DIM cases (when tensors have different column dimensions), as documented in the API.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations around workspace offset correctness
  • The implementation is well-structured with proper validation checks. The new group_reduce_dbias_kernel correctly handles the two supported shape representations (SAME_BOTH_DIMS and VARYING_FIRST_DIM). The device-side validation at line 109-111 ensures first dimensions are divisible by 128, which is critical for workspace alignment. Tests have been updated to cover per-tensor dbias outputs. Score is 4 (not 5) due to complexity of offset arithmetic in VARYING_FIRST_DIM case that could benefit from additional edge case testing, though the logic appears sound.
  • Pay close attention to transformer_engine/common/cast/core/common.cuh line 110 - verify the workspace offset calculation offsets_ptr[tensor_id] / cols / chunk_dim_Y is correct for all VARYING_FIRST_DIM edge cases

Important Files Changed

Filename Overview
transformer_engine/common/cast/core/common.cuh Added group_reduce_dbias_kernel to compute per-tensor dbias values from workspace. Logic looks correct for SAME_BOTH_DIMS and VARYING_FIRST_DIM cases (the only supported cases).
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Updated dbias signature to GroupedTensor*, added validation that first dims are divisible by 128. Changed full-tile check from last_logical_dim % 128 == 0 to elts_total % ELTS_PER_CHUNK == 0 for varying dimension cases.
transformer_engine/common/include/transformer_engine/cast.h Updated function signatures to use NVTEGroupedTensor instead of NVTETensor for dbias parameter. Added documentation about grouped dbias not being supported for varying last dimension.
tests/cpp/operator/test_cast_mxfp8_grouped.cu Updated tests to support grouped dbias with per-tensor outputs. Changed from single dbias vector to per-tensor dbias tracking. Added new test cases and updated skip conditions for varying last dimension.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Input: Grouped Tensor] --> B{Check Shape Rep}
    B -->|SAME_BOTH_DIMS or<br/>VARYING_FIRST_DIM| C[Supported: Single Tensor Mode]
    B -->|VARYING_LAST_DIM or<br/>VARYING_BOTH_DIMS| D[Not Supported: Skip DBias]
    
    C --> E[Validate: First dim % 128 == 0]
    E --> F[group_quantize_mxfp8_kernel]
    F --> G[Compute partial dbias<br/>per 128x128 chunk]
    G --> H[Store to workspace<br/>M/128 x K layout]
    
    H --> I[group_reduce_dbias_kernel]
    I --> J{Shape Rep}
    J -->|SAME_BOTH_DIMS| K[Offset = tensor_id * rows/128]
    J -->|VARYING_FIRST_DIM| L[Offset = offsets_ptr / cols / 128]
    
    K --> M[Reduce rows in workspace]
    L --> M
    M --> N[Output: Per-tensor dbias<br/>num_tensors x K]
Loading

Last reviewed commit: a7a06d2

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Oleg-Goncharov and others added 2 commits February 12, 2026 14:57
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

? (first_logical_dim / num_tensors)
: first_dims_ptr[tensor_id];

const size_t rows = tensor_rows / chunk_dim_Y;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that tensor_rows is always divisible by chunk_dim_Y (128), otherwise this division silently truncates and skips tail row reduction.

Comment on lines 147 to 150
if (global_dim_X % CHUNK_DIM_X != 0) {
NVTE_DEVICE_ERROR(
"The grouped tensor must be divisible by 128x128 tiles without a tail tile.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see the performance impact of having this here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On B300, the difference is within measurement noise. Over 3 runs, nsys shows ~59.69 µs with the check vs. ~59.62 µs without.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile


const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For VARYING_FIRST_DIM, the offset computation offsets_ptr[tensor_id] / cols / chunk_dim_Y assumes the data offset is divisible by cols * chunk_dim_Y. However, when tensors have varying first dimensions, the cumulative offset offsets_ptr[tensor_id] equals the sum of M_i * K for all previous tensors. If any M_i % chunk_dim_Y != 0, this division will truncate and compute an incorrect workspace offset, causing data corruption.

The kernel in group_quantize_mxfp8.cuh:109-111 validates each tensor's first dimension is divisible by 128, which ensures M_i % chunk_dim_Y == 0, but the workspace offset depends on the sum of all previous tensor sizes being correctly aligned. Verify this is always satisfied for VARYING_FIRST_DIM case.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants