feat: add quantized tensor support#335
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a unified low-precision tensor wrapper (QuantizedTensor / GroupedQuantizedTensor) and rewires FP8/FP4 GEMM + grouped GEMM ops to accept either raw tensors or pre-quantized wrappers, enabling quantization reuse and optional transpose-cache reuse across forward/backward.
Changes:
- Added
primus_turbo/pytorch/core/quantized_tensor.pyimplementing the wrapper abstraction + serialization/view support. - Updated FP8/FP4 GEMM and grouped-FP8 GEMM to consume wrapper inputs (skipping internal quantization when provided).
- Added/extended substantial test coverage for the new wrapper behavior and wrapper-aware ops; removed legacy
Float8Tensor.
Reviewed changes
Copilot reviewed 41 out of 41 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
primus_turbo/pytorch/core/quantized_tensor.py |
New QuantizedTensor / GroupedQuantizedTensor implementation with quantization, transpose-cache, view/reshape, serialization. |
primus_turbo/pytorch/core/low_precision.py |
Adds padding-align constants and clarifies ScalingRecipe docs. |
primus_turbo/pytorch/core/__init__.py |
Updates core exports to include QuantizedTensor and explicitly list __all__. |
primus_turbo/pytorch/core/float8_tensor.py |
Removes the legacy Float8Tensor wrapper. |
primus_turbo/pytorch/ops/gemm_fp8.py |
Makes gemm_fp8 accept wrapper inputs; reuses cached transpose in backward paths. |
primus_turbo/pytorch/ops/gemm_fp4.py |
Makes gemm_fp4 accept wrapper inputs and reuse wrapper caches/recipes. |
primus_turbo/pytorch/ops/grouped_gemm_fp8.py |
Makes grouped_gemm_fp8 accept wrapper inputs for activations/weights and use wrapper metadata. |
primus_turbo/pytorch/ops/grouped_gemm.py |
Removes group_offs parameter and recomputes offsets internally. |
primus_turbo/pytorch/ops/gemm.py |
Uses torch.promote_types for output dtype inference. |
primus_turbo/pytorch/ops/quantization.py |
Switches MX block-size constants to MXFP8_BLOCK_SIZE / MXFP4_BLOCK_SIZE and updates assertions. |
primus_turbo/pytorch/ops/moe/permutation.py |
Swaps legacy Float8Tensor references toward QuantizedTensor (but FP8 path remains unsupported). |
primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py |
Adds __all__ exports. |
primus_turbo/pytorch/ops/moe/indices_converter.py |
Adds __all__ exports. |
primus_turbo/pytorch/ops/moe/fused_moe_router.py |
Adds __all__ exports. |
primus_turbo/pytorch/kernels/quantization/quantization_impl.py |
Threads through padding-align size and tightens MX block-size validation. |
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_utils.py |
Renames _group_offs_from_lens to group_offs_from_lens and updates uses. |
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py |
Removes grouped_gemm_compute_offs helper. |
csrc/pytorch/quantization/quantization.cpp |
Adds padding-align-size parameter to MX quantization entry points. |
csrc/pytorch/quantization/quantization_meta.cpp |
Meta kernels updated for padding-align-size parameter. |
csrc/pytorch/extensions.h |
Updates declarations to match new MX quantization signatures. |
csrc/pytorch/bindings_pytorch.cpp |
Updates PyTorch bindings signatures for MX quantization ops. |
csrc/include/primus_turbo/quantization.h |
Removes hardcoded padding-align constants from header. |
csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu |
Adjusts hipBLASLt grouped GEMM stream/handle management and sync logic. |
csrc/kernels/deep_ep/utils.cuh |
Removes custom barrier macros (replaced by simpler sync logic). |
csrc/kernels/deep_ep/intranode.cu |
Updates synchronization primitives and warp intrinsics usage. |
primus_turbo/jax/primitive/moe/moe_dispatch.py |
Drops has_side_effect=True in ffi lowering calls. |
primus_turbo/jax/primitive/moe/moe_combine.py |
Drops has_side_effect=True in ffi lowering calls. |
primus_turbo/jax/lax/moe/moe_utils.py |
Replaces C++ Config import with a Python NamedTuple definition. |
primus_turbo/jax/lax/moe/moe_dispatch_combine.py |
Imports Config from moe_utils, updates exports, minor docstring updates. |
primus_turbo/jax/lax/moe/__init__.py |
Updates exported symbols to match new lax API surface. |
csrc/jax/deep_ep/deep_ep.cpp |
Modifies input layout/alignment checks for intranode dispatch. |
tests/pytorch/core/test_quantized_tensor.py |
New tests for QuantizedTensor construction/serialization/view/dequantize. |
tests/pytorch/core/test_grouped_quantized_tensor.py |
New tests for GroupedQuantizedTensor group metadata + behavior. |
tests/pytorch/ops/test_gemm_fp8.py |
Adds wrapper-input FP8 GEMM tests across granularities/backends. |
tests/pytorch/ops/test_gemm_fp4.py |
Adds wrapper-input FP4 GEMM tests for MX path. |
tests/pytorch/ops/test_grouped_gemm_fp8.py |
Adds wrapper-input grouped FP8 GEMM tests (tensorwise/rowwise). |
tests/pytorch/ops/test_grouped_gemm.py |
Adds temporary gfx942 hipBLASLt flake skips. |
setup.py |
Removes hipblas from extension link libraries. |
SECURITY.md |
Deletes repository security policy file. |
benchmark/ops/bench_grouped_gemm_gmm.py |
Removes the legacy grouped_gemm baseline benchmark script. |
Comments suppressed due to low confidence (2)
primus_turbo/pytorch/ops/moe/permutation.py:95
- The FP8/QuantizedTensor path here is internally inconsistent:
use_fp8immediately raises, but the code below still references legacyFloat8Tensorfields (_scale,_fp8_dtype,_config) and even tries to constructQuantizedTensorusing the oldFloat8Tensor-style kwargs (data=..., scale=..., orig_dtype=...). With the newQuantizedTensorAPI this branch would crash if it ever became reachable. Please either (a) removeQuantizedTensorfrom the accepted input type and delete the dead FP8 code paths, or (b) implement the FP8 permute/unpermute support using the new wrapper’sdata/scale_invsemantics and constructor signature.
primus_turbo/pytorch/ops/grouped_gemm_fp8.py:466 grouped_gemm_fp8declaresconfig: Float8QuantConfig | None = Nonebut then unconditionally readsconfig.granularity. Ifconfigis actuallyNonethis will crash, and it contradicts the docstring claim thatNoneuses a default config. Add the missingif config is None: config = Float8QuantConfig()(and ideally validateout_dtypewhen wrappers are passed).
if out_dtype is None:
out_dtype = torch.promote_types(a.dtype, b.dtype)
args = (a, b, group_lens, trans_b, out_dtype, config, num_cu)
if config.granularity == ScalingGranularity.TENSORWISE:
return FP8GroupedGemmTensorFunc.apply(*args)
elif config.granularity == ScalingGranularity.ROWWISE:
return FP8GroupedGemmRowFunc.apply(*args)
elif config.granularity == ScalingGranularity.BLOCKWISE:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
60fb133 to
efe9b3b
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 29 out of 29 changed files in this pull request and generated 7 comments.
Comments suppressed due to low confidence (2)
primus_turbo/pytorch/ops/grouped_gemm_fp8.py:465
- grouped_gemm_fp8() allows config=None (per signature/docstring) but immediately dereferences config.granularity. This will raise at runtime for default callers. Initialize config = Float8QuantConfig() when config is None (mirroring gemm_fp8).
if out_dtype is None:
out_dtype = torch.promote_types(a.dtype, b.dtype)
args = (a, b, group_lens, trans_b, out_dtype, config, num_cu)
if config.granularity == ScalingGranularity.TENSORWISE:
return FP8GroupedGemmTensorFunc.apply(*args)
elif config.granularity == ScalingGranularity.ROWWISE:
return FP8GroupedGemmRowFunc.apply(*args)
primus_turbo/pytorch/ops/moe/permutation.py:95
- This module was switched from Float8Tensor to QuantizedTensor, but the FP8 branches still reference legacy attributes (_scale, _fp8_dtype, _config) and construct QuantizedTensor with unsupported kwargs (data=..., scale=..., orig_dtype=...). Even though FP8 is currently blocked by a raise/assert, this code is now dead/incorrect and will break if FP8 support is re-enabled. Please delete the unreachable FP8 branches or update them to the QuantizedTensor API.
use_fp8 = isinstance(inp, QuantizedTensor)
if use_fp8:
raise ValueError("FP8 is not supported for now.")
if use_fp8:
fp8_scale = inp._scale
fp8_dtype = inp._fp8_dtype
scale_hidden_dim = fp8_scale.shape[1]
else:
fp8_scale = None
fp8_dtype = None
scale_hidden_dim = None
output, permuted_scale, permuted_probs = permutation.permute_with_mask_map(
inp,
row_id_map,
probs,
fp8_scale,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
scale_hidden_dim,
)
if use_fp8:
output = QuantizedTensor(
data=output,
scale=permuted_scale,
orig_dtype=inp._orig_dtype,
fp8_dtype=fp8_dtype,
config=inp._config,
)
| ) and tensor._scale_inv is not None: | ||
| block_size = tensor._block_size | ||
| packing = _get_packing_factor(tensor) | ||
| scale_target_shape = _compute_scale_shape(padded_target_shape, block_size, packing) |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 29 out of 29 changed files in this pull request and generated 11 comments.
Comments suppressed due to low confidence (1)
primus_turbo/pytorch/ops/moe/permutation.py:66
- This file still treats QuantizedTensor like the removed Float8Tensor (FP8 metadata fields like _scale/_fp8_dtype/_config are referenced below). If FP8 remains unsupported here, consider dropping QuantizedTensor from the accepted types; otherwise update the FP8 path to the new QuantizedTensor API to avoid future runtime attribute errors.
use_fp8 = isinstance(inp, QuantizedTensor)
if use_fp8:
raise ValueError("FP8 is not supported for now.")
| ) | ||
|
|
||
| if isinstance(a, GroupedQuantizedTensor): | ||
| check_grouped_quantized_tensor(a, config, group_lens) |
| assert quantized_tensor.granularity == config.granularity, ( | ||
| f"QuantizedTensor granularity {quantized_tensor.granularity} does not match config " | ||
| f"granularity {config.granularity}" | ||
| ) | ||
| assert quantized_tensor.block_size == config.block_size, ( |
| # TODO(ruibin): ``_compute_scale_shape`` assumes axis=-1 | ||
| # chunked scale layout (row-quant). ``_scale_inv_t`` is | ||
| # chunked along axis=-2 (col-quant), so its shape differs. | ||
| # Wrapper-input BLOCKWISE/MX_BLOCKWISE is not supported yet | ||
| # (see grouped_gemm_fp8 BlockFunc), so we keep the existing | ||
| # call here as a placeholder; revisit when that path lands. | ||
| scale_t_target_shape = _compute_scale_shape( | ||
| padded_target_shape, tensor._block_size, _get_packing_factor(tensor) | ||
| ) | ||
| out_scale_inv_t = op(out_scale_inv_t, *scale_t_target_shape) |
| # FP8 dtypes to cover across tests. GroupedQuantizedTensor explicitly does | ||
| # *not* accept FP4; that case is checked in TestGroupSpecific. | ||
| _DTYPES = set([float8_e4m3, float8_e5m2]) | ||
|
|
||
| if MXFP4_SUPPORT: |
| granularity: ScalingGranularity = ScalingGranularity.MX_BLOCKWISE, | ||
| block_size: int = MXFP8_BLOCK_SIZE, | ||
| dest_dtype: torch.dtype = float8_e4m3, | ||
| keep_trans_cache: bool = False, | ||
| ) -> GroupedQuantizedTensor: | ||
| """Unified helper: construct a GroupedQuantizedTensor with sensible | ||
| defaults for each granularity. | ||
|
|
||
| - TENSORWISE: no block_size, no recipe (a single global scale) | ||
| - ROWWISE: no block_size, no recipe | ||
| - BLOCKWISE: block_size required (1D blockwise; 2D-block / weight-style | ||
| blockwise does not apply to packed-M activations) | ||
| - MX_BLOCKWISE: block_size + ScalingRecipe required |
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| group_lens: torch.Tensor, # [B,] int64 | ||
| group_offs: torch.Tensor, # [B+1,] int64 |
There was a problem hiding this comment.
The group_offs can be inferred by group_lens. They are essentially the same thing. If we keep them both it's ambiguous and not user friendly I think. And it may cause potential issue when group_lens is not match group_offs that user passed. May I know there is any reason to must keep this args ?
There was a problem hiding this comment.
torchtitan need this args.
| __torch_function__ = torch._C._disabled_torch_function_impl | ||
|
|
||
|
|
||
| class GroupedQuantizedTensor(QuantizedTensor): |
There was a problem hiding this comment.
Why do we need GroupedQuantizedTensor?
There was a problem hiding this comment.
The activation of grouped gemm needs group_lens attribute. For mx and blockwise, the quantization and dequantization func also have some different. To make internal func clear that I abstract it into a new class named GroupedQuantizedTensor.
| """ | ||
| return self._orig_ndim | ||
|
|
||
| def t(self) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
There is a risk that it may cause ambiguity.
There was a problem hiding this comment.
The t func is original from torch.Tensor. I hope to keep QuantizedTensor has same api as torch.Tensor. Do you have any suggestion for this func?
There was a problem hiding this comment.
Return QuantizedTensor
| return self._data_t, self._scale_inv_t | ||
|
|
||
| @property | ||
| def T(self) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
same as above
| """Wrapper subclass that carries low-precision quantized data, scale_inv""" | ||
|
|
||
| @staticmethod | ||
| def __new__( |
There was a problem hiding this comment.
Launching a quantize kernel inside __new__ is somewhat against the conventions of PyTorch tensor subclasses.
Conventionally, __new__ should only handle _make_wrapper_subclass(...) plus field initialization, while quantization should go through a separate factory method.
class QuantizedTensor(torch.Tensor):
# Pure wrapper: only stores fields, does not launch kernels
@staticmethod
def __new__(cls, data, scale_inv, *, dest_dtype, granularity, ...):
self = torch.Tensor._make_wrapper_subclass(cls, ...)
self._data = data
self._scale_inv = scale_inv
# ... store fields
return self
# Factory method: quantization kernel is launched here
@classmethod
def quantize(cls, hp_tensor, dest_dtype, granularity, *, keep_trans_cache=False, ...):
data, scale_inv, data_t, scale_inv_t = _do_quantize(hp_tensor, ...)
return cls(data, scale_inv, dest_dtype=..., ...)
usage
fp8_tensor = QuantizedTensor.quantize(fp16_tensor, ...)
fp8_tensor = QuantizedTensor(fp8_data, scale, ....)
| return self | ||
|
|
||
| @classmethod | ||
| @torch.no_grad() |
There was a problem hiding this comment.
no grad?
The @torch.no_grad() on _quantize causes QuantizedTensor(x) to break the autograd chain of x — after wrapping an nn.Parameter, x.grad will always be None. Is this intended?
There was a problem hiding this comment.
It is expected. The quantized tensor actually is a buffer. The grad will accumulate its original x.grad.
| # Dispatch hooks | ||
| # ------------------------------------------------------------------ | ||
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
There was a problem hiding this comment.
__torch_dispatch__ silently dequantizing all ops by default is quite risky.
Recommend making the default behavior raise NotImplementedError instead.
| """ | ||
| if out_dtype is None: | ||
| out_dtype = torch.promote_types(a.dtype, b.dtype) | ||
|
|
There was a problem hiding this comment.
miss if config == None
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (2)
primus_turbo/pytorch/ops/grouped_gemm_fp8.py:476
- grouped_gemm_fp8() declares
config: Union[Float8QuantConfig, None] = None, but it unconditionally readsconfig.granularity(and passesconfiginto autograd Functions). Calling grouped_gemm_fp8(..., config=None) will raise an AttributeError. Addif config is None: config = Float8QuantConfig()before accessingconfig.granularity(and consider validating the type early).
if out_dtype is None:
out_dtype = torch.promote_types(a.dtype, b.dtype)
args = (a, b, group_lens, trans_b, out_dtype, config, num_cu)
if config.granularity == ScalingGranularity.TENSORWISE:
return FP8GroupedGemmTensorFunc.apply(*args)
elif config.granularity == ScalingGranularity.ROWWISE:
return FP8GroupedGemmRowFunc.apply(*args)
elif config.granularity == ScalingGranularity.BLOCKWISE:
primus_turbo/pytorch/ops/moe/permutation.py:95
- There is now dead / inconsistent FP8-handling code that still references the removed Float8Tensor fields (
_scale,_fp8_dtype,_config) and constructsQuantizedTensorusing the old Float8Tensor constructor signature. Even thoughuse_fp8currently raises immediately, this code is misleading and will break if FP8 support is re-enabled. Either delete the unreachable FP8 branch or update it to the new QuantizedTensor API (data + scale_inv + granularity metadata).
use_fp8 = isinstance(inp, QuantizedTensor)
if use_fp8:
raise ValueError("FP8 is not supported for now.")
if use_fp8:
fp8_scale = inp._scale
fp8_dtype = inp._fp8_dtype
scale_hidden_dim = fp8_scale.shape[1]
else:
fp8_scale = None
fp8_dtype = None
scale_hidden_dim = None
output, permuted_scale, permuted_probs = permutation.permute_with_mask_map(
inp,
row_id_map,
probs,
fp8_scale,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
scale_hidden_dim,
)
if use_fp8:
output = QuantizedTensor(
data=output,
scale=permuted_scale,
orig_dtype=inp._orig_dtype,
fp8_dtype=fp8_dtype,
config=inp._config,
)
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 30 changed files in this pull request and generated 18 comments.
Comments suppressed due to low confidence (5)
primus_turbo/pytorch/core/quantized_tensor.py:607
_get_padding_align_size()returns 0 for TENSORWISE, ROWWISE, and BLOCKWISE tensors, so any view/reshape to a different same-rank shape calls_pad_inner_dim(..., 0)and divides by zero before reaching the underlying view. This makes the publicview/reshapemethods unusable for non-MX granularities except for the exact same shape.
align = _get_padding_align_size(tensor)
padded_target_shape = _pad_inner_dim(target_shape, align)
primus_turbo/pytorch/ops/gemm_fp8.py:284
- For BLOCKWISE wrapper inputs this accepts an activation without a transpose cache, but the code unconditionally calls
a_fp8.t()below. Since blockwise dequantization is not implemented inQuantizedTensor.dequantize(), a valid-looking pre-quantized activation withkeep_trans_cache=Falsewill fail at runtime instead of being rejected with a clear validation error.
if isinstance(a, QuantizedTensor):
check_quantized_tensor(a, config)
a_fp8 = a
primus_turbo/pytorch/ops/grouped_gemm_fp8.py:440
- Disabling Dynamo on the public grouped FP8 GEMM API forces graph breaks for all callers, including raw tensor inputs that do not use the new wrappers. This makes full-graph
torch.compileincompatible with grouped FP8 GEMM rather than only guarding the unsupported wrapper-specific path.
@torch._dynamo.disable(
recursive=True,
reason=(
"Grouped FP8 GEMM constructs (Grouped)QuantizedTensor wrapper subclasses "
"inside its autograd.Function.forward and reads their inner tensors "
"(data / scale_inv / group_lens / group_offs). Dynamo cannot recover Python "
"sources for those graph-internal inner tensors, tripping gb0116 "
"('SourcelessBuilder.create cannot wrap FakeTensor'). "
),
)
primus_turbo/pytorch/core/quantized_tensor.py:659
- These overrides pass
shapethrough exactly as collected from*shape, so common PyTorch forms likeqt.view(torch.Size([m, n])),qt.view((m, n)), orqt.reshape(-1, n)are not normalized the wayTensor.view/reshapedoes. They will be interpreted as a one-element shape tuple (or fail before-1inference), making the wrapper incompatible with standard view/reshape call patterns.
def view(self, *shape) -> "QuantizedTensor":
"""View without dequantizing (autograd-aware)."""
return _ViewFunc.apply(self, shape)
def reshape(self, *shape) -> "QuantizedTensor":
"""Reshape without dequantizing (autograd-aware)."""
return _ReshapeFunc.apply(self, shape)
primus_turbo/pytorch/core/quantized_tensor.py:635
- This assumes
_data_thas the same physical shape as_data, but the MXFP8/MXFP4 dual quantization kernels return the colwise buffer with the last two dimensions transposed (e.g.[N, M_pad]rather than[M, N_pad]). Applying the same target shape to_data_twill corrupt or fail non-identity views of MX wrappers with a transpose cache.
out_data_t = tensor._data_t
out_scale_inv_t = tensor._scale_inv_t
if out_data_t is not None:
# ``_data_t`` is the col-quantized counterpart at the SAME
# physical layout as ``_data`` — apply the identical target
# shape, no last-two-dim swap.
assert (
out_data_t.numel() == torch.Size(padded_target_shape).numel()
), "data_t numel and padded_target_shape must have the same number of elements"
out_data_t = op(out_data_t, *padded_target_shape)
| return list(tensors.keys()), metadata | ||
|
|
||
| @staticmethod | ||
| def __tensor_unflatten__(inner_tensors, metadata, outer_size): |
| return keys, metadata | ||
|
|
||
| @staticmethod | ||
| def __tensor_unflatten__(inner_tensors, metadata, outer_size): |
| def __tensor_flatten__(self): | ||
| keys, metadata = super().__tensor_flatten__() | ||
| metadata["_group_lens"] = self._group_lens |
| assert quantized_tensor.granularity == config.granularity, ( | ||
| f"QuantizedTensor granularity {quantized_tensor.granularity} does not match config " | ||
| f"granularity {config.granularity}" | ||
| ) | ||
| assert quantized_tensor.block_size == config.block_size, ( | ||
| f"QuantizedTensor block_size {quantized_tensor.block_size} does not match config " | ||
| f"block_size {config.block_size}" | ||
| ) |
| assert granularity in [ | ||
| ScalingGranularity.ROWWISE, | ||
| ScalingGranularity.TENSORWISE, | ||
| ], "GroupedQuantizedTensor only supports ROWWISE and TENSORWISE granularity" |
| align = _get_padding_align_size(tensor) | ||
| padded_target_shape = _pad_inner_dim(target_shape, align) | ||
|
|
||
| assert ( | ||
| tensor._data.numel() == torch.Size(padded_target_shape).numel() | ||
| ), "data numel and padded_target_shape must have the same number of elements" | ||
|
|
||
| out_data = op(tensor._data, *padded_target_shape) |
| @torch.no_grad() | ||
| def dequantize(self) -> torch.Tensor: | ||
| """Dequantize back to the original high-precision dtype.""" | ||
| from primus_turbo.pytorch.ops.quantization import dequantize_fp4, dequantize_fp8 | ||
|
|
||
| axis = _normalize_axis(-1, self._data.ndim) | ||
|
|
||
| if self._dest_dtype in [float8_e4m3, float8_e5m2]: | ||
| return dequantize_fp8( | ||
| self._data, | ||
| self._orig_dtype, | ||
| self._granularity, | ||
| block_size=self._block_size, | ||
| axis=axis, | ||
| scale_inv=self._scale_inv, | ||
| scaling_recipe=self._scaling_recipe, | ||
| ) | ||
| elif self._dest_dtype == float4_e2m1fn_x2: | ||
| return dequantize_fp4( | ||
| self._data, | ||
| self._orig_dtype, | ||
| self._granularity, | ||
| block_size=self._block_size, | ||
| axis=axis, | ||
| scale_inv=self._scale_inv, | ||
| scaling_recipe=self._scaling_recipe, | ||
| ) |
| def grouped_gemm( | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| group_lens: torch.Tensor, | ||
| group_offs: torch.Tensor | None = None, | ||
| trans_b: bool = False, | ||
| num_cu: int | None = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 32 out of 32 changed files in this pull request and generated 12 comments.
Comments suppressed due to low confidence (5)
primus_turbo/pytorch/core/quantized_tensor.py:241
- Grouped wrappers accept any
group_lenstensor but do not validate that it is 1D int64 on the same device or thatgroup_lens.sum()equals the packed-M dimension. A stale or mismatchedgroup_lenscan produce invalidgroup_offsand drive grouped GEMM kernels with offsets outside the quantized activation buffer.
is_grouped_tensor = group_lens is not None
if is_grouped_tensor:
assert (
hp_tensor.ndim == 2
), f"Grouped quantized tensor expects a 2D packed-M tensor, got {hp_tensor.ndim}D"
assert (
dest_dtype in _SUPPORTED_QUANTIZED_DTYPES and dest_dtype != float4_e2m1fn_x2
), "Unsupported quantized dtype (FP4 not supported for grouped activations)"
assert granularity in [
ScalingGranularity.ROWWISE,
ScalingGranularity.TENSORWISE,
], "Grouped quantized tensor only supports ROWWISE and TENSORWISE granularity"
primus_turbo/pytorch/core/quantized_tensor.py:426
- The flatten metadata omits
_quantized_axis. After FSDP/compile unflattening, rowwise/MX tensors lose the axis needed bycheck_quantized_tensor(..., axis=...)anddequantize(), so serialized wrappers can no longer be validated or dequantized correctly.
metadata = {
"_orig_dtype": self._orig_dtype,
"_dest_dtype": self._dest_dtype,
"_granularity": self._granularity,
"_block_size": self._block_size,
"_scaling_recipe": self._scaling_recipe,
"_is_grouped_tensor": self._is_grouped_tensor,
}
primus_turbo/pytorch/core/quantized_tensor.py:478
_make_likedoes not carry over_quantized_axis. Any tensor returned fromview()/reshape()loses the axis metadata, so laterdequantize()orcheck_quantized_tensor(..., axis=...)on rowwise/MX wrappers will fail or useNoneas the axis.
return cls(
data,
scale_inv,
shape=shape,
orig_dtype=tensor._orig_dtype,
dest_dtype=tensor._dest_dtype,
granularity=tensor._granularity,
block_size=tensor._block_size,
scaling_recipe=tensor._scaling_recipe,
group_lens=tensor._group_lens,
group_offs=tensor._group_offs,
is_grouped_tensor=tensor._is_grouped_tensor,
requires_grad=data.requires_grad,
)
primus_turbo/pytorch/core/quantized_tensor.py:492
- For TENSORWISE and ROWWISE tensors
_get_padding_align_size()returns 0, so any rank-preserving view/reshape to a different shape reaches_pad_inner_dim(target_shape, 0)and divides by zero. This makes the advertised view/reshape support work only for the same-shape fast path.
align = _get_padding_align_size(tensor)
padded_target_shape = _pad_inner_dim(target_shape, align)
primus_turbo/pytorch/ops/gemm_fp4.py:136
b_fp4.dtypealso resolves to the original fp16/bf16 dtype, so this transpose-cache quantization passes an unsupporteddest_dtypeand fails before backward can run. Use the quantized storage dtype metadata instead.
quantized_b_t = QuantizedTensor.quantize(
b_fp4.dequantize(),
b_fp4.dtype,
config.granularity,
| def quantize( | ||
| cls, | ||
| hp_tensor: Union[torch.Tensor, torch.nn.Parameter], | ||
| dest_dtype: torch.dtype, | ||
| granularity: ScalingGranularity, | ||
| axis: int, | ||
| *, | ||
| group_lens: Optional[torch.Tensor] = None, | ||
| block_size: Optional[int] = None, | ||
| scaling_recipe: Optional[ScalingRecipe] = None, | ||
| ) -> "QuantizedTensor": |
| @staticmethod | ||
| def __tensor_unflatten__(inner_tensors, metadata, outer_size): | ||
| return QuantizedTensor( | ||
| inner_tensors["_data"], | ||
| inner_tensors["_scale_inv"], | ||
| shape=outer_size, | ||
| orig_dtype=metadata["_orig_dtype"], | ||
| dest_dtype=metadata["_dest_dtype"], | ||
| granularity=metadata["_granularity"], | ||
| block_size=metadata["_block_size"], | ||
| scaling_recipe=metadata["_scaling_recipe"], | ||
| group_lens=inner_tensors.get("_group_lens"), | ||
| group_offs=inner_tensors.get("_group_offs"), | ||
| is_grouped_tensor=metadata["_is_grouped_tensor"], | ||
| ) |
| def grouped_gemm( | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| group_lens: torch.Tensor, | ||
| group_offs: torch.Tensor | None = None, | ||
| trans_b: bool = False, | ||
| num_cu: int | None = None, |
| Output tensor with shape [m, n] (same dtype as input) | ||
| """ | ||
| if out_dtype is None: | ||
| out_dtype = torch.promote_types(a.dtype, b.dtype) |
| quantized_a_t = QuantizedTensor.quantize( | ||
| a_fp4.dequantize(), | ||
| a_fp4.dtype, | ||
| config.granularity, |
| scaling_recipe_for_trans=unpermuted_act_grad._scaling_recipe_for_trans, | ||
| keep_trans_cache=unpermuted_act_grad._keep_trans_cache, |
| from .moe_dispatch_combine import ( | ||
| Config, | ||
| get_combine_config, | ||
| get_dispatch_config, | ||
| moe_combine, | ||
| moe_dispatch, |
| from .quantized_tensor import QuantizedTensor | ||
| from .stream import TurboStream | ||
| from .symm_mem import SymmetricMemory, get_symm_mem_workspace | ||
|
|
||
| __all__ = [ | ||
| "QuantizedTensor", | ||
| "SymmetricMemory", |
| # BLOCKWISE is intentionally not handled by the QuantizedTensor wrapper | ||
| assert ( | ||
| granularity != ScalingGranularity.BLOCKWISE | ||
| ), "BLOCKWISE is not supported by QuantizedTensor; Please call the low-level kernels directly." |
| def quantize( | ||
| cls, | ||
| hp_tensor: Union[torch.Tensor, torch.nn.Parameter], | ||
| dest_dtype: torch.dtype, | ||
| granularity: ScalingGranularity, | ||
| axis: int, | ||
| *, |
2bcf803 to
73e0716
Compare
| group_offs=tensor._group_offs, | ||
| is_grouped_tensor=tensor._is_grouped_tensor, | ||
| quantized_axis=tensor._quantized_axis, | ||
| requires_grad=data.requires_grad, |
| out_scale_inv = tensor._scale_inv | ||
| if tensor._granularity == ScalingGranularity.MX_BLOCKWISE and tensor._scale_inv is not None: | ||
| block_size = tensor._block_size | ||
| scale_target_shape = _compute_scale_shape(data_target_shape, block_size, packing) | ||
| out_scale_inv = op(tensor._scale_inv, *scale_target_shape) |
| assert hp_tensor.ndim in (2, 3), f"data must be a 2D or 3D tensor, got {hp_tensor.ndim}D" | ||
| assert dest_dtype in _SUPPORTED_QUANTIZED_DTYPES, "Unsupported quantized dtype" | ||
|
|
| class QuantizedTensor(torch.Tensor): | ||
| """Wrapper subclass that carries low-precision quantized data, scale_inv.""" | ||
|
|
| if not x.is_contiguous(): | ||
| x = x.contiguous() | ||
| x_fp8, scale_inv = torch.ops.primus_turbo_cpp_extension.quantize_fp8_rowwise(x, out_dtype, axis, scale) | ||
| assert x.is_contiguous(), "The x tensor must be contiguous." |
| shuffle_out=enable_preshuffle(), | ||
| ), | ||
| if isinstance(a, QuantizedTensor): | ||
| check_quantized_tensor(a, config, scaling_recipe=a_scaling_recipe) |
| shuffle_scale=enable_preshuffle(), | ||
| ) | ||
| if isinstance(b, QuantizedTensor): | ||
| check_quantized_tensor(b, config, scaling_recipe=b_scaling_recipe) |
| if a_t is not None: | ||
| quantized_a_t = a_t | ||
| else: | ||
| a_t_scaling_recipe = ScalingRecipe( | ||
| use_2d_block=False, | ||
| use_sr=False, | ||
| use_rht=True, | ||
| shuffle_scale=enable_preshuffle(), | ||
| shuffle_out=enable_preshuffle(), | ||
| ) |
| if b_t is not None: | ||
| quantized_b_t = b_t | ||
| else: | ||
| b_t_scaling_recipe = ScalingRecipe( | ||
| use_2d_block=True, | ||
| use_sr=False, | ||
| use_rht=True, | ||
| shuffle_scale=enable_preshuffle(), | ||
| shuffle_out=enable_preshuffle(), | ||
| ) |
| # and passed it via ``a_t`` / ``b_t``, reuse it directly; otherwise | ||
| # derive it (dequantize + re-quantize along the other axis), mirroring | ||
| # FP8GemmRowFunction in gemm_fp8.py. | ||
| if a_t is not None: |
45e212e to
02c53a4
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 33 out of 33 changed files in this pull request and generated 9 comments.
Comments suppressed due to low confidence (1)
primus_turbo/pytorch/core/quantized_tensor.py:517
_view_data_and_scale_invtreatstarget_shapeliterally (viatorch.Size(target_shape)) and doesn't handle-1shape inference. Callingqt.view(-1, N)/reshape(-1, N)would either create an invalid wrapper shape (containing -1) or fail unexpectedly. Either implement -1 inference to match PyTorch semantics or explicitly assert that all dims are non-negative and-1is not used.
def _view_data_and_scale_inv(tensor: QuantizedTensor, target_shape: torch.Size, op: Callable):
"""Apply *op* (``view`` or ``reshape``) to ``_data``, ``_scale_inv``"""
assert len(target_shape) == tensor.data.ndim
out_shape = torch.Size(target_shape)
wrapper_shape = tensor.shape
if out_shape == wrapper_shape:
return tensor._data, tensor._scale_inv, out_shape
align = _get_padding_align_size(tensor)
padded_target_shape = _pad_inner_dim(target_shape, align)
| use_fp8 = isinstance(inp, QuantizedTensor) | ||
|
|
||
| if use_fp8: | ||
| raise ValueError("FP8 is not supported for now.") | ||
|
|
||
| if use_fp8: | ||
| fp8_scale = inp._scale | ||
| fp8_dtype = inp._fp8_dtype | ||
| fp8_scale = inp._scale_inv | ||
| fp8_dtype = inp._dest_dtype | ||
| scale_hidden_dim = fp8_scale.shape[1] |
| if not x.is_contiguous(): | ||
| x = x.contiguous() | ||
| x_fp8, scale_inv = torch.ops.primus_turbo_cpp_extension.quantize_fp8_rowwise(x, out_dtype, axis, scale) | ||
| assert x.is_contiguous(), "The x tensor must be contiguous." |
| if a_t is None: | ||
| quantized_a_t = QuantizedTensor.quantize( | ||
| quantized_a.dequantize(), | ||
| quantized_a.real_dtype, | ||
| config.granularity, | ||
| axis=-2, | ||
| block_size=config.block_size, | ||
| ) | ||
| else: | ||
| assert isinstance(a_t, QuantizedTensor) | ||
| quantized_a_t = a_t | ||
|
|
||
| if isinstance(b, QuantizedTensor): | ||
| check_quantized_tensor(b, config, axis=-1 if trans_b else -2) | ||
| quantized_b = b | ||
| else: | ||
| b_dtype = _get_fp8_dtype(config.format, True) | ||
| quantized_b = QuantizedTensor.quantize( | ||
| b, | ||
| b_dtype, | ||
| config.granularity, | ||
| axis=-1 if trans_b else -2, | ||
| block_size=config.block_size, | ||
| ) | ||
|
|
||
| if b_t is None: | ||
| # B's row-wise axis is (-1 if trans_b else -2); the col-wise / trans | ||
| # cache used by backward is the other axis. | ||
| quantized_b_t = QuantizedTensor.quantize( | ||
| quantized_b.dequantize(), | ||
| quantized_b.real_dtype, | ||
| config.granularity, | ||
| axis=-2 if trans_b else -1, | ||
| block_size=config.block_size, | ||
| ) | ||
| else: | ||
| assert isinstance(b_t, QuantizedTensor) | ||
| quantized_b_t = b_t |
| if a_t is None: | ||
| # MX_BLOCKWISE requires a scaling_recipe; reuse the forward recipe | ||
| # for A's col-wise direction (same recipe as forward). | ||
| quantized_a_t = QuantizedTensor.quantize( | ||
| quantized_a.dequantize(), | ||
| quantized_a.real_dtype, | ||
| config.granularity, | ||
| axis=-2, | ||
| block_size=config.block_size, | ||
| scaling_recipe=a_scaling_recipe, | ||
| ) | ||
| else: | ||
| assert isinstance(a_t, QuantizedTensor) | ||
| quantized_a_t = a_t | ||
|
|
||
| b_scaling_recipe = ScalingRecipe(use_2d_block=True) | ||
| if isinstance(b, QuantizedTensor): | ||
| quantized_b = b | ||
| check_quantized_tensor(quantized_b, config, axis=-1, scaling_recipe=b_scaling_recipe) | ||
| else: | ||
| b_dtype = _get_fp8_dtype(config.format, True) | ||
| quantized_b = QuantizedTensor.quantize( | ||
| b, | ||
| b_dtype, | ||
| config.granularity, | ||
| axis=-1, | ||
| block_size=config.block_size, | ||
| scaling_recipe=b_scaling_recipe, | ||
| ) | ||
|
|
||
| if b_t is None: | ||
| quantized_b_t = QuantizedTensor.quantize( | ||
| quantized_b.dequantize(), | ||
| quantized_b.real_dtype, | ||
| config.granularity, | ||
| axis=-2, | ||
| block_size=config.block_size, | ||
| scaling_recipe=b_scaling_recipe, | ||
| ) | ||
| else: | ||
| assert isinstance(b_t, QuantizedTensor) | ||
| quantized_b_t = b_t |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 33 out of 33 changed files in this pull request and generated 9 comments.
Comments suppressed due to low confidence (1)
primus_turbo/pytorch/ops/gemm_fp4.py:118
- When callers provide
a_t/b_t(the col-wise/RHT transpose cache) there is no validation that these tensors were quantized with the expected axis (axis=0) and expected RHT scaling recipe. A mismatched cache would silently produce wrong gradients. Consider validatinga_t/b_tviacheck_quantized_tensor(..., axis=0, scaling_recipe=...)before saving/using them.
# If the caller pre-quantized this and passed it via ``a_t`` / ``b_t``,
# reuse it directly; otherwise derive it from the forward fp4 tensor.
if a_t is not None:
quantized_a_t = a_t
else:
| def gemm_fp8( | ||
| a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| a: Union[torch.Tensor, QuantizedTensorPair], |
There was a problem hiding this comment.
Union[torch.Tensor, QuantizedTensor, QuantizedTensorPair]
| "autograd.Function.forward and reads their inner tensors (data / scale_inv). " | ||
| "Dynamo cannot recover Python sources for those graph-internal inner tensors, " | ||
| ), | ||
| ) |
| if isinstance(a, QuantizedTensor): | ||
| assert a._is_grouped_tensor, "A QuantizedTensor input must be a grouped tensor" | ||
| check_quantized_tensor(a, config, axis=-1) | ||
| assert torch.equal(a.group_lens, group_lens), "a.group_lens must match the given group_lens" |
There was a problem hiding this comment.
Will it cause synchronization?


Description
This PR introduces a unified
QuantizedTensorwrapper class as the canonical container for all FP8 / FP4 low-precision paths. As part of the work, the FP8/FP4 quantization kernels are reorganized, a new rowwise dequantize C++ binding and an optimized col-major dequant kernel are added, and all GEMM / grouped GEMM / MoE permutation ops are migrated to the newQuantizedTensorAPI.Motivation:
Float8Tensorwas still under development, did not cover FP4 or the full set of granularities (rowwise / tensorwise / MX-blockwise), and could not be passed as a first-class citizen through grouped GEMM / MoE permutation paths.quantization.culumped tensorwise / rowwise / MXFP8 / MXFP4 into a single TU, leaving sharedQuantOp/amax→scalehelpers tightly coupled and hard to reuse or extend.scale_invconsumed a non-trivial fraction of bandwidth on tall-M shapes.Fixes # (issue)
Type of change
Changes
New unified
QuantizedTensorwrapperprimus_turbo/pytorch/core/quantized_tensor.py:QuantizedTensor: a wrapper subclass built ontorch.Tensor._make_wrapper_subclass, carrying_data/_scale_invtogether with metadata such asorig_dtype/dest_dtype/granularity/block_size/scaling_recipe/quantized_axis, plusgroup_lens/group_offsfor the grouped case.QuantizedTensor.quantize(...)(auto-dispatched toquantize_fp8/quantize_fp4), as well asdequantize(), autograd-awareview/reshape(no dequant round-trip), and__tensor_flatten__/__tensor_unflatten__(compatible withtorch.compile/ FSDP).QuantizedTensorPairandcheck_quantized_tensorto help op-side enforce config consistency.primus_turbo/pytorch/core/float8_tensor.py(Float8Tensor).primus_turbo/pytorch/core/__init__.py.Quantization C++ kernel refactor + new ops
csrc/kernels/quantization/quantization.cuinto:csrc/kernels/quantization/quantization_rowwise.cu: rowwise FP8 quantize / dequantize (both row-major and col-major layouts).csrc/kernels/quantization/quantization_tensorwise.cu: tensorwise FP8 quantize / dequantize, and the single explicit instantiation ofcompute_scale_from_amax<float>.csrc/include/primus_turbo/device/quant_utils.cuh:QuantOpBase/QuantOpand the host/device implementations ofcompute_scale_from_amax, shared across rowwise and tensorwise paths.csrc/include/primus_turbo/quantization.h: declare the newdequantize_rowwise_row_major_impl/dequantize_rowwise_col_major_implinterfaces.csrc/pytorch/quantization/quantization.cpp+bindings_pytorch.cpp:dequantize_fp8_rowwise(Tensor, Tensor, int axis, ScalarType) -> Tensor, supporting rowwise dequant for arbitraryaxis(auto-dispatched between row-major and col-major impls).padding_align_sizeargument to allquantize_mxfp4{,_dual}/quantize_mxfp8{,_dual}schemas, with C++ sanity checks (must equalMXFP{4,8}_PADDING_ALIGN_SIZE) so the alignment constant is no longer hard-coded inside the kernels.csrc/pytorch/quantization/quantization_meta.cpp: add the corresponding meta implementations.Rowwise dequantize kernel optimization (
f1cfe45)dequantize_rowwise_col_major_kernelto use 2D tiling (TILE_INNER × ELEMS_PER_THREAD): each block now loadsELEMS_PER_THREADscale_invvalues into registers once and reuses themTILE_INNERtimes along the M direction, significantly reducing redundantscale_invtraffic.TILE_INNER/ threads-per-row are chosen adaptively based onnext_pow2(inner_len)to improve occupancy.Op-level adoption of
QuantizedTensorprimus_turbo/pytorch/ops/gemm_fp8.py/gemm_fp4.py/grouped_gemm_fp8.py/grouped_gemm.py/gemm.py:QuantizedTensor(or auto-quantize plain tensors) and validate config / axis / scaling_recipe throughcheck_quantized_tensor.-1= inner-K,-2= inner-M); grouped paths consistently propagategroup_lens/group_offs.primus_turbo/pytorch/ops/quantization.py/kernels/quantization/quantization_impl.py: expose the rowwisedequantize_fp8path and adapt to the new binding.primus_turbo/pytorch/ops/moe/permutation.py,fused_moe_router.py,indices_converter.py,moe_dispatch_combine.py: passQuantizedTensortransparently through MoE paths (grouped wrappers preserve their grouped identity across view / reshape).primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_utils.py/grouped_gemm_fp8_impl.py: align the helpers with the new wrapper semantics.Tests
tests/pytorch/core/test_quantized_tensor.py: covers quantize / dequantize / view / reshape / flatten–unflatten / grouped-wrapper behavior across TENSORWISE / ROWWISE / MX_BLOCKWISE granularities and across FP8 / MXFP4 / MXFP8 dtype combinations.tests/pytorch/ops/test_gemm_fp8.py/test_gemm_fp4.py/test_grouped_gemm_fp8.py/test_grouped_gemm.py/test_quantization.py/test_linear_fp8.py: extend end-to-end tests for the newQuantizedTensorinput / output paths.CI
.github/workflows/ci.yaml: bump the single-GPU test job'stimeout_minutesfrom 95 to 135 (andtimeout -k 60s 130maccordingly) to accommodate the additional runtime introduced by the newQuantizedTensortest suite.Checklist: