Skip to content

feat: add quantized tensor support#335

Open
RuibinCheung wants to merge 5 commits into
mainfrom
dev/zhangrb/quantized_tensor
Open

feat: add quantized tensor support#335
RuibinCheung wants to merge 5 commits into
mainfrom
dev/zhangrb/quantized_tensor

Conversation

@RuibinCheung
Copy link
Copy Markdown
Collaborator

@RuibinCheung RuibinCheung commented May 9, 2026

Description

This PR introduces a unified QuantizedTensor wrapper class as the canonical container for all FP8 / FP4 low-precision paths. As part of the work, the FP8/FP4 quantization kernels are reorganized, a new rowwise dequantize C++ binding and an optimized col-major dequant kernel are added, and all GEMM / grouped GEMM / MoE permutation ops are migrated to the new QuantizedTensor API.

Motivation:

  • The previous Float8Tensor was still under development, did not cover FP4 or the full set of granularities (rowwise / tensorwise / MX-blockwise), and could not be passed as a first-class citizen through grouped GEMM / MoE permutation paths.
  • On the C++ side, quantization.cu lumped tensorwise / rowwise / MXFP8 / MXFP4 into a single TU, leaving shared QuantOp / amax→scale helpers tightly coupled and hard to reuse or extend.
  • The col-major path of rowwise dequantize processed only one row per block, so re-loading scale_inv consumed a non-trivial fraction of bandwidth on tall-M shapes.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Breaking change notes: the legacy primus_turbo.pytorch.core.float8_tensor.Float8Tensor has been removed; downstream callers must migrate to the new QuantizedTensor API. The C++ schemas of quantize_mxfp4* / quantize_mxfp8* have a new required argument padding_align_size.

Changes

New unified QuantizedTensor wrapper

  • Add primus_turbo/pytorch/core/quantized_tensor.py:
    • QuantizedTensor: a wrapper subclass built on torch.Tensor._make_wrapper_subclass, carrying _data / _scale_inv together with metadata such as orig_dtype / dest_dtype / granularity / block_size / scaling_recipe / quantized_axis, plus group_lens / group_offs for the grouped case.
    • Provide a factory method QuantizedTensor.quantize(...) (auto-dispatched to quantize_fp8 / quantize_fp4), as well as dequantize(), autograd-aware view / reshape (no dequant round-trip), and __tensor_flatten__ / __tensor_unflatten__ (compatible with torch.compile / FSDP).
    • Add QuantizedTensorPair and check_quantized_tensor to help op-side enforce config consistency.
  • Remove the legacy primus_turbo/pytorch/core/float8_tensor.py (Float8Tensor).
  • Export the new APIs from primus_turbo/pytorch/core/__init__.py.

Quantization C++ kernel refactor + new ops

  • Split the original csrc/kernels/quantization/quantization.cu into:
    • csrc/kernels/quantization/quantization_rowwise.cu: rowwise FP8 quantize / dequantize (both row-major and col-major layouts).
    • csrc/kernels/quantization/quantization_tensorwise.cu: tensorwise FP8 quantize / dequantize, and the single explicit instantiation of compute_scale_from_amax<float>.
  • Extract shared device helpers into a new header csrc/include/primus_turbo/device/quant_utils.cuh: QuantOpBase / QuantOp and the host/device implementations of compute_scale_from_amax, shared across rowwise and tensorwise paths.
  • csrc/include/primus_turbo/quantization.h: declare the new dequantize_rowwise_row_major_impl / dequantize_rowwise_col_major_impl interfaces.
  • csrc/pytorch/quantization/quantization.cpp + bindings_pytorch.cpp:
    • Add dequantize_fp8_rowwise(Tensor, Tensor, int axis, ScalarType) -> Tensor, supporting rowwise dequant for arbitrary axis (auto-dispatched between row-major and col-major impls).
    • Add an explicit padding_align_size argument to all quantize_mxfp4{,_dual} / quantize_mxfp8{,_dual} schemas, with C++ sanity checks (must equal MXFP{4,8}_PADDING_ALIGN_SIZE) so the alignment constant is no longer hard-coded inside the kernels.
  • csrc/pytorch/quantization/quantization_meta.cpp: add the corresponding meta implementations.

Rowwise dequantize kernel optimization (f1cfe45)

  • Rewrite dequantize_rowwise_col_major_kernel to use 2D tiling (TILE_INNER × ELEMS_PER_THREAD): each block now loads ELEMS_PER_THREAD scale_inv values into registers once and reuses them TILE_INNER times along the M direction, significantly reducing redundant scale_inv traffic. TILE_INNER / threads-per-row are chosen adaptively based on next_pow2(inner_len) to improve occupancy.

Op-level adoption of QuantizedTensor

  • primus_turbo/pytorch/ops/gemm_fp8.py / gemm_fp4.py / grouped_gemm_fp8.py / grouped_gemm.py / gemm.py:
    • Forward / backward accept and return QuantizedTensor (or auto-quantize plain tensors) and validate config / axis / scaling_recipe through check_quantized_tensor.
    • Unify axis semantics (-1 = inner-K, -2 = inner-M); grouped paths consistently propagate group_lens / group_offs.
  • primus_turbo/pytorch/ops/quantization.py / kernels/quantization/quantization_impl.py: expose the rowwise dequantize_fp8 path and adapt to the new binding.
  • primus_turbo/pytorch/ops/moe/permutation.py, fused_moe_router.py, indices_converter.py, moe_dispatch_combine.py: pass QuantizedTensor transparently through MoE paths (grouped wrappers preserve their grouped identity across view / reshape).
  • primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_utils.py / grouped_gemm_fp8_impl.py: align the helpers with the new wrapper semantics.

Tests

  • Add tests/pytorch/core/test_quantized_tensor.py: covers quantize / dequantize / view / reshape / flatten–unflatten / grouped-wrapper behavior across TENSORWISE / ROWWISE / MX_BLOCKWISE granularities and across FP8 / MXFP4 / MXFP8 dtype combinations.
  • tests/pytorch/ops/test_gemm_fp8.py / test_gemm_fp4.py / test_grouped_gemm_fp8.py / test_grouped_gemm.py / test_quantization.py / test_linear_fp8.py: extend end-to-end tests for the new QuantizedTensor input / output paths.

CI

  • .github/workflows/ci.yaml: bump the single-GPU test job's timeout_minutes from 95 to 135 (and timeout -k 60s 130m accordingly) to accommodate the additional runtime introduced by the new QuantizedTensor test suite.

Checklist:

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings May 9, 2026 09:04
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

安全审查完成:发现 2 个中危问题。

  • JAX DeepEP 的对齐校验使用了错误的元素计数,未按 int4 对齐的输入可进入分发 kernel,导致返回缓冲区部分未初始化/错拷贝,存在 GPU 内存信息泄露风险。
  • 新增的 padding_align_size 原生参数未校验即参与 cdiv,直接调用公开 torch.ops 可用 0 触发除零崩溃,造成服务拒绝服务。

未发现此前自动化留下的未解决安全审查线程。

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Comment thread csrc/jax/deep_ep/deep_ep.cpp Outdated
Comment thread csrc/pytorch/quantization/quantization.cpp
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a unified low-precision tensor wrapper (QuantizedTensor / GroupedQuantizedTensor) and rewires FP8/FP4 GEMM + grouped GEMM ops to accept either raw tensors or pre-quantized wrappers, enabling quantization reuse and optional transpose-cache reuse across forward/backward.

Changes:

  • Added primus_turbo/pytorch/core/quantized_tensor.py implementing the wrapper abstraction + serialization/view support.
  • Updated FP8/FP4 GEMM and grouped-FP8 GEMM to consume wrapper inputs (skipping internal quantization when provided).
  • Added/extended substantial test coverage for the new wrapper behavior and wrapper-aware ops; removed legacy Float8Tensor.

Reviewed changes

Copilot reviewed 41 out of 41 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
primus_turbo/pytorch/core/quantized_tensor.py New QuantizedTensor / GroupedQuantizedTensor implementation with quantization, transpose-cache, view/reshape, serialization.
primus_turbo/pytorch/core/low_precision.py Adds padding-align constants and clarifies ScalingRecipe docs.
primus_turbo/pytorch/core/__init__.py Updates core exports to include QuantizedTensor and explicitly list __all__.
primus_turbo/pytorch/core/float8_tensor.py Removes the legacy Float8Tensor wrapper.
primus_turbo/pytorch/ops/gemm_fp8.py Makes gemm_fp8 accept wrapper inputs; reuses cached transpose in backward paths.
primus_turbo/pytorch/ops/gemm_fp4.py Makes gemm_fp4 accept wrapper inputs and reuse wrapper caches/recipes.
primus_turbo/pytorch/ops/grouped_gemm_fp8.py Makes grouped_gemm_fp8 accept wrapper inputs for activations/weights and use wrapper metadata.
primus_turbo/pytorch/ops/grouped_gemm.py Removes group_offs parameter and recomputes offsets internally.
primus_turbo/pytorch/ops/gemm.py Uses torch.promote_types for output dtype inference.
primus_turbo/pytorch/ops/quantization.py Switches MX block-size constants to MXFP8_BLOCK_SIZE / MXFP4_BLOCK_SIZE and updates assertions.
primus_turbo/pytorch/ops/moe/permutation.py Swaps legacy Float8Tensor references toward QuantizedTensor (but FP8 path remains unsupported).
primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py Adds __all__ exports.
primus_turbo/pytorch/ops/moe/indices_converter.py Adds __all__ exports.
primus_turbo/pytorch/ops/moe/fused_moe_router.py Adds __all__ exports.
primus_turbo/pytorch/kernels/quantization/quantization_impl.py Threads through padding-align size and tightens MX block-size validation.
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_utils.py Renames _group_offs_from_lens to group_offs_from_lens and updates uses.
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py Removes grouped_gemm_compute_offs helper.
csrc/pytorch/quantization/quantization.cpp Adds padding-align-size parameter to MX quantization entry points.
csrc/pytorch/quantization/quantization_meta.cpp Meta kernels updated for padding-align-size parameter.
csrc/pytorch/extensions.h Updates declarations to match new MX quantization signatures.
csrc/pytorch/bindings_pytorch.cpp Updates PyTorch bindings signatures for MX quantization ops.
csrc/include/primus_turbo/quantization.h Removes hardcoded padding-align constants from header.
csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu Adjusts hipBLASLt grouped GEMM stream/handle management and sync logic.
csrc/kernels/deep_ep/utils.cuh Removes custom barrier macros (replaced by simpler sync logic).
csrc/kernels/deep_ep/intranode.cu Updates synchronization primitives and warp intrinsics usage.
primus_turbo/jax/primitive/moe/moe_dispatch.py Drops has_side_effect=True in ffi lowering calls.
primus_turbo/jax/primitive/moe/moe_combine.py Drops has_side_effect=True in ffi lowering calls.
primus_turbo/jax/lax/moe/moe_utils.py Replaces C++ Config import with a Python NamedTuple definition.
primus_turbo/jax/lax/moe/moe_dispatch_combine.py Imports Config from moe_utils, updates exports, minor docstring updates.
primus_turbo/jax/lax/moe/__init__.py Updates exported symbols to match new lax API surface.
csrc/jax/deep_ep/deep_ep.cpp Modifies input layout/alignment checks for intranode dispatch.
tests/pytorch/core/test_quantized_tensor.py New tests for QuantizedTensor construction/serialization/view/dequantize.
tests/pytorch/core/test_grouped_quantized_tensor.py New tests for GroupedQuantizedTensor group metadata + behavior.
tests/pytorch/ops/test_gemm_fp8.py Adds wrapper-input FP8 GEMM tests across granularities/backends.
tests/pytorch/ops/test_gemm_fp4.py Adds wrapper-input FP4 GEMM tests for MX path.
tests/pytorch/ops/test_grouped_gemm_fp8.py Adds wrapper-input grouped FP8 GEMM tests (tensorwise/rowwise).
tests/pytorch/ops/test_grouped_gemm.py Adds temporary gfx942 hipBLASLt flake skips.
setup.py Removes hipblas from extension link libraries.
SECURITY.md Deletes repository security policy file.
benchmark/ops/bench_grouped_gemm_gmm.py Removes the legacy grouped_gemm baseline benchmark script.
Comments suppressed due to low confidence (2)

primus_turbo/pytorch/ops/moe/permutation.py:95

  • The FP8/QuantizedTensor path here is internally inconsistent: use_fp8 immediately raises, but the code below still references legacy Float8Tensor fields (_scale, _fp8_dtype, _config) and even tries to construct QuantizedTensor using the old Float8Tensor-style kwargs (data=..., scale=..., orig_dtype=...). With the new QuantizedTensor API this branch would crash if it ever became reachable. Please either (a) remove QuantizedTensor from the accepted input type and delete the dead FP8 code paths, or (b) implement the FP8 permute/unpermute support using the new wrapper’s data/scale_inv semantics and constructor signature.
    primus_turbo/pytorch/ops/grouped_gemm_fp8.py:466
  • grouped_gemm_fp8 declares config: Float8QuantConfig | None = None but then unconditionally reads config.granularity. If config is actually None this will crash, and it contradicts the docstring claim that None uses a default config. Add the missing if config is None: config = Float8QuantConfig() (and ideally validate out_dtype when wrappers are passed).
    if out_dtype is None:
        out_dtype = torch.promote_types(a.dtype, b.dtype)

    args = (a, b, group_lens, trans_b, out_dtype, config, num_cu)

    if config.granularity == ScalingGranularity.TENSORWISE:
        return FP8GroupedGemmTensorFunc.apply(*args)
    elif config.granularity == ScalingGranularity.ROWWISE:
        return FP8GroupedGemmRowFunc.apply(*args)
    elif config.granularity == ScalingGranularity.BLOCKWISE:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread primus_turbo/pytorch/core/quantized_tensor.py Outdated
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py Outdated
Comment thread primus_turbo/pytorch/ops/grouped_gemm.py
Comment thread csrc/jax/deep_ep/deep_ep.cpp Outdated
Comment thread primus_turbo/jax/lax/moe/moe_dispatch_combine.py Outdated
Comment thread primus_turbo/pytorch/ops/gemm_fp8.py Outdated
Copilot AI review requested due to automatic review settings May 11, 2026 02:33
@RuibinCheung RuibinCheung force-pushed the dev/zhangrb/quantized_tensor branch from 60fb133 to efe9b3b Compare May 11, 2026 02:33
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 29 out of 29 changed files in this pull request and generated 7 comments.

Comments suppressed due to low confidence (2)

primus_turbo/pytorch/ops/grouped_gemm_fp8.py:465

  • grouped_gemm_fp8() allows config=None (per signature/docstring) but immediately dereferences config.granularity. This will raise at runtime for default callers. Initialize config = Float8QuantConfig() when config is None (mirroring gemm_fp8).
    if out_dtype is None:
        out_dtype = torch.promote_types(a.dtype, b.dtype)

    args = (a, b, group_lens, trans_b, out_dtype, config, num_cu)

    if config.granularity == ScalingGranularity.TENSORWISE:
        return FP8GroupedGemmTensorFunc.apply(*args)
    elif config.granularity == ScalingGranularity.ROWWISE:
        return FP8GroupedGemmRowFunc.apply(*args)

primus_turbo/pytorch/ops/moe/permutation.py:95

  • This module was switched from Float8Tensor to QuantizedTensor, but the FP8 branches still reference legacy attributes (_scale, _fp8_dtype, _config) and construct QuantizedTensor with unsupported kwargs (data=..., scale=..., orig_dtype=...). Even though FP8 is currently blocked by a raise/assert, this code is now dead/incorrect and will break if FP8 support is re-enabled. Please delete the unreachable FP8 branches or update them to the QuantizedTensor API.
        use_fp8 = isinstance(inp, QuantizedTensor)

        if use_fp8:
            raise ValueError("FP8 is not supported for now.")

        if use_fp8:
            fp8_scale = inp._scale
            fp8_dtype = inp._fp8_dtype
            scale_hidden_dim = fp8_scale.shape[1]
        else:
            fp8_scale = None
            fp8_dtype = None
            scale_hidden_dim = None

        output, permuted_scale, permuted_probs = permutation.permute_with_mask_map(
            inp,
            row_id_map,
            probs,
            fp8_scale,
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
            scale_hidden_dim,
        )

        if use_fp8:
            output = QuantizedTensor(
                data=output,
                scale=permuted_scale,
                orig_dtype=inp._orig_dtype,
                fp8_dtype=fp8_dtype,
                config=inp._config,
            )

Comment thread primus_turbo/pytorch/core/quantized_tensor.py Outdated
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py Outdated
) and tensor._scale_inv is not None:
block_size = tensor._block_size
packing = _get_packing_factor(tensor)
scale_target_shape = _compute_scale_shape(padded_target_shape, block_size, packing)
Comment thread primus_turbo/pytorch/ops/grouped_gemm.py
Comment thread primus_turbo/pytorch/ops/grouped_gemm.py
Comment thread primus_turbo/pytorch/ops/gemm_fp8.py Outdated
Copilot AI review requested due to automatic review settings May 12, 2026 02:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 29 out of 29 changed files in this pull request and generated 11 comments.

Comments suppressed due to low confidence (1)

primus_turbo/pytorch/ops/moe/permutation.py:66

  • This file still treats QuantizedTensor like the removed Float8Tensor (FP8 metadata fields like _scale/_fp8_dtype/_config are referenced below). If FP8 remains unsupported here, consider dropping QuantizedTensor from the accepted types; otherwise update the FP8 path to the new QuantizedTensor API to avoid future runtime attribute errors.
        use_fp8 = isinstance(inp, QuantizedTensor)

        if use_fp8:
            raise ValueError("FP8 is not supported for now.")

Comment thread primus_turbo/pytorch/core/__init__.py Outdated
Comment thread tests/pytorch/core/test_grouped_quantized_tensor.py Outdated
Comment thread primus_turbo/pytorch/ops/moe/permutation.py
Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py Outdated
Comment thread primus_turbo/pytorch/ops/grouped_gemm.py
)

if isinstance(a, GroupedQuantizedTensor):
check_grouped_quantized_tensor(a, config, group_lens)
Comment on lines +117 to +121
assert quantized_tensor.granularity == config.granularity, (
f"QuantizedTensor granularity {quantized_tensor.granularity} does not match config "
f"granularity {config.granularity}"
)
assert quantized_tensor.block_size == config.block_size, (
Comment on lines +629 to +638
# TODO(ruibin): ``_compute_scale_shape`` assumes axis=-1
# chunked scale layout (row-quant). ``_scale_inv_t`` is
# chunked along axis=-2 (col-quant), so its shape differs.
# Wrapper-input BLOCKWISE/MX_BLOCKWISE is not supported yet
# (see grouped_gemm_fp8 BlockFunc), so we keep the existing
# call here as a placeholder; revisit when that path lands.
scale_t_target_shape = _compute_scale_shape(
padded_target_shape, tensor._block_size, _get_packing_factor(tensor)
)
out_scale_inv_t = op(out_scale_inv_t, *scale_t_target_shape)
Comment on lines +44 to +48
# FP8 dtypes to cover across tests. GroupedQuantizedTensor explicitly does
# *not* accept FP4; that case is checked in TestGroupSpecific.
_DTYPES = set([float8_e4m3, float8_e5m2])

if MXFP4_SUPPORT:
Comment on lines +59 to +71
granularity: ScalingGranularity = ScalingGranularity.MX_BLOCKWISE,
block_size: int = MXFP8_BLOCK_SIZE,
dest_dtype: torch.dtype = float8_e4m3,
keep_trans_cache: bool = False,
) -> GroupedQuantizedTensor:
"""Unified helper: construct a GroupedQuantizedTensor with sensible
defaults for each granularity.

- TENSORWISE: no block_size, no recipe (a single global scale)
- ROWWISE: no block_size, no recipe
- BLOCKWISE: block_size required (1D blockwise; 2D-block / weight-style
blockwise does not apply to packed-M activations)
- MX_BLOCKWISE: block_size + ScalingRecipe required
a: torch.Tensor,
b: torch.Tensor,
group_lens: torch.Tensor, # [B,] int64
group_offs: torch.Tensor, # [B+1,] int64
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep group_offs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The group_offs can be inferred by group_lens. They are essentially the same thing. If we keep them both it's ambiguous and not user friendly I think. And it may cause potential issue when group_lens is not match group_offs that user passed. May I know there is any reason to must keep this args ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchtitan need this args.

__torch_function__ = torch._C._disabled_torch_function_impl


class GroupedQuantizedTensor(QuantizedTensor):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need GroupedQuantizedTensor?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The activation of grouped gemm needs group_lens attribute. For mx and blockwise, the quantization and dequantization func also have some different. To make internal func clear that I abstract it into a new class named GroupedQuantizedTensor.

"""
return self._orig_ndim

def t(self) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a risk that it may cause ambiguity.

Copy link
Copy Markdown
Collaborator Author

@RuibinCheung RuibinCheung May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The t func is original from torch.Tensor. I hope to keep QuantizedTensor has same api as torch.Tensor. Do you have any suggestion for this func?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return QuantizedTensor

return self._data_t, self._scale_inv_t

@property
def T(self) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

"""Wrapper subclass that carries low-precision quantized data, scale_inv"""

@staticmethod
def __new__(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Launching a quantize kernel inside __new__ is somewhat against the conventions of PyTorch tensor subclasses.

Conventionally, __new__ should only handle _make_wrapper_subclass(...) plus field initialization, while quantization should go through a separate factory method.

class QuantizedTensor(torch.Tensor):
    # Pure wrapper: only stores fields, does not launch kernels
    @staticmethod
    def __new__(cls, data, scale_inv, *, dest_dtype, granularity, ...):
        self = torch.Tensor._make_wrapper_subclass(cls, ...)
        self._data = data
        self._scale_inv = scale_inv
        # ... store fields
        return self

    # Factory method: quantization kernel is launched here
    @classmethod
    def quantize(cls, hp_tensor, dest_dtype, granularity, *, keep_trans_cache=False, ...):
        data, scale_inv, data_t, scale_inv_t = _do_quantize(hp_tensor, ...)
        return cls(data, scale_inv, dest_dtype=..., ...)

usage

fp8_tensor = QuantizedTensor.quantize(fp16_tensor, ...)

fp8_tensor = QuantizedTensor(fp8_data, scale, ....)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK. Thx.

return self

@classmethod
@torch.no_grad()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no grad?

The @torch.no_grad() on _quantize causes QuantizedTensor(x) to break the autograd chain of x — after wrapping an nn.Parameter, x.grad will always be None. Is this intended?

Copy link
Copy Markdown
Collaborator Author

@RuibinCheung RuibinCheung May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is expected. The quantized tensor actually is a buffer. The grad will accumulate its original x.grad.

# Dispatch hooks
# ------------------------------------------------------------------
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__torch_dispatch__ silently dequantizing all ops by default is quite risky.

Recommend making the default behavior raise NotImplementedError instead.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK. Thx.

"""
if out_dtype is None:
out_dtype = torch.promote_types(a.dtype, b.dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss if config == None

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK. Thx.

Copilot AI review requested due to automatic review settings May 13, 2026 06:16
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (2)

primus_turbo/pytorch/ops/grouped_gemm_fp8.py:476

  • grouped_gemm_fp8() declares config: Union[Float8QuantConfig, None] = None, but it unconditionally reads config.granularity (and passes config into autograd Functions). Calling grouped_gemm_fp8(..., config=None) will raise an AttributeError. Add if config is None: config = Float8QuantConfig() before accessing config.granularity (and consider validating the type early).
    if out_dtype is None:
        out_dtype = torch.promote_types(a.dtype, b.dtype)

    args = (a, b, group_lens, trans_b, out_dtype, config, num_cu)

    if config.granularity == ScalingGranularity.TENSORWISE:
        return FP8GroupedGemmTensorFunc.apply(*args)
    elif config.granularity == ScalingGranularity.ROWWISE:
        return FP8GroupedGemmRowFunc.apply(*args)
    elif config.granularity == ScalingGranularity.BLOCKWISE:

primus_turbo/pytorch/ops/moe/permutation.py:95

  • There is now dead / inconsistent FP8-handling code that still references the removed Float8Tensor fields (_scale, _fp8_dtype, _config) and constructs QuantizedTensor using the old Float8Tensor constructor signature. Even though use_fp8 currently raises immediately, this code is misleading and will break if FP8 support is re-enabled. Either delete the unreachable FP8 branch or update it to the new QuantizedTensor API (data + scale_inv + granularity metadata).
        use_fp8 = isinstance(inp, QuantizedTensor)

        if use_fp8:
            raise ValueError("FP8 is not supported for now.")

        if use_fp8:
            fp8_scale = inp._scale
            fp8_dtype = inp._fp8_dtype
            scale_hidden_dim = fp8_scale.shape[1]
        else:
            fp8_scale = None
            fp8_dtype = None
            scale_hidden_dim = None

        output, permuted_scale, permuted_probs = permutation.permute_with_mask_map(
            inp,
            row_id_map,
            probs,
            fp8_scale,
            num_tokens,
            num_experts,
            num_out_tokens,
            hidden_size,
            scale_hidden_dim,
        )

        if use_fp8:
            output = QuantizedTensor(
                data=output,
                scale=permuted_scale,
                orig_dtype=inp._orig_dtype,
                fp8_dtype=fp8_dtype,
                config=inp._config,
            )

Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Comment thread primus_turbo/pytorch/core/__init__.py Outdated
Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py Outdated
Copilot AI review requested due to automatic review settings May 14, 2026 03:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 30 changed files in this pull request and generated 18 comments.

Comments suppressed due to low confidence (5)

primus_turbo/pytorch/core/quantized_tensor.py:607

  • _get_padding_align_size() returns 0 for TENSORWISE, ROWWISE, and BLOCKWISE tensors, so any view/reshape to a different same-rank shape calls _pad_inner_dim(..., 0) and divides by zero before reaching the underlying view. This makes the public view/reshape methods unusable for non-MX granularities except for the exact same shape.
        align = _get_padding_align_size(tensor)
        padded_target_shape = _pad_inner_dim(target_shape, align)

primus_turbo/pytorch/ops/gemm_fp8.py:284

  • For BLOCKWISE wrapper inputs this accepts an activation without a transpose cache, but the code unconditionally calls a_fp8.t() below. Since blockwise dequantization is not implemented in QuantizedTensor.dequantize(), a valid-looking pre-quantized activation with keep_trans_cache=False will fail at runtime instead of being rejected with a clear validation error.
        if isinstance(a, QuantizedTensor):
            check_quantized_tensor(a, config)
            a_fp8 = a

primus_turbo/pytorch/ops/grouped_gemm_fp8.py:440

  • Disabling Dynamo on the public grouped FP8 GEMM API forces graph breaks for all callers, including raw tensor inputs that do not use the new wrappers. This makes full-graph torch.compile incompatible with grouped FP8 GEMM rather than only guarding the unsupported wrapper-specific path.
@torch._dynamo.disable(
    recursive=True,
    reason=(
        "Grouped FP8 GEMM constructs (Grouped)QuantizedTensor wrapper subclasses "
        "inside its autograd.Function.forward and reads their inner tensors "
        "(data / scale_inv / group_lens / group_offs). Dynamo cannot recover Python "
        "sources for those graph-internal inner tensors, tripping gb0116 "
        "('SourcelessBuilder.create cannot wrap FakeTensor'). "
    ),
)

primus_turbo/pytorch/core/quantized_tensor.py:659

  • These overrides pass shape through exactly as collected from *shape, so common PyTorch forms like qt.view(torch.Size([m, n])), qt.view((m, n)), or qt.reshape(-1, n) are not normalized the way Tensor.view/reshape does. They will be interpreted as a one-element shape tuple (or fail before -1 inference), making the wrapper incompatible with standard view/reshape call patterns.
    def view(self, *shape) -> "QuantizedTensor":
        """View without dequantizing (autograd-aware)."""
        return _ViewFunc.apply(self, shape)

    def reshape(self, *shape) -> "QuantizedTensor":
        """Reshape without dequantizing (autograd-aware)."""
        return _ReshapeFunc.apply(self, shape)

primus_turbo/pytorch/core/quantized_tensor.py:635

  • This assumes _data_t has the same physical shape as _data, but the MXFP8/MXFP4 dual quantization kernels return the colwise buffer with the last two dimensions transposed (e.g. [N, M_pad] rather than [M, N_pad]). Applying the same target shape to _data_t will corrupt or fail non-identity views of MX wrappers with a transpose cache.
        out_data_t = tensor._data_t
        out_scale_inv_t = tensor._scale_inv_t
        if out_data_t is not None:
            # ``_data_t`` is the col-quantized counterpart at the SAME
            # physical layout as ``_data`` — apply the identical target
            # shape, no last-two-dim swap.
            assert (
                out_data_t.numel() == torch.Size(padded_target_shape).numel()
            ), "data_t numel and padded_target_shape must have the same number of elements"

            out_data_t = op(out_data_t, *padded_target_shape)

return list(tensors.keys()), metadata

@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size):
return keys, metadata

@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size):
Comment on lines +906 to +908
def __tensor_flatten__(self):
keys, metadata = super().__tensor_flatten__()
metadata["_group_lens"] = self._group_lens
Comment on lines +117 to +124
assert quantized_tensor.granularity == config.granularity, (
f"QuantizedTensor granularity {quantized_tensor.granularity} does not match config "
f"granularity {config.granularity}"
)
assert quantized_tensor.block_size == config.block_size, (
f"QuantizedTensor block_size {quantized_tensor.block_size} does not match config "
f"block_size {config.block_size}"
)
assert granularity in [
ScalingGranularity.ROWWISE,
ScalingGranularity.TENSORWISE,
], "GroupedQuantizedTensor only supports ROWWISE and TENSORWISE granularity"
Comment on lines +606 to +613
align = _get_padding_align_size(tensor)
padded_target_shape = _pad_inner_dim(target_shape, align)

assert (
tensor._data.numel() == torch.Size(padded_target_shape).numel()
), "data numel and padded_target_shape must have the same number of elements"

out_data = op(tensor._data, *padded_target_shape)
Comment on lines +456 to +482
@torch.no_grad()
def dequantize(self) -> torch.Tensor:
"""Dequantize back to the original high-precision dtype."""
from primus_turbo.pytorch.ops.quantization import dequantize_fp4, dequantize_fp8

axis = _normalize_axis(-1, self._data.ndim)

if self._dest_dtype in [float8_e4m3, float8_e5m2]:
return dequantize_fp8(
self._data,
self._orig_dtype,
self._granularity,
block_size=self._block_size,
axis=axis,
scale_inv=self._scale_inv,
scaling_recipe=self._scaling_recipe,
)
elif self._dest_dtype == float4_e2m1fn_x2:
return dequantize_fp4(
self._data,
self._orig_dtype,
self._granularity,
block_size=self._block_size,
axis=axis,
scale_inv=self._scale_inv,
scaling_recipe=self._scaling_recipe,
)
Comment on lines 110 to 116
def grouped_gemm(
a: torch.Tensor,
b: torch.Tensor,
group_lens: torch.Tensor,
group_offs: torch.Tensor | None = None,
trans_b: bool = False,
num_cu: int | None = None,
) -> torch.Tensor:
Comment thread primus_turbo/pytorch/ops/moe/permutation.py
Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py Outdated
Copilot AI review requested due to automatic review settings May 15, 2026 09:21
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 32 out of 32 changed files in this pull request and generated 12 comments.

Comments suppressed due to low confidence (5)

primus_turbo/pytorch/core/quantized_tensor.py:241

  • Grouped wrappers accept any group_lens tensor but do not validate that it is 1D int64 on the same device or that group_lens.sum() equals the packed-M dimension. A stale or mismatched group_lens can produce invalid group_offs and drive grouped GEMM kernels with offsets outside the quantized activation buffer.
        is_grouped_tensor = group_lens is not None
        if is_grouped_tensor:
            assert (
                hp_tensor.ndim == 2
            ), f"Grouped quantized tensor expects a 2D packed-M tensor, got {hp_tensor.ndim}D"
            assert (
                dest_dtype in _SUPPORTED_QUANTIZED_DTYPES and dest_dtype != float4_e2m1fn_x2
            ), "Unsupported quantized dtype (FP4 not supported for grouped activations)"
            assert granularity in [
                ScalingGranularity.ROWWISE,
                ScalingGranularity.TENSORWISE,
            ], "Grouped quantized tensor only supports ROWWISE and TENSORWISE granularity"

primus_turbo/pytorch/core/quantized_tensor.py:426

  • The flatten metadata omits _quantized_axis. After FSDP/compile unflattening, rowwise/MX tensors lose the axis needed by check_quantized_tensor(..., axis=...) and dequantize(), so serialized wrappers can no longer be validated or dequantized correctly.
        metadata = {
            "_orig_dtype": self._orig_dtype,
            "_dest_dtype": self._dest_dtype,
            "_granularity": self._granularity,
            "_block_size": self._block_size,
            "_scaling_recipe": self._scaling_recipe,
            "_is_grouped_tensor": self._is_grouped_tensor,
        }

primus_turbo/pytorch/core/quantized_tensor.py:478

  • _make_like does not carry over _quantized_axis. Any tensor returned from view()/reshape() loses the axis metadata, so later dequantize() or check_quantized_tensor(..., axis=...) on rowwise/MX wrappers will fail or use None as the axis.
        return cls(
            data,
            scale_inv,
            shape=shape,
            orig_dtype=tensor._orig_dtype,
            dest_dtype=tensor._dest_dtype,
            granularity=tensor._granularity,
            block_size=tensor._block_size,
            scaling_recipe=tensor._scaling_recipe,
            group_lens=tensor._group_lens,
            group_offs=tensor._group_offs,
            is_grouped_tensor=tensor._is_grouped_tensor,
            requires_grad=data.requires_grad,
        )

primus_turbo/pytorch/core/quantized_tensor.py:492

  • For TENSORWISE and ROWWISE tensors _get_padding_align_size() returns 0, so any rank-preserving view/reshape to a different shape reaches _pad_inner_dim(target_shape, 0) and divides by zero. This makes the advertised view/reshape support work only for the same-shape fast path.
        align = _get_padding_align_size(tensor)
        padded_target_shape = _pad_inner_dim(target_shape, align)

primus_turbo/pytorch/ops/gemm_fp4.py:136

  • b_fp4.dtype also resolves to the original fp16/bf16 dtype, so this transpose-cache quantization passes an unsupported dest_dtype and fails before backward can run. Use the quantized storage dtype metadata instead.
        quantized_b_t = QuantizedTensor.quantize(
            b_fp4.dequantize(),
            b_fp4.dtype,
            config.granularity,

Comment on lines +193 to +203
def quantize(
cls,
hp_tensor: Union[torch.Tensor, torch.nn.Parameter],
dest_dtype: torch.dtype,
granularity: ScalingGranularity,
axis: int,
*,
group_lens: Optional[torch.Tensor] = None,
block_size: Optional[int] = None,
scaling_recipe: Optional[ScalingRecipe] = None,
) -> "QuantizedTensor":
Comment on lines +429 to +443
@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size):
return QuantizedTensor(
inner_tensors["_data"],
inner_tensors["_scale_inv"],
shape=outer_size,
orig_dtype=metadata["_orig_dtype"],
dest_dtype=metadata["_dest_dtype"],
granularity=metadata["_granularity"],
block_size=metadata["_block_size"],
scaling_recipe=metadata["_scaling_recipe"],
group_lens=inner_tensors.get("_group_lens"),
group_offs=inner_tensors.get("_group_offs"),
is_grouped_tensor=metadata["_is_grouped_tensor"],
)
Comment on lines 110 to 115
def grouped_gemm(
a: torch.Tensor,
b: torch.Tensor,
group_lens: torch.Tensor,
group_offs: torch.Tensor | None = None,
trans_b: bool = False,
num_cu: int | None = None,
Output tensor with shape [m, n] (same dtype as input)
"""
if out_dtype is None:
out_dtype = torch.promote_types(a.dtype, b.dtype)
Comment thread primus_turbo/pytorch/ops/gemm_fp4.py Outdated
Comment on lines +117 to +120
quantized_a_t = QuantizedTensor.quantize(
a_fp4.dequantize(),
a_fp4.dtype,
config.granularity,
Comment on lines +277 to +278
scaling_recipe_for_trans=unpermuted_act_grad._scaling_recipe_for_trans,
keep_trans_cache=unpermuted_act_grad._keep_trans_cache,
Comment on lines 7 to 11
from .moe_dispatch_combine import (
Config,
get_combine_config,
get_dispatch_config,
moe_combine,
moe_dispatch,
Comment thread primus_turbo/pytorch/core/__init__.py Outdated
Comment on lines +1 to +7
from .quantized_tensor import QuantizedTensor
from .stream import TurboStream
from .symm_mem import SymmetricMemory, get_symm_mem_workspace

__all__ = [
"QuantizedTensor",
"SymmetricMemory",
Comment on lines +217 to +220
# BLOCKWISE is intentionally not handled by the QuantizedTensor wrapper
assert (
granularity != ScalingGranularity.BLOCKWISE
), "BLOCKWISE is not supported by QuantizedTensor; Please call the low-level kernels directly."
Comment on lines +193 to +199
def quantize(
cls,
hp_tensor: Union[torch.Tensor, torch.nn.Parameter],
dest_dtype: torch.dtype,
granularity: ScalingGranularity,
axis: int,
*,
@RuibinCheung RuibinCheung force-pushed the dev/zhangrb/quantized_tensor branch from 2bcf803 to 73e0716 Compare May 18, 2026 09:37
Copilot AI review requested due to automatic review settings May 20, 2026 01:28
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 33 out of 33 changed files in this pull request and generated 17 comments.

group_offs=tensor._group_offs,
is_grouped_tensor=tensor._is_grouped_tensor,
quantized_axis=tensor._quantized_axis,
requires_grad=data.requires_grad,
Comment on lines +534 to +538
out_scale_inv = tensor._scale_inv
if tensor._granularity == ScalingGranularity.MX_BLOCKWISE and tensor._scale_inv is not None:
block_size = tensor._block_size
scale_target_shape = _compute_scale_shape(data_target_shape, block_size, packing)
out_scale_inv = op(tensor._scale_inv, *scale_target_shape)
Comment on lines +224 to +226
assert hp_tensor.ndim in (2, 3), f"data must be a 2D or 3D tensor, got {hp_tensor.ndim}D"
assert dest_dtype in _SUPPORTED_QUANTIZED_DTYPES, "Unsupported quantized dtype"

Comment on lines +139 to +141
class QuantizedTensor(torch.Tensor):
"""Wrapper subclass that carries low-precision quantized data, scale_inv."""

if not x.is_contiguous():
x = x.contiguous()
x_fp8, scale_inv = torch.ops.primus_turbo_cpp_extension.quantize_fp8_rowwise(x, out_dtype, axis, scale)
assert x.is_contiguous(), "The x tensor must be contiguous."
shuffle_out=enable_preshuffle(),
),
if isinstance(a, QuantizedTensor):
check_quantized_tensor(a, config, scaling_recipe=a_scaling_recipe)
shuffle_scale=enable_preshuffle(),
)
if isinstance(b, QuantizedTensor):
check_quantized_tensor(b, config, scaling_recipe=b_scaling_recipe)
Comment on lines +116 to +125
if a_t is not None:
quantized_a_t = a_t
else:
a_t_scaling_recipe = ScalingRecipe(
use_2d_block=False,
use_sr=False,
use_rht=True,
shuffle_scale=enable_preshuffle(),
shuffle_out=enable_preshuffle(),
)
Comment on lines +135 to +144
if b_t is not None:
quantized_b_t = b_t
else:
b_t_scaling_recipe = ScalingRecipe(
use_2d_block=True,
use_sr=False,
use_rht=True,
shuffle_scale=enable_preshuffle(),
shuffle_out=enable_preshuffle(),
)
# and passed it via ``a_t`` / ``b_t``, reuse it directly; otherwise
# derive it (dequantize + re-quantize along the other axis), mirroring
# FP8GemmRowFunction in gemm_fp8.py.
if a_t is not None:
Copilot AI review requested due to automatic review settings May 20, 2026 04:21
@RuibinCheung RuibinCheung force-pushed the dev/zhangrb/quantized_tensor branch from 45e212e to 02c53a4 Compare May 20, 2026 04:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 33 out of 33 changed files in this pull request and generated 9 comments.

Comments suppressed due to low confidence (1)

primus_turbo/pytorch/core/quantized_tensor.py:517

  • _view_data_and_scale_inv treats target_shape literally (via torch.Size(target_shape)) and doesn't handle -1 shape inference. Calling qt.view(-1, N)/reshape(-1, N) would either create an invalid wrapper shape (containing -1) or fail unexpectedly. Either implement -1 inference to match PyTorch semantics or explicitly assert that all dims are non-negative and -1 is not used.
    def _view_data_and_scale_inv(tensor: QuantizedTensor, target_shape: torch.Size, op: Callable):
        """Apply *op* (``view`` or ``reshape``) to ``_data``, ``_scale_inv``"""
        assert len(target_shape) == tensor.data.ndim

        out_shape = torch.Size(target_shape)
        wrapper_shape = tensor.shape

        if out_shape == wrapper_shape:
            return tensor._data, tensor._scale_inv, out_shape

        align = _get_padding_align_size(tensor)
        padded_target_shape = _pad_inner_dim(target_shape, align)

Comment on lines +62 to 70
use_fp8 = isinstance(inp, QuantizedTensor)

if use_fp8:
raise ValueError("FP8 is not supported for now.")

if use_fp8:
fp8_scale = inp._scale
fp8_dtype = inp._fp8_dtype
fp8_scale = inp._scale_inv
fp8_dtype = inp._dest_dtype
scale_hidden_dim = fp8_scale.shape[1]
if not x.is_contiguous():
x = x.contiguous()
x_fp8, scale_inv = torch.ops.primus_turbo_cpp_extension.quantize_fp8_rowwise(x, out_dtype, axis, scale)
assert x.is_contiguous(), "The x tensor must be contiguous."
Comment on lines +174 to +211
if a_t is None:
quantized_a_t = QuantizedTensor.quantize(
quantized_a.dequantize(),
quantized_a.real_dtype,
config.granularity,
axis=-2,
block_size=config.block_size,
)
else:
assert isinstance(a_t, QuantizedTensor)
quantized_a_t = a_t

if isinstance(b, QuantizedTensor):
check_quantized_tensor(b, config, axis=-1 if trans_b else -2)
quantized_b = b
else:
b_dtype = _get_fp8_dtype(config.format, True)
quantized_b = QuantizedTensor.quantize(
b,
b_dtype,
config.granularity,
axis=-1 if trans_b else -2,
block_size=config.block_size,
)

if b_t is None:
# B's row-wise axis is (-1 if trans_b else -2); the col-wise / trans
# cache used by backward is the other axis.
quantized_b_t = QuantizedTensor.quantize(
quantized_b.dequantize(),
quantized_b.real_dtype,
config.granularity,
axis=-2 if trans_b else -1,
block_size=config.block_size,
)
else:
assert isinstance(b_t, QuantizedTensor)
quantized_b_t = b_t
Comment on lines +427 to +468
if a_t is None:
# MX_BLOCKWISE requires a scaling_recipe; reuse the forward recipe
# for A's col-wise direction (same recipe as forward).
quantized_a_t = QuantizedTensor.quantize(
quantized_a.dequantize(),
quantized_a.real_dtype,
config.granularity,
axis=-2,
block_size=config.block_size,
scaling_recipe=a_scaling_recipe,
)
else:
assert isinstance(a_t, QuantizedTensor)
quantized_a_t = a_t

b_scaling_recipe = ScalingRecipe(use_2d_block=True)
if isinstance(b, QuantizedTensor):
quantized_b = b
check_quantized_tensor(quantized_b, config, axis=-1, scaling_recipe=b_scaling_recipe)
else:
b_dtype = _get_fp8_dtype(config.format, True)
quantized_b = QuantizedTensor.quantize(
b,
b_dtype,
config.granularity,
axis=-1,
block_size=config.block_size,
scaling_recipe=b_scaling_recipe,
)

if b_t is None:
quantized_b_t = QuantizedTensor.quantize(
quantized_b.dequantize(),
quantized_b.real_dtype,
config.granularity,
axis=-2,
block_size=config.block_size,
scaling_recipe=b_scaling_recipe,
)
else:
assert isinstance(b_t, QuantizedTensor)
quantized_b_t = b_t
Comment thread primus_turbo/pytorch/ops/gemm_fp4.py
Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py
Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Copilot AI review requested due to automatic review settings May 21, 2026 01:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 33 out of 33 changed files in this pull request and generated 9 comments.

Comments suppressed due to low confidence (1)

primus_turbo/pytorch/ops/gemm_fp4.py:118

  • When callers provide a_t / b_t (the col-wise/RHT transpose cache) there is no validation that these tensors were quantized with the expected axis (axis=0) and expected RHT scaling recipe. A mismatched cache would silently produce wrong gradients. Consider validating a_t/b_t via check_quantized_tensor(..., axis=0, scaling_recipe=...) before saving/using them.
        # If the caller pre-quantized this and passed it via ``a_t`` / ``b_t``,
        # reuse it directly; otherwise derive it from the forward fp4 tensor.
        if a_t is not None:
            quantized_a_t = a_t
        else:

Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py
Comment thread primus_turbo/pytorch/ops/gemm_fp8.py
Comment thread primus_turbo/pytorch/ops/gemm_fp4.py
Comment thread primus_turbo/pytorch/ops/gemm_fp8.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Comment thread primus_turbo/pytorch/ops/moe/permutation.py
Comment thread primus_turbo/pytorch/kernels/quantization/quantization_impl.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
Comment thread primus_turbo/pytorch/core/quantized_tensor.py
def gemm_fp8(
a: torch.Tensor,
b: torch.Tensor,
a: Union[torch.Tensor, QuantizedTensorPair],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union[torch.Tensor, QuantizedTensor, QuantizedTensorPair]

"autograd.Function.forward and reads their inner tensors (data / scale_inv). "
"Dynamo cannot recover Python sources for those graph-internal inner tensors, "
),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check cudagraph

if isinstance(a, QuantizedTensor):
assert a._is_grouped_tensor, "A QuantizedTensor input must be a grouped tensor"
check_quantized_tensor(a, config, axis=-1)
assert torch.equal(a.group_lens, group_lens), "a.group_lens must match the given group_lens"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it cause synchronization?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants