feat(quantization): add fused amax+scale FP8 tensorwise quantize#322
Draft
jasainio wants to merge 1 commit into
Draft
feat(quantization): add fused amax+scale FP8 tensorwise quantize#322jasainio wants to merge 1 commit into
jasainio wants to merge 1 commit into
Conversation
feb2557 to
33f80d0
Compare
Fuse the abs-max reduction and scale computation into a single final-round kernel, reducing total kernel launches from 3 to 2 for the tensorwise FP8 quantization path. - Add reduce_amax_final_scale_kernel that computes scale = fp8_max / amax and scale_inv = amax / fp8_max directly in the reduction epilogue - Add reduce_amax_and_compute_scale dispatcher with multi-round support for large tensors (reuses existing reduce_row_kernel for intermediate rounds) - Register quantize_fp8_tensorwise_fused op in PyTorch bindings (CUDA + Meta) - Expose quantize_fp8_fused() public Python API with fallback for non-tensorwise granularities - Add tests for fused path: tensorwise + amax correctness (42 cases)
33f80d0 to
fd70f19
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fuse the abs-max reduction and scale computation into a single final-round kernel, reducing total kernel launches from 3 to 2 for the tensorwise FP8 quantization path.
reduce_amax_final_scale_kernelthat computesscale = fp8_max / amaxandscale_inv = amax / fp8_maxdirectly in the reduction epiloguereduce_amax_and_compute_scaledispatcher with multi-round support for large tensors (reuses existingreduce_row_kernelfor intermediate rounds)quantize_fp8_tensorwise_fusedop in PyTorch bindings (CUDA + Meta)quantize_fp8_fused()public Python API with fallback for non-tensorwise granularitiesDescription
Fuse the abs-max reduction and scale computation into a single HIP kernel for the tensorwise FP8 quantization path. The original path required 3 kernel launches (
reduce_row->compute_scale_from_amax->quantize_tensorwise_impl). The fused path eliminates the standalone scale computation by folding it into the final reduction round, reducing the total to 2 kernel launches (reduce_amax_and_compute_scale->quantize_tensorwise_impl). For large tensors that require multi-round reduction, intermediate rounds reuse the existingreduce_row_kernel, and only the final round uses the fused kernel. A workspace size debug assertion and ping/pong buffer aliasing safety are documented in the kernel code.Type of change
Changes
Please list the changes introduced in this PR:
reduce_amax_final_scale_kernelincsrc/kernels/quantization/quantization.cuthat fuses abs-max reduction with scale/scale_inv computation in the epiloguereduce_amax_and_compute_scalehost dispatcher with single-tile fast path and multi-round support using ping/pong workspace buffersreduce_amax_and_compute_scaletemplate incsrc/include/primus_turbo/quantization.hquantize_fp8_tensorwise_fusedC++ implementation incsrc/pytorch/quantization/quantization.cppwith workspace allocationquantize_fp8_tensorwise_fusedop incsrc/pytorch/bindings_pytorch.cpp(CUDA + Meta) andcsrc/pytorch/extensions.hquantize_fp8_tensorwise_fused_implPython wrapper inprimus_turbo/pytorch/kernels/quantization/quantization_impl.pyquantize_fp8_fused()public API inprimus_turbo/pytorch/ops/quantization.pywith fallback toquantize_fp8for non-tensorwise granularitiestest_quantize_fp8_tensorwise_fused(36 cases: 3 dtypes x 2 FP8 types x 3 numel sizes x 2 torch_compile modes) andtest_quantize_fp8_tensorwise_fused_amax_correctness(30 cases: partial tile and spike-position regression)