Tune gfx950 batched GEMM configs and benchmarks#3283
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
brunomazzottiamd
left a comment
There was a problem hiding this comment.
Leaving some comments as a 1st review round.
| mem_read = ( | ||
| x.numel() * x.element_size() | ||
| + w.numel() * w.element_size() | ||
| + bias.numel() * bias.element_size() |
There was a problem hiding this comment.
Is bias tensor always present? If this is the case then we're in a good position. Otherwise, bias.something() may trigger an exception. I don't know, maybe the benchmark script has a option to enable/disable bias. Please double check it.
PS: You don't need to add an option to enable/disable bias if we don't have such a thing.
There was a problem hiding this comment.
I checked the current benchmark path in this kernel and bias is always generated and passed today:
x, w, bias, y = generate_batched_gemm_a16w16_inputs(...)
batched_gemm_bf16(x, w, bias, c_dtype, YQ=y)
There is no --no-bias / no_bias path in this benchmark right now, so it does not fail.
However, since the wrapper supports optional bias, I agree it is better to change it to be defensive in the case that it gets changed in the future. I can update it to:
+ (bias.numel() * bias.element_size() if bias is not None else 0)
| x.numel() * x.element_size() | ||
| + weight.numel() * weight.element_size() | ||
| + w_scale.numel() * w_scale.element_size() | ||
| + (bias.numel() * bias.element_size() if bias is not None else 0) |
There was a problem hiding this comment.
My comment in bench_batched_gemm_bf16.py was motivated by a situation like this one.
|
Cleanup suggestion: Remove |
Motivation
This PR improves gfx950 batched GEMM benchmarking coverage and tunes several batched GEMM config files to reduce TOT (triton upstream) compiler regressions (when compared against our baseline golden compiler stack)
The regression methodology follows the agreed timing rule:
Timing is measured with rocprofv3, using repeated runs and the 5% stability heuristic.
Before this tuning work, the batched GEMM comparison had 22 regressions in total. After the config tuning and benchmark coverage updates in this PR, the full comparison is down to 14 regressions across the 5-kernel batched GEMM suite (and continuous work will be done to drop this number over time as this is ongoing progress). For the raw numbers and status of each shape and individual test case, refer to the performance spreadsheet.
Technical Details
Benchmark/config coverage
Benchmark Updates
This PR adds a dedicated benchmark for:
op_tests/op_benchmarks/triton/bench_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.pyIt also renames the previous
A16W16benchmark toBF16:op_tests/op_benchmarks/triton/bench_batched_gemm_a16w16.pyto
op_tests/op_benchmarks/triton/bench_batched_gemm_bf16.pyBF16 config routing
batched_gemm_bf16.pycontinues to use the existing A16W16 config family:gfx950-BATCHED_GEMM-A16W16.jsonThis keeps behavior consistent with the existing kernel/config naming, while making the benchmark naming reflect the actual BF16 kernel.
Config Tuning Work
1.
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quantThis PR tunes the gfx950 configs for:
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quantTouched config files:
This was done in three config-only tuning passes.
Regression count for this kernel:
The default-config pass fixed the large
N=8192, K=8192group for:The primary default-config target improved from roughly:
Small-M buckets that were harmed by the winning config were intentionally left untouched such that no regressions would be created from these changes.
2.
batched_gemm_a16wfp4This PR adds the base gfx950 config:
aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4.jsonIt also adds a per-(N,K) specialization for the
N=128, K=512family:aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4-N=128-K=512.jsonA default-config edit was tested first, but rejected because it introduced regressions in the previously OK N=512, K=128 family. The per-(N,K) specialization fixes the target family while preserving the previously working ones.
The new per-(N,K) file uses a hybrid bucket layout:
The "tuned winner" config uses:
{ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 1, "cache_modifier": ".cg", "NUM_KSPLIT": 1 }Test Plan
Full batched GEMM benchmark/compare was run across 5 kernels and 61 shapes.
Commands:
These wrapper scripts run the batched GEMM benchmark suite through /workspace/bench_kernels.py.
run_bench_golden.shruns the benchmark suite with the golden environment, using the golden Python/Triton stack.run_bench_tot.shruns the same benchmark suite with the TOT environment, using the latest TOT Python/Triton stack.Both scripts use the same benchmark driver and shape numbers so the results are comparable.
Outputs are written to:
The comparison command is:
python3 /workspace/compare_envs.py \ --golden /workspace/results/golden/results_golden.csv \ --tot /workspace/results/tot/results_tot.csv \ --out /workspace/results/comparison.csvThis reads the golden and TOT results, joins matching rows by kernel and shape, computes timing deltas, and labels each row using the regression rule:
The final comparison table is written to:
/workspace/results/compare.csvwhich closely resembles the values in the spreadsheet when not parsed.Environment Information:
The comparison was run between the golden compiler stack and the TOT compiler stack.
Golden
{ "env": "golden", "captured": "2026-05-06T23:14:49", "aiter_commit": "226c2a87b", "aiter_branch": "aiter-baseline-226c2a87", "triton_version": "3.4.0", "rocm_version": "7.2.0", "python": "/workspace/venvs/golden/bin/python", "async_copy": "0" }Note: when running tests to compare with TOT, I run the golden compiler stack with present aiter so that the latest changes are all included for testing with respect to both compilers, but I have the baseline to benchmark any test or kernel work that was done before any of the
ASYNC_COPYchanges described in my earlier GitHub Issue.TOT
{ "env": "tot", "captured": "2026-05-06T23:17:29", "aiter_commit": "1a7f97205", "aiter_branch": "main", "triton_version": "3.7.0", "rocm_version": "7.2.0", "python": "/workspace/venvs/tot/bin/python", "async_copy": "1" }Timing methodology
Test Result
Overall regression count (subject to change over time):
Before tuning: 22 regressions
After tuning: 14 regressions
Be sure to check the spreadsheet for all the raw numbers and status change per shape run.
Submission Checklist