Skip to content

[Triton][CDNA4] Optimize gluon blockscale a8w8 gemm kernel#3307

Open
lijinpei-amd wants to merge 2 commits into
ROCm:mainfrom
lijinpei-amd:review/2026-05-22/gluon-blockscale-wind-down-unroll
Open

[Triton][CDNA4] Optimize gluon blockscale a8w8 gemm kernel#3307
lijinpei-amd wants to merge 2 commits into
ROCm:mainfrom
lijinpei-amd:review/2026-05-22/gluon-blockscale-wind-down-unroll

Conversation

@lijinpei-amd
Copy link
Copy Markdown
Contributor

Perf improvement for mi350/355 gluon blockscale a8w8 gemm kernel. Compared to the triton version, the gluon version also pipeline scale factor load, which triton compiler currently doesn't do. Also scale factor and a/b operand lds loads are different stage, which is hard for triton-compiler to figure out, but easy for kernel writing. Compared to the old gluon version, add pipeline through lds.

perf on MI350:

python3 bench_gemm_a8w8_blockscale.py -gluon

bench_gemm_a8w8_blockscale:
          M       N       K  TFLOPS (Throughput (TFLOPS))
0       1.0  1280.0  8192.0                      0.604139
1      32.0  1280.0  8192.0                     19.064667
2      64.0  1280.0  8192.0                     37.522605
3     128.0  1280.0  8192.0                    100.565860
4     192.0  1280.0  8192.0                     69.512152
5     256.0  1280.0  8192.0                     89.348881
6     320.0  1280.0  8192.0                    115.422745
7     512.0  1280.0  8192.0                    175.689190
8    1024.0  1280.0  8192.0                    345.129363
9    2048.0  1280.0  8192.0                    677.299835
10   4096.0  1280.0  8192.0                    863.537762
11   8192.0  1280.0  8192.0                    887.143030
12  16384.0  1280.0  8192.0                   1164.919752
13   4096.0  4096.0  4096.0                   1271.401835
14   4096.0  4096.0  4160.0                   1076.085957

python3 bench_gemm_a8w8_blockscale.py

bench_gemm_a8w8_blockscale:
          M       N       K  TFLOPS (Throughput (TFLOPS))
0       1.0  1280.0  8192.0                      0.455752
1      32.0  1280.0  8192.0                     13.141420
2      64.0  1280.0  8192.0                     24.324535
3     128.0  1280.0  8192.0                     51.085179
4     192.0  1280.0  8192.0                     85.387665
5     256.0  1280.0  8192.0                    109.271191
6     320.0  1280.0  8192.0                    138.334302
7     512.0  1280.0  8192.0                    218.300780
8    1024.0  1280.0  8192.0                    172.178122
9    2048.0  1280.0  8192.0                    341.678502
10   4096.0  1280.0  8192.0                    670.851040
11   8192.0  1280.0  8192.0                    683.083809
12  16384.0  1280.0  8192.0                    899.010470
13   4096.0  4096.0  4096.0                   1013.235796
14   4096.0  4096.0  4160.0                    862.656740

python3 bench_gemm_a8w8_blockscale.py -gluon and some non-upstream llvm hack

bench_gemm_a8w8_blockscale:
          M       N       K  TFLOPS (Throughput (TFLOPS))
0       1.0  1280.0  8192.0                      0.554379
1      32.0  1280.0  8192.0                     17.488356
2      64.0  1280.0  8192.0                     34.616803
3     128.0  1280.0  8192.0                     89.166024
4     192.0  1280.0  8192.0                     73.594313
5     256.0  1280.0  8192.0                     97.149177
6     320.0  1280.0  8192.0                    121.220295
7     512.0  1280.0  8192.0                    192.800737
8    1024.0  1280.0  8192.0                    379.143869
9    2048.0  1280.0  8192.0                    742.228581
10   4096.0  1280.0  8192.0                    921.613818
11   8192.0  1280.0  8192.0                    957.073812
12  16384.0  1280.0  8192.0                   1237.227918
13   4096.0  4096.0  4096.0                   1449.159948
14   4096.0  4096.0  4160.0                   1332.912692

@lijinpei-amd lijinpei-amd requested a review from a team May 21, 2026 20:38
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3307 --add-label <label>

@lijinpei-amd lijinpei-amd force-pushed the review/2026-05-22/gluon-blockscale-wind-down-unroll branch from 69aed97 to 0f5e60c Compare May 22, 2026 07:02
* test_gemm_a8w8_blockscale: enable the "gluon" parametrize entry and
  add small-K shapes (K in {128, 192, 256, 320}) that exercise the
  wind-down's num_k_iter guards.
* bench_gemm_a8w8_blockscale: add a -test flag that runs each
  benchmarked shape against a torch reference via checkAllclose.
Reimplement the gluon a8w8 blockscale kernel around
gl.amd.cdna4.mfma_scaled with an explicit async-copy / LDS multi-buffer
pipeline.

* Split the main loop into an aligned-K body (EVEN_K=True
  _prefetch_tensors) plus a statically unrolled wind-down for the masked
  tail iterations.
* Runtime-guard the wind-down iters for small num_k_iter so the Final
  iter is the only MFMA that runs when K is short.
* In the main loop, commit the prefetch group before loading scales so
  the compiler schedules buffer_load earlier in the iteration.
* Refresh tuning configs for gfx950.

perf on MI350:

python3 bench_gemm_a8w8_blockscale.py  -gluon

bench_gemm_a8w8_blockscale:
          M       N       K  TFLOPS (Throughput (TFLOPS))
0       1.0  1280.0  8192.0                      0.604139
1      32.0  1280.0  8192.0                     19.064667
2      64.0  1280.0  8192.0                     37.522605
3     128.0  1280.0  8192.0                    100.565860
4     192.0  1280.0  8192.0                     69.512152
5     256.0  1280.0  8192.0                     89.348881
6     320.0  1280.0  8192.0                    115.422745
7     512.0  1280.0  8192.0                    175.689190
8    1024.0  1280.0  8192.0                    345.129363
9    2048.0  1280.0  8192.0                    677.299835
10   4096.0  1280.0  8192.0                    863.537762
11   8192.0  1280.0  8192.0                    887.143030
12  16384.0  1280.0  8192.0                   1164.919752
13   4096.0  4096.0  4096.0                   1271.401835
14   4096.0  4096.0  4160.0                   1076.085957

python3 bench_gemm_a8w8_blockscale.py

bench_gemm_a8w8_blockscale:
          M       N       K  TFLOPS (Throughput (TFLOPS))
0       1.0  1280.0  8192.0                      0.455752
1      32.0  1280.0  8192.0                     13.141420
2      64.0  1280.0  8192.0                     24.324535
3     128.0  1280.0  8192.0                     51.085179
4     192.0  1280.0  8192.0                     85.387665
5     256.0  1280.0  8192.0                    109.271191
6     320.0  1280.0  8192.0                    138.334302
7     512.0  1280.0  8192.0                    218.300780
8    1024.0  1280.0  8192.0                    172.178122
9    2048.0  1280.0  8192.0                    341.678502
10   4096.0  1280.0  8192.0                    670.851040
11   8192.0  1280.0  8192.0                    683.083809
12  16384.0  1280.0  8192.0                    899.010470
13   4096.0  4096.0  4096.0                   1013.235796
14   4096.0  4096.0  4160.0                    862.656740

python3 bench_gemm_a8w8_blockscale.py  -gluon and some non-upstream llvm hack

bench_gemm_a8w8_blockscale:
          M       N       K  TFLOPS (Throughput (TFLOPS))
0       1.0  1280.0  8192.0                      0.554379
1      32.0  1280.0  8192.0                     17.488356
2      64.0  1280.0  8192.0                     34.616803
3     128.0  1280.0  8192.0                     89.166024
4     192.0  1280.0  8192.0                     73.594313
5     256.0  1280.0  8192.0                     97.149177
6     320.0  1280.0  8192.0                    121.220295
7     512.0  1280.0  8192.0                    192.800737
8    1024.0  1280.0  8192.0                    379.143869
9    2048.0  1280.0  8192.0                    742.228581
10   4096.0  1280.0  8192.0                    921.613818
11   8192.0  1280.0  8192.0                    957.073812
12  16384.0  1280.0  8192.0                   1237.227918
13   4096.0  4096.0  4096.0                   1449.159948
14   4096.0  4096.0  4160.0                   1332.912692
@lijinpei-amd lijinpei-amd force-pushed the review/2026-05-22/gluon-blockscale-wind-down-unroll branch from 0f5e60c to 84d30c3 Compare May 23, 2026 15:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant