Skip to content

[ck gemm a8w8 blockscale] shape-aware kernel selection heuristic for untuned shapes#3316

Open
eppaneamd wants to merge 4 commits into
mainfrom
a8w8-blockscale-gemm-fallback-heuristic
Open

[ck gemm a8w8 blockscale] shape-aware kernel selection heuristic for untuned shapes#3316
eppaneamd wants to merge 4 commits into
mainfrom
a8w8-blockscale-gemm-fallback-heuristic

Conversation

@eppaneamd
Copy link
Copy Markdown
Contributor

Summary

This PR adds a shape-aware kernel selection heuristic for untuned a8w8 blockscale GEMM shapes on gfx942. Without a tuned CSV hit, the runtime falls back to a single hard-coded default (16×128 tile, v1 pipeline), which is well-suited for single-token decode but leaves significant performance on the table for compute-bound prefill and moderate-batch shapes. On a 211-shape diagnostic benchmark spanning 28 (N, K) pairs, the heuristic is 1.61× faster than the upstream fallback on average and closes approximately 96% of the gap to the per-shape tuned optimum. Across a wider dataset of 19,068 shapes from 8 model/TP configurations, the heuristic matches the tuned-optimal kernel in 97.9% of cases.


Shape coverage

The 19,068 MNK shapes are constructed from two components:

  • (N, K) pairs are derived analytically from each model's config.json using vLLM's tensor-parallel sharding rules (ColumnParallelLinear, RowParallelLinear, QKVParallelLinear) and verified against empirical vLLM inference logs. Each model/TP configuration contributes 2–5 distinct NK pairs (28 pairs across 8 configs).
  • M values are taken from AITER's gl=0 padded-M grid. At runtime, AITER rounds raw M to the nearest grid point before CSV lookup; covering all 681 grid points (M up to 80,000) provides complete coverage of every runtime M the tuned tables would serve.

Motivation

The upstream fallback saturates memory bandwidth at small M but does not efficiently utilise XDL compute units for larger tiles. Comparing tuned and default kernels across 19,068 shapes (3 FP8 VLMs, TP=2/4/8, gfx942):

Metric Value
Geometric mean speedup (tuned / default) 3.08×
Median speedup 3.17×
Shapes where tuned is >2× faster 95.8%

Why a heuristic instead of shipping all tuned tables

Tuned tables are specific to a (model, TP, hardware) triple:

  • Coverage: any new model, TP value, or hardware revision with different (N, K) pairs falls back to the ~3× slower default.
  • Maintenance: 3 models × up to 4 TP values × 2+ hardware targets already produces over 150K rows; every new model multiplies this linearly.
  • Correctness risk: a tuned CSV row referencing a kernel not compiled in the current build silently hits TORCH_CHECK(false, ...) at inference time.

A shape-aware heuristic selects from a fixed set of kernels compiled unconditionally on every build (via heuristic_kernels_dict), covers any model without a prior tuning sweep, and degrades gracefully rather than failing when tuning is unavailable.


Heuristic design

The heuristic partitions the (M×N, K) space into 7 arms. K=128 is treated separately because KPerBLOCK=128 kernels avoid the KPadding overhead that KPerBLOCK=256 kernels incur at small K, and the v3 pipeline is unsupported at K=128:

k == 128:
    M×N ≤ 307,200   → 16×64  tile, KPerBLOCK=128, v1
    M×N ≤ 2,000,000 → 32×64  tile, KPerBLOCK=128, v1
    else             → 64×64  tile, KPerBLOCK=128, v1
k != 128:
    M×N ≤ 307,200   → 16×64  tile, KPerBLOCK=256, v1
    M×N ≤ 1,050,000 → 64×64  tile, KPerBLOCK=256, v1
    K < 384          → 64×64  tile, KPerBLOCK=256, v1   (v3 requires K ≥ 384)
    M×N ≤ 2,500,000 → 64×128 tile, KPerBLOCK=128, v3
    else             → 128×128 tile, KPerBLOCK=128, v3

The 7 heuristic kernel instances are registered in heuristic_kernels_dict in gemm_a8w8_blockscale_instance.py and emitted unconditionally by gen_instances.py, ensuring they are always present in the C++ name-keyed lookup table regardless of which tuned CSVs are installed.


Validation

Three-way comparison — a tailored diagnostic benchmark of 211 shapes covering the 28 NK pairs of the three production VLMs plus K=128 shapes from Qwen3-Next-80B-A3B, with representative M values:

Comparison GeoMean
PR heuristic vs upstream fallback 1.61×
Tuned vs upstream fallback 1.68×
Tuned vs PR heuristic (residual gap) 1.044×

Match rate against tuned data (19,068 shapes, 8 model/TP configurations):

Config Shapes Match rate
32B TP=8 2,724 97.0%
32B TP=4 2,724 98.4%
27B TP=8 2,724 97.2%
27B TP=4 3,405 98.3%
27B TP=2 3,405 98.6%
235B TP=8 1,362 96.5%
235B TP=4 1,362 98.3%
235B TP=2 1,362 99.0%
Total 19,068 97.9%

The 394 mismatches (2.1%) are M×N boundary effects; the performance cost of a mismatch is at most 1.39× based on the three-way comparison above.


Scope

  • Hardware: gfx942 (MI300X) only. Shapes on gfx950 fall through to the asm path.
  • Data-type: a8w8 blockscale (FP8 inputs with per-block scale factors) only.
  • Models benchmarked: Qwen3-VL-235B-A22B-Instruct-FP8 (TP=2/4/8), Qwen3-VL-32B-Instruct-FP8 (TP=4/8), Qwen3.6-27B-FP8 (TP=2/4/8). Speedup magnitudes are sample estimates; they may differ on other models, TP values, or hardware.

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3316 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Python-side, shape-aware fallback dispatch for untuned a8w8 blockscale GEMM shapes (intended for gfx942), and updates CK instance codegen to ensure the heuristic target kernels are always compiled and present in the name-keyed C++ lookup table even when tuned CSV coverage is incomplete.

Changes:

  • Extend CK codegen to always include a fixed set of “heuristic kernels” in the generated registry alongside tuned CSV-selected kernels.
  • Introduce heuristic_kernels_dict to pin the exact kernel instances the runtime heuristic will reference.
  • Add a no-tuned-row shape heuristic in gemm_a8w8_blockscale() that selects a kernelName based on (M×N, K) and calls the CK wrapper with that explicit kernelName.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
csrc/ck_gemm_a8w8_blockscale/gen_instances.py Merges heuristic kernels into the generated set so name-keyed lookup contains them regardless of tuned CSV install state.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py Defines heuristic_kernels_dict mapping heuristic labels to specific candidate kernel instances.
aiter/ops/gemm_op_a8w8.py Adds Python-side shape heuristic fallback that chooses a kernelName string when no tuned CSV row matches.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/gemm_op_a8w8.py Outdated
Comment thread aiter/ops/gemm_op_a8w8.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

Comment on lines +127 to +148
# Kernels always compiled regardless of tuned CSV presence.
# These are the targets of the shape-aware dispatch heuristic in gemm_op_a8w8.py
# and must appear in the name-keyed C++ lookup table in every build. Keep in
# sync with the heuristic if kernel selections change.
heuristic_kernels_dict = {
"heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N
"heuristic_medium_mn": candidate_kernels_dict[
18
], # 64×64 tile, v1 — medium M×N; also K<384 fallback
"heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N
"heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N
# K=128: KPerBLOCK=128 exact fit avoids KPadding overhead; v3 unsupported at K=128.
"heuristic_k128_xsmall_mn": candidate_kernels_dict[
6
], # 16×64 tile, v1 — K=128, small M×N
"heuristic_k128_medium_large_mn": candidate_kernels_dict[
11
], # 32×64 tile, v1 — K=128, medium M×N
"heuristic_k128_xlarge_mn": candidate_kernels_dict[
16
], # 64×64 tile, v1 — K=128, large M×N
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants