Skip to content

feat(quantization): add fused amax+scale FP8 tensorwise quantize#322

Draft
jasainio wants to merge 1 commit into
mainfrom
opt/quantize/fused-amax-scale
Draft

feat(quantization): add fused amax+scale FP8 tensorwise quantize#322
jasainio wants to merge 1 commit into
mainfrom
opt/quantize/fused-amax-scale

Conversation

@jasainio
Copy link
Copy Markdown
Contributor

Fuse the abs-max reduction and scale computation into a single final-round kernel, reducing total kernel launches from 3 to 2 for the tensorwise FP8 quantization path.

  • Add reduce_amax_final_scale_kernel that computes scale = fp8_max / amax and scale_inv = amax / fp8_max directly in the reduction epilogue
  • Add reduce_amax_and_compute_scale dispatcher with multi-round support for large tensors (reuses existing reduce_row_kernel for intermediate rounds)
  • Register quantize_fp8_tensorwise_fused op in PyTorch bindings (CUDA + Meta)
  • Expose quantize_fp8_fused() public Python API with fallback for non-tensorwise granularities
  • Add tests for fused path: tensorwise + amax correctness (66 cases)
  • Add benchmark script comparing original 3-kernel vs fused 2-kernel latency

Description

Fuse the abs-max reduction and scale computation into a single HIP kernel for the tensorwise FP8 quantization path. The original path required 3 kernel launches (reduce_row -> compute_scale_from_amax -> quantize_tensorwise_impl). The fused path eliminates the standalone scale computation by folding it into the final reduction round, reducing the total to 2 kernel launches (reduce_amax_and_compute_scale -> quantize_tensorwise_impl). For large tensors that require multi-round reduction, intermediate rounds reuse the existing reduce_row_kernel, and only the final round uses the fused kernel. A workspace size debug assertion and ping/pong buffer aliasing safety are documented in the kernel code.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add reduce_amax_final_scale_kernel in csrc/kernels/quantization/quantization.cu that fuses abs-max reduction with scale/scale_inv computation in the epilogue
  • Add reduce_amax_and_compute_scale host dispatcher with single-tile fast path and multi-round support using ping/pong workspace buffers
  • Add workspace size debug assertion and ping/pong aliasing safety comment for multi-round reduction
  • Declare reduce_amax_and_compute_scale template in csrc/include/primus_turbo/quantization.h
  • Add quantize_fp8_tensorwise_fused C++ implementation in csrc/pytorch/quantization/quantization.cpp with workspace allocation
  • Register quantize_fp8_tensorwise_fused op in csrc/pytorch/bindings_pytorch.cpp (CUDA + Meta) and csrc/pytorch/extensions.h
  • Add quantize_fp8_tensorwise_fused_impl Python wrapper in primus_turbo/pytorch/kernels/quantization/quantization_impl.py
  • Expose quantize_fp8_fused() public API in primus_turbo/pytorch/ops/quantization.py with fallback to quantize_fp8 for non-tensorwise granularities
  • Add test_quantize_fp8_tensorwise_fused (36 cases: 3 dtypes x 2 FP8 types x 3 numel sizes x 2 torch_compile modes) and test_quantize_fp8_tensorwise_fused_amax_correctness (30 cases: partial tile and spike-position regression)

@jasainio jasainio marked this pull request as draft May 3, 2026 05:53
@jasainio jasainio force-pushed the opt/quantize/fused-amax-scale branch from feb2557 to 33f80d0 Compare May 6, 2026 06:11
Fuse the abs-max reduction and scale computation into a single
final-round kernel, reducing total kernel launches from 3 to 2 for
the tensorwise FP8 quantization path.

- Add reduce_amax_final_scale_kernel that computes scale = fp8_max / amax
  and scale_inv = amax / fp8_max directly in the reduction epilogue
- Add reduce_amax_and_compute_scale dispatcher with multi-round support
  for large tensors (reuses existing reduce_row_kernel for intermediate
  rounds)
- Register quantize_fp8_tensorwise_fused op in PyTorch bindings (CUDA + Meta)
- Expose quantize_fp8_fused() public Python API with fallback for
  non-tensorwise granularities
- Add tests for fused path: tensorwise + amax correctness (42 cases)
@jasainio jasainio force-pushed the opt/quantize/fused-amax-scale branch from 33f80d0 to fd70f19 Compare May 11, 2026 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant