[ck gemm a8w8 blockscale] shape-aware kernel selection heuristic for untuned shapes#3316
Open
eppaneamd wants to merge 4 commits into
Open
[ck gemm a8w8 blockscale] shape-aware kernel selection heuristic for untuned shapes#3316eppaneamd wants to merge 4 commits into
eppaneamd wants to merge 4 commits into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a Python-side, shape-aware fallback dispatch for untuned a8w8 blockscale GEMM shapes (intended for gfx942), and updates CK instance codegen to ensure the heuristic target kernels are always compiled and present in the name-keyed C++ lookup table even when tuned CSV coverage is incomplete.
Changes:
- Extend CK codegen to always include a fixed set of “heuristic kernels” in the generated registry alongside tuned CSV-selected kernels.
- Introduce
heuristic_kernels_dictto pin the exact kernel instances the runtime heuristic will reference. - Add a no-tuned-row shape heuristic in
gemm_a8w8_blockscale()that selects a kernelName based on(M×N, K)and calls the CK wrapper with that explicit kernelName.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| csrc/ck_gemm_a8w8_blockscale/gen_instances.py | Merges heuristic kernels into the generated set so name-keyed lookup contains them regardless of tuned CSV install state. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py | Defines heuristic_kernels_dict mapping heuristic labels to specific candidate kernel instances. |
| aiter/ops/gemm_op_a8w8.py | Adds Python-side shape heuristic fallback that chooses a kernelName string when no tuned CSV row matches. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+127
to
+148
| # Kernels always compiled regardless of tuned CSV presence. | ||
| # These are the targets of the shape-aware dispatch heuristic in gemm_op_a8w8.py | ||
| # and must appear in the name-keyed C++ lookup table in every build. Keep in | ||
| # sync with the heuristic if kernel selections change. | ||
| heuristic_kernels_dict = { | ||
| "heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N | ||
| "heuristic_medium_mn": candidate_kernels_dict[ | ||
| 18 | ||
| ], # 64×64 tile, v1 — medium M×N; also K<384 fallback | ||
| "heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N | ||
| "heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N | ||
| # K=128: KPerBLOCK=128 exact fit avoids KPadding overhead; v3 unsupported at K=128. | ||
| "heuristic_k128_xsmall_mn": candidate_kernels_dict[ | ||
| 6 | ||
| ], # 16×64 tile, v1 — K=128, small M×N | ||
| "heuristic_k128_medium_large_mn": candidate_kernels_dict[ | ||
| 11 | ||
| ], # 32×64 tile, v1 — K=128, medium M×N | ||
| "heuristic_k128_xlarge_mn": candidate_kernels_dict[ | ||
| 16 | ||
| ], # 64×64 tile, v1 — K=128, large M×N | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a shape-aware kernel selection heuristic for untuned
a8w8 blockscaleGEMM shapes on gfx942. Without a tuned CSV hit, the runtime falls back to a single hard-coded default (16×128 tile, v1 pipeline), which is well-suited for single-token decode but leaves significant performance on the table for compute-bound prefill and moderate-batch shapes. On a 211-shape diagnostic benchmark spanning 28 (N, K) pairs, the heuristic is 1.61× faster than the upstream fallback on average and closes approximately 96% of the gap to the per-shape tuned optimum. Across a wider dataset of 19,068 shapes from 8 model/TP configurations, the heuristic matches the tuned-optimal kernel in 97.9% of cases.Shape coverage
The 19,068 MNK shapes are constructed from two components:
config.jsonusing vLLM's tensor-parallel sharding rules (ColumnParallelLinear,RowParallelLinear,QKVParallelLinear) and verified against empirical vLLM inference logs. Each model/TP configuration contributes 2–5 distinct NK pairs (28 pairs across 8 configs).Motivation
The upstream fallback saturates memory bandwidth at small M but does not efficiently utilise XDL compute units for larger tiles. Comparing tuned and default kernels across 19,068 shapes (3 FP8 VLMs, TP=2/4/8, gfx942):
Why a heuristic instead of shipping all tuned tables
Tuned tables are specific to a (model, TP, hardware) triple:
TORCH_CHECK(false, ...)at inference time.A shape-aware heuristic selects from a fixed set of kernels compiled unconditionally on every build (via
heuristic_kernels_dict), covers any model without a prior tuning sweep, and degrades gracefully rather than failing when tuning is unavailable.Heuristic design
The heuristic partitions the (M×N, K) space into 7 arms. K=128 is treated separately because KPerBLOCK=128 kernels avoid the KPadding overhead that KPerBLOCK=256 kernels incur at small K, and the v3 pipeline is unsupported at K=128:
The 7 heuristic kernel instances are registered in
heuristic_kernels_dictingemm_a8w8_blockscale_instance.pyand emitted unconditionally bygen_instances.py, ensuring they are always present in the C++ name-keyed lookup table regardless of which tuned CSVs are installed.Validation
Three-way comparison — a tailored diagnostic benchmark of 211 shapes covering the 28 NK pairs of the three production VLMs plus K=128 shapes from Qwen3-Next-80B-A3B, with representative M values:
Match rate against tuned data (19,068 shapes, 8 model/TP configurations):
The 394 mismatches (2.1%) are M×N boundary effects; the performance cost of a mismatch is at most 1.39× based on the three-way comparison above.
Scope