Required prerequisites
Motivation
mxfp/nvfp types with group scaling have been widely used in efficient AI training/inference. Currently TileLang support dequantgemm gemm for mxfp/nvfp via storing them in uint8, and requires to perform manual SIMT scaling and T.gemm, which maybe kinda compilcated to upper users. What's more, hardware native blockscaled gemm is supported on modern GPU architecture, e.g. SM100.
Solution
I propose that we may expose a unified T.blockscaled_gemm API in the frontend, which will be lowered to native blockscaled ptx instructions on supported architecture. For older architecture, we can implement this by appending SIMT scaling operations to mma macros.
Alternatives
No response
Additional context
No response
Required prerequisites
Motivation
mxfp/nvfp types with group scaling have been widely used in efficient AI training/inference. Currently TileLang support dequantgemm gemm for mxfp/nvfp via storing them in
uint8, and requires to perform manual SIMT scaling andT.gemm, which maybe kinda compilcated to upper users. What's more, hardware native blockscaled gemm is supported on modern GPU architecture, e.g. SM100.Solution
I propose that we may expose a unified
T.blockscaled_gemmAPI in the frontend, which will be lowered to native blockscaled ptx instructions on supported architecture. For older architecture, we can implement this by appending SIMT scaling operations to mma macros.Alternatives
No response
Additional context
No response