[Triton][CDNA4] Optimize gluon blockscale a8w8 gemm kernel#3307
Open
lijinpei-amd wants to merge 2 commits into
Open
[Triton][CDNA4] Optimize gluon blockscale a8w8 gemm kernel#3307lijinpei-amd wants to merge 2 commits into
lijinpei-amd wants to merge 2 commits into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
69aed97 to
0f5e60c
Compare
* test_gemm_a8w8_blockscale: enable the "gluon" parametrize entry and
add small-K shapes (K in {128, 192, 256, 320}) that exercise the
wind-down's num_k_iter guards.
* bench_gemm_a8w8_blockscale: add a -test flag that runs each
benchmarked shape against a torch reference via checkAllclose.
Reimplement the gluon a8w8 blockscale kernel around
gl.amd.cdna4.mfma_scaled with an explicit async-copy / LDS multi-buffer
pipeline.
* Split the main loop into an aligned-K body (EVEN_K=True
_prefetch_tensors) plus a statically unrolled wind-down for the masked
tail iterations.
* Runtime-guard the wind-down iters for small num_k_iter so the Final
iter is the only MFMA that runs when K is short.
* In the main loop, commit the prefetch group before loading scales so
the compiler schedules buffer_load earlier in the iteration.
* Refresh tuning configs for gfx950.
perf on MI350:
python3 bench_gemm_a8w8_blockscale.py -gluon
bench_gemm_a8w8_blockscale:
M N K TFLOPS (Throughput (TFLOPS))
0 1.0 1280.0 8192.0 0.604139
1 32.0 1280.0 8192.0 19.064667
2 64.0 1280.0 8192.0 37.522605
3 128.0 1280.0 8192.0 100.565860
4 192.0 1280.0 8192.0 69.512152
5 256.0 1280.0 8192.0 89.348881
6 320.0 1280.0 8192.0 115.422745
7 512.0 1280.0 8192.0 175.689190
8 1024.0 1280.0 8192.0 345.129363
9 2048.0 1280.0 8192.0 677.299835
10 4096.0 1280.0 8192.0 863.537762
11 8192.0 1280.0 8192.0 887.143030
12 16384.0 1280.0 8192.0 1164.919752
13 4096.0 4096.0 4096.0 1271.401835
14 4096.0 4096.0 4160.0 1076.085957
python3 bench_gemm_a8w8_blockscale.py
bench_gemm_a8w8_blockscale:
M N K TFLOPS (Throughput (TFLOPS))
0 1.0 1280.0 8192.0 0.455752
1 32.0 1280.0 8192.0 13.141420
2 64.0 1280.0 8192.0 24.324535
3 128.0 1280.0 8192.0 51.085179
4 192.0 1280.0 8192.0 85.387665
5 256.0 1280.0 8192.0 109.271191
6 320.0 1280.0 8192.0 138.334302
7 512.0 1280.0 8192.0 218.300780
8 1024.0 1280.0 8192.0 172.178122
9 2048.0 1280.0 8192.0 341.678502
10 4096.0 1280.0 8192.0 670.851040
11 8192.0 1280.0 8192.0 683.083809
12 16384.0 1280.0 8192.0 899.010470
13 4096.0 4096.0 4096.0 1013.235796
14 4096.0 4096.0 4160.0 862.656740
python3 bench_gemm_a8w8_blockscale.py -gluon and some non-upstream llvm hack
bench_gemm_a8w8_blockscale:
M N K TFLOPS (Throughput (TFLOPS))
0 1.0 1280.0 8192.0 0.554379
1 32.0 1280.0 8192.0 17.488356
2 64.0 1280.0 8192.0 34.616803
3 128.0 1280.0 8192.0 89.166024
4 192.0 1280.0 8192.0 73.594313
5 256.0 1280.0 8192.0 97.149177
6 320.0 1280.0 8192.0 121.220295
7 512.0 1280.0 8192.0 192.800737
8 1024.0 1280.0 8192.0 379.143869
9 2048.0 1280.0 8192.0 742.228581
10 4096.0 1280.0 8192.0 921.613818
11 8192.0 1280.0 8192.0 957.073812
12 16384.0 1280.0 8192.0 1237.227918
13 4096.0 4096.0 4096.0 1449.159948
14 4096.0 4096.0 4160.0 1332.912692
0f5e60c to
84d30c3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Perf improvement for mi350/355 gluon blockscale a8w8 gemm kernel. Compared to the triton version, the gluon version also pipeline scale factor load, which triton compiler currently doesn't do. Also scale factor and a/b operand lds loads are different stage, which is hard for triton-compiler to figure out, but easy for kernel writing. Compared to the old gluon version, add pipeline through lds.
perf on MI350:
python3 bench_gemm_a8w8_blockscale.py -gluon
python3 bench_gemm_a8w8_blockscale.py
python3 bench_gemm_a8w8_blockscale.py -gluon and some non-upstream llvm hack