Skip to content

Tune gfx950 batched GEMM configs and benchmarks#3283

Draft
nidal567 wants to merge 1 commit into
ROCm:mainfrom
nidal567:golden-tot-batched-gemm-perf
Draft

Tune gfx950 batched GEMM configs and benchmarks#3283
nidal567 wants to merge 1 commit into
ROCm:mainfrom
nidal567:golden-tot-batched-gemm-perf

Conversation

@nidal567
Copy link
Copy Markdown
Contributor

Motivation

This PR improves gfx950 batched GEMM benchmarking coverage and tunes several batched GEMM config files to reduce TOT (triton upstream) compiler regressions (when compared against our baseline golden compiler stack)

The regression methodology follows the agreed timing rule:

Ignore shapes where golden_time < 4 us.
Otherwise, flag regression iff:
tot_time > 1.005 * golden_time + 0.5 us

Timing is measured with rocprofv3, using repeated runs and the 5% stability heuristic.

Before this tuning work, the batched GEMM comparison had 22 regressions in total. After the config tuning and benchmark coverage updates in this PR, the full comparison is down to 14 regressions across the 5-kernel batched GEMM suite (and continuous work will be done to drop this number over time as this is ongoing progress). For the raw numbers and status of each shape and individual test case, refer to the performance spreadsheet.

Technical Details

Benchmark/config coverage

Kernel Benchmark file? Config files?
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py Yes — added Yes
batched_gemm_a8w8.py Yes Yes
batched_gemm_a16wfp4.py Yes Yes — and additionally added gfx950 config + per-(N,K) config
batched_gemm_afp4wfp4.py Yes Yes
batched_gemm_bf16.py Yes — renamed from bench_batched_gemm_a16w16.py because it calls the BF16 kernel Yes — routed through existing A16W16 config

Benchmark Updates

This PR adds a dedicated benchmark for:

op_tests/op_benchmarks/triton/bench_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py

It also renames the previous A16W16 benchmark to BF16:

op_tests/op_benchmarks/triton/bench_batched_gemm_a16w16.py
to op_tests/op_benchmarks/triton/bench_batched_gemm_bf16.py

BF16 config routing

batched_gemm_bf16.py continues to use the existing A16W16 config family:

gfx950-BATCHED_GEMM-A16W16.json

This keeps behavior consistent with the existing kernel/config naming, while making the benchmark naming reflect the actual BF16 kernel.

Config Tuning Work

1. batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant

This PR tunes the gfx950 configs for:

batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant

Touched config files:

gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json
gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json
gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json

This was done in three config-only tuning passes.

Regression count for this kernel:

Before tuning: 11 / 15 regressions
After pass 1:   9 / 15
After pass 2:   7 / 15
After pass 3:   4 / 15

The default-config pass fixed the large N=8192, K=8192 group for:

M=64
M=128
M=256

The primary default-config target improved from roughly:

M=128, N=8192, K=8192
TOT: ~7905 us → ~3680 us
~53% reduction in TOT time (can be seen in the spreadsheet)

Small-M buckets that were harmed by the winning config were intentionally left untouched such that no regressions would be created from these changes.

2. batched_gemm_a16wfp4

This PR adds the base gfx950 config:

aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4.json

It also adds a per-(N,K) specialization for the N=128, K=512 family:

aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4-N=128-K=512.json

A default-config edit was tested first, but rejected because it introduced regressions in the previously OK N=512, K=128 family. The per-(N,K) specialization fixes the target family while preserving the previously working ones.

The new per-(N,K) file uses a hybrid bucket layout:

Bucket Source
M_LEQ_16 copied from default
M_LEQ_32 copied from default
M_LEQ_64 tuned winner
M_LEQ_128 tuned winner
M_LEQ_256 tuned winner
any tuned winner

The "tuned winner" config uses:

{
  "BLOCK_SIZE_M": 64,
  "BLOCK_SIZE_N": 128,
  "BLOCK_SIZE_K": 256,
  "GROUP_SIZE_M": 1,
  "num_warps": 8,
  "num_stages": 2,
  "waves_per_eu": 4,
  "matrix_instr_nonkdim": 16,
  "kpack": 1,
  "cache_modifier": ".cg",
  "NUM_KSPLIT": 1
}

Test Plan

Full batched GEMM benchmark/compare was run across 5 kernels and 61 shapes.

Commands:

bash /workspace/run_bench_golden.sh
bash /workspace/run_bench_tot.sh

python3 /workspace/bench_kernels.py compare \
  --golden /workspace/results/golden/results_golden.csv \
  --tot /workspace/results/tot/results_tot.csv \
  --out /workspace/results/compare.csv

These wrapper scripts run the batched GEMM benchmark suite through /workspace/bench_kernels.py.

run_bench_golden.sh runs the benchmark suite with the golden environment, using the golden Python/Triton stack.
run_bench_tot.sh runs the same benchmark suite with the TOT environment, using the latest TOT Python/Triton stack.
Both scripts use the same benchmark driver and shape numbers so the results are comparable.

Outputs are written to:

/workspace/results/golden/results_golden.csv
/workspace/results/tot/results_tot.csv

The comparison command is:

python3 /workspace/compare_envs.py \
    --golden /workspace/results/golden/results_golden.csv \
    --tot    /workspace/results/tot/results_tot.csv \
    --out    /workspace/results/comparison.csv

This reads the golden and TOT results, joins matching rows by kernel and shape, computes timing deltas, and labels each row using the regression rule:

ignore golden < 4 us
otherwise regression iff tot_us > 1.005 * golden_us + 0.5 us

The final comparison table is written to:
/workspace/results/compare.csv which closely resembles the values in the spreadsheet when not parsed.

Environment Information:

The comparison was run between the golden compiler stack and the TOT compiler stack.

Golden

{
  "env": "golden",
  "captured": "2026-05-06T23:14:49",
  "aiter_commit": "226c2a87b",
  "aiter_branch": "aiter-baseline-226c2a87",
  "triton_version": "3.4.0",
  "rocm_version": "7.2.0",
  "python": "/workspace/venvs/golden/bin/python",
  "async_copy": "0"
}

Note: when running tests to compare with TOT, I run the golden compiler stack with present aiter so that the latest changes are all included for testing with respect to both compilers, but I have the baseline to benchmark any test or kernel work that was done before any of the ASYNC_COPY changes described in my earlier GitHub Issue.

TOT

{
  "env": "tot",
  "captured": "2026-05-06T23:17:29",
  "aiter_commit": "1a7f97205",
  "aiter_branch": "main",
  "triton_version": "3.7.0",
  "rocm_version": "7.2.0",
  "python": "/workspace/venvs/tot/bin/python",
  "async_copy": "1"
}

Timing methodology

rocprofv3 --kernel-trace --stats
time-only metric
3-run avg / 5%-stability heuristic
regression iff tot_us > 1.005 * golden_us + 0.5 us
ignore golden_us < 4 us

Test Result

Overall regression count (subject to change over time):

Before tuning: 22 regressions
After tuning: 14 regressions

Be sure to check the spreadsheet for all the raw numbers and status change per shape run.

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3283 --add-label <label>

Copy link
Copy Markdown
Contributor

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some comments as a 1st review round.

mem_read = (
x.numel() * x.element_size()
+ w.numel() * w.element_size()
+ bias.numel() * bias.element_size()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is bias tensor always present? If this is the case then we're in a good position. Otherwise, bias.something() may trigger an exception. I don't know, maybe the benchmark script has a option to enable/disable bias. Please double check it.

PS: You don't need to add an option to enable/disable bias if we don't have such a thing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the current benchmark path in this kernel and bias is always generated and passed today:

x, w, bias, y = generate_batched_gemm_a16w16_inputs(...)
batched_gemm_bf16(x, w, bias, c_dtype, YQ=y)

There is no --no-bias / no_bias path in this benchmark right now, so it does not fail.
However, since the wrapper supports optional bias, I agree it is better to change it to be defensive in the case that it gets changed in the future. I can update it to:

+ (bias.numel() * bias.element_size() if bias is not None else 0)

x.numel() * x.element_size()
+ weight.numel() * weight.element_size()
+ w_scale.numel() * w_scale.element_size()
+ (bias.numel() * bias.element_size() if bias is not None else 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment in bench_batched_gemm_bf16.py was motivated by a situation like this one.

@brunomazzottiamd
Copy link
Copy Markdown
Contributor

Cleanup suggestion: Remove kpack: 1 from the gfx950 config files you're touching in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants