From 7d021490c39e6022829022027268bb072ee84961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Fri, 22 May 2026 16:14:28 +0000 Subject: [PATCH 1/3] [ck gemm a8w8 blockscale] add shape-aware fallback heuristic --- aiter/ops/gemm_op_a8w8.py | 49 ++++++++++++++++++- .../gemm_a8w8_blockscale_instance.py | 15 ++++++ csrc/ck_gemm_a8w8_blockscale/gen_instances.py | 8 ++- 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index ba0e523e51..63eba77270 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -716,8 +716,55 @@ def gemm_a8w8_blockscale( ) else: assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale" + # No tuned row: select kernel by shape. + # K=128: KPerBLOCK=128 v1 kernels avoid KPadding vs the KPerBLOCK=256 defaults. + # K<384 (not 128): v3 pipeline requires K≥384; substitute v1 kernel. + _mn = m * n + if k == 128: + # v3 pipeline unsupported at K=128; use v1 throughout this branch. + if _mn <= 307200: + _heuristic_kernel = ( # 16×64, v1 + "a8w8_blockscale_1x128x128_256x16x64x128_8x16_16x16_1x1" + "_16x16x1_8x32x1_1x16x1x16_4_1x1_intrawave_v1" + ) + elif _mn <= 2000000: + _heuristic_kernel = ( # 32×64, v1 + "a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_2x1" + "_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1" + ) + else: + _heuristic_kernel = ( # 64×64, v1 + "a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_1x1" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif _mn <= 307200: + _heuristic_kernel = ( # 16×64, v1 + "a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_1x1" + "_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1" + ) + elif _mn <= 1050000: + _heuristic_kernel = ( # 64×64, v1 + "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" + "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif k < 384: + _heuristic_kernel = ( # 64×64, v1 — v3 unsupported below K=384 + "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" + "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif _mn <= 2500000: + _heuristic_kernel = ( # 64×128, v3 + "a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_1x2" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" + ) + else: + _heuristic_kernel = ( # 128×128, v3 + "a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_2x2" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" + ) try: - return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y) + return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y, + kernelName=_heuristic_kernel) except RuntimeError as e: raise RuntimeError( f"gemm_a8w8_blockscale failed for shape M={m}, N={n}, K={k}, " diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py index 48fe8276de..9b658d292d 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py @@ -123,3 +123,18 @@ def name(self) -> str: # (kernelId, kernelName) desync in the CSV silently producing a .so that # TORCH_CHECK(false, ...) at runtime for the offending shape. candidate_kernels_by_name = {v.name: v for v in candidate_kernels_dict.values()} + +# Kernels always compiled regardless of tuned CSV presence. +# These are the targets of the shape-aware dispatch heuristic in gemm_op_a8w8.py +# and must appear in the name-keyed C++ lookup table in every build. Keep in +# sync with the heuristic if kernel selections change. +heuristic_kernels_dict = { + "heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N + "heuristic_medium_mn": candidate_kernels_dict[18], # 64×64 tile, v1 — medium M×N; also K<384 fallback + "heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N + "heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N + # K=128: KPerBLOCK=128 exact fit avoids KPadding overhead; v3 unsupported at K=128. + "heuristic_k128_xsmall_mn": candidate_kernels_dict[6], # 16×64 tile, v1 — K=128, small M×N + "heuristic_k128_medium_large_mn": candidate_kernels_dict[11], # 32×64 tile, v1 — K=128, medium M×N + "heuristic_k128_xlarge_mn": candidate_kernels_dict[16], # 64×64 tile, v1 — K=128, large M×N +} diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py index 24a5f63fdb..27fa47fca6 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py @@ -28,6 +28,7 @@ KernelInstance, candidate_kernels_dict, candidate_kernels_by_name, + heuristic_kernels_dict, ) """ @@ -406,8 +407,11 @@ def run(self): # generate code for default kernels self.gen_code(candidate_kernels_dict) else: - # generate code for tuned kernels from tune_file - self.gen_code(self.get_tune_dict(self.tune_file)) + # generate code for tuned kernels from tune_file, always including + # heuristic kernels so they are in the name-keyed lookup table + # regardless of which tuned CSVs are installed. + tune_dict = self.get_tune_dict(self.tune_file) + self.gen_code({**tune_dict, **heuristic_kernels_dict}) if __name__ == "__main__": From 8520a7d3e2e5afd53ef26a998e82b3f028ad09b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Fri, 22 May 2026 17:01:44 +0000 Subject: [PATCH 2/3] style: apply black formatting --- aiter/ops/gemm_op_a8w8.py | 21 +++++++++--------- .../gemm_a8w8_blockscale_instance.py | 22 +++++++++++++------ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 63eba77270..cb8a3b056b 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -723,48 +723,49 @@ def gemm_a8w8_blockscale( if k == 128: # v3 pipeline unsupported at K=128; use v1 throughout this branch. if _mn <= 307200: - _heuristic_kernel = ( # 16×64, v1 + _heuristic_kernel = ( # 16×64, v1 "a8w8_blockscale_1x128x128_256x16x64x128_8x16_16x16_1x1" "_16x16x1_8x32x1_1x16x1x16_4_1x1_intrawave_v1" ) elif _mn <= 2000000: - _heuristic_kernel = ( # 32×64, v1 + _heuristic_kernel = ( # 32×64, v1 "a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_2x1" "_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1" ) else: - _heuristic_kernel = ( # 64×64, v1 + _heuristic_kernel = ( # 64×64, v1 "a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_1x1" "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1" ) elif _mn <= 307200: - _heuristic_kernel = ( # 16×64, v1 + _heuristic_kernel = ( # 16×64, v1 "a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_1x1" "_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1" ) elif _mn <= 1050000: - _heuristic_kernel = ( # 64×64, v1 + _heuristic_kernel = ( # 64×64, v1 "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" ) elif k < 384: - _heuristic_kernel = ( # 64×64, v1 — v3 unsupported below K=384 + _heuristic_kernel = ( # 64×64, v1 — v3 unsupported below K=384 "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" ) elif _mn <= 2500000: - _heuristic_kernel = ( # 64×128, v3 + _heuristic_kernel = ( # 64×128, v3 "a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_1x2" "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" ) else: - _heuristic_kernel = ( # 128×128, v3 + _heuristic_kernel = ( # 128×128, v3 "a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_2x2" "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" ) try: - return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y, - kernelName=_heuristic_kernel) + return gemm_a8w8_blockscale_ck( + XQ, WQ, x_scale, w_scale, Y, kernelName=_heuristic_kernel + ) except RuntimeError as e: raise RuntimeError( f"gemm_a8w8_blockscale failed for shape M={m}, N={n}, K={k}, " diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py index 9b658d292d..58888c4347 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py @@ -129,12 +129,20 @@ def name(self) -> str: # and must appear in the name-keyed C++ lookup table in every build. Keep in # sync with the heuristic if kernel selections change. heuristic_kernels_dict = { - "heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N - "heuristic_medium_mn": candidate_kernels_dict[18], # 64×64 tile, v1 — medium M×N; also K<384 fallback - "heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N - "heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N + "heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N + "heuristic_medium_mn": candidate_kernels_dict[ + 18 + ], # 64×64 tile, v1 — medium M×N; also K<384 fallback + "heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N + "heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N # K=128: KPerBLOCK=128 exact fit avoids KPadding overhead; v3 unsupported at K=128. - "heuristic_k128_xsmall_mn": candidate_kernels_dict[6], # 16×64 tile, v1 — K=128, small M×N - "heuristic_k128_medium_large_mn": candidate_kernels_dict[11], # 32×64 tile, v1 — K=128, medium M×N - "heuristic_k128_xlarge_mn": candidate_kernels_dict[16], # 64×64 tile, v1 — K=128, large M×N + "heuristic_k128_xsmall_mn": candidate_kernels_dict[ + 6 + ], # 16×64 tile, v1 — K=128, small M×N + "heuristic_k128_medium_large_mn": candidate_kernels_dict[ + 11 + ], # 32×64 tile, v1 — K=128, medium M×N + "heuristic_k128_xlarge_mn": candidate_kernels_dict[ + 16 + ], # 64×64 tile, v1 — K=128, large M×N } From f68dc9da9612d1a988e1e505cf2f2b326f8b098d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Fri, 22 May 2026 17:11:16 +0000 Subject: [PATCH 3/3] [ck gemm a8w8 blockscale] gate heuristic dispatch on gfx942 --- aiter/ops/gemm_op_a8w8.py | 93 ++++++++++++++++++----------------- op_tests/test_gemm_codegen.py | 76 ++++++++++++++++++---------- 2 files changed, 99 insertions(+), 70 deletions(-) diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index cb8a3b056b..1209d03f8a 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -716,56 +716,59 @@ def gemm_a8w8_blockscale( ) else: assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale" - # No tuned row: select kernel by shape. - # K=128: KPerBLOCK=128 v1 kernels avoid KPadding vs the KPerBLOCK=256 defaults. - # K<384 (not 128): v3 pipeline requires K≥384; substitute v1 kernel. - _mn = m * n - if k == 128: - # v3 pipeline unsupported at K=128; use v1 throughout this branch. - if _mn <= 307200: - _heuristic_kernel = ( # 16×64, v1 - "a8w8_blockscale_1x128x128_256x16x64x128_8x16_16x16_1x1" - "_16x16x1_8x32x1_1x16x1x16_4_1x1_intrawave_v1" + # No tuned row: on gfx942 select kernel by shape; elsewhere use the + # C++ default (empty kernelName) to avoid behaviour changes on other + # architectures where these tile choices have not been validated. + # get_gfx() is lru_cache(maxsize=1) — no per-call overhead. + _ck_kwargs = {} + if get_gfx() == "gfx942": + # K=128: KPerBLOCK=128 v1 kernels avoid KPadding vs the KPerBLOCK=256 defaults. + # K<384 (not 128): v3 pipeline requires K≥384; substitute v1 kernel. + _mn = m * n + if k == 128: + # v3 pipeline unsupported at K=128; use v1 throughout this branch. + if _mn <= 307200: + _ck_kwargs["kernelName"] = ( # 16×64, v1 + "a8w8_blockscale_1x128x128_256x16x64x128_8x16_16x16_1x1" + "_16x16x1_8x32x1_1x16x1x16_4_1x1_intrawave_v1" + ) + elif _mn <= 2000000: + _ck_kwargs["kernelName"] = ( # 32×64, v1 + "a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_2x1" + "_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1" + ) + else: + _ck_kwargs["kernelName"] = ( # 64×64, v1 + "a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_1x1" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif _mn <= 307200: + _ck_kwargs["kernelName"] = ( # 16×64, v1 + "a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_1x1" + "_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1" ) - elif _mn <= 2000000: - _heuristic_kernel = ( # 32×64, v1 - "a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_2x1" - "_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1" + elif _mn <= 1050000: + _ck_kwargs["kernelName"] = ( # 64×64, v1 + "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" + "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif k < 384: + _ck_kwargs["kernelName"] = ( # 64×64, v1 — v3 unsupported below K=384 + "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" + "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif _mn <= 2500000: + _ck_kwargs["kernelName"] = ( # 64×128, v3 + "a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_1x2" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" ) else: - _heuristic_kernel = ( # 64×64, v1 - "a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_1x1" - "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1" + _ck_kwargs["kernelName"] = ( # 128×128, v3 + "a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_2x2" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" ) - elif _mn <= 307200: - _heuristic_kernel = ( # 16×64, v1 - "a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_1x1" - "_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1" - ) - elif _mn <= 1050000: - _heuristic_kernel = ( # 64×64, v1 - "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" - "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" - ) - elif k < 384: - _heuristic_kernel = ( # 64×64, v1 — v3 unsupported below K=384 - "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" - "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" - ) - elif _mn <= 2500000: - _heuristic_kernel = ( # 64×128, v3 - "a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_1x2" - "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" - ) - else: - _heuristic_kernel = ( # 128×128, v3 - "a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_2x2" - "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" - ) try: - return gemm_a8w8_blockscale_ck( - XQ, WQ, x_scale, w_scale, Y, kernelName=_heuristic_kernel - ) + return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y, **_ck_kwargs) except RuntimeError as e: raise RuntimeError( f"gemm_a8w8_blockscale failed for shape M={m}, N={n}, K={k}, " diff --git a/op_tests/test_gemm_codegen.py b/op_tests/test_gemm_codegen.py index 8d82453f0c..b6ade94c82 100644 --- a/op_tests/test_gemm_codegen.py +++ b/op_tests/test_gemm_codegen.py @@ -634,36 +634,62 @@ def _override_blockscale_csv(csv_path): f"recorded kwargs={record.get('kwargs')}", ) - # 6.4 No tuned row for the shape → kernelName="" forwarded (default - # heuristic kicks in inside C++). This guards the empty-name fallback - # path that's intentionally distinct from the wrong-name hard error. + # 6.4 No tuned row for the shape → shape-aware heuristic on gfx942, + # empty-string default on other architectures. csv_empty = _make_temp_csv( "gfx,cu_num,M,N,K,kernelId,libtype,splitK,us,kernelName," "tflops,bw,errRatio\n" ) csv_paths.append(csv_empty) - a8w8_mod._CKGEMM_CONFIG_CACHE = {} - a8w8_mod._CKGEMM_HAS_GFX = {} - get_CKGEMM_config.cache_clear() - with _override_blockscale_csv(csv_empty): - record.clear() - a8w8_mod.gemm_a8w8_blockscale( - XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16 - ) - # With no row matched, the dispatcher hits the "no config" fallback, - # which calls gemm_a8w8_blockscale_ck without kernelName= — Python's - # default kwarg ("") then propagates to C++. - _check( - "no tuned row: still routed to gemm_a8w8_blockscale_ck (default path)", - record.get("libtype") == "ck", - f"recorded libtype={record.get('libtype')}", - ) - _check( - "no tuned row: kernelName not explicitly set " - "(C++ sees default empty string)", - "kernelName" not in record.get("kwargs", {}), - f"recorded kwargs={record.get('kwargs')}", - ) + + saved_get_gfx = a8w8_mod.get_gfx + + try: + # 6.4a gfx942: heuristic selects a specific kernelName. + a8w8_mod.get_gfx = lambda: "gfx942" + a8w8_mod._CKGEMM_CONFIG_CACHE = {} + a8w8_mod._CKGEMM_HAS_GFX = {} + get_CKGEMM_config.cache_clear() + with _override_blockscale_csv(csv_empty): + record.clear() + a8w8_mod.gemm_a8w8_blockscale( + XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16 + ) + _check( + "no tuned row (gfx942): routed to gemm_a8w8_blockscale_ck", + record.get("libtype") == "ck", + f"recorded libtype={record.get('libtype')}", + ) + _check( + "no tuned row (gfx942): heuristic kernelName explicitly set", + "kernelName" in record.get("kwargs", {}) + and record["kwargs"]["kernelName"] != "", + f"recorded kwargs={record.get('kwargs')}", + ) + + # 6.4b non-gfx942: no kernelName kwarg — C++ sees its own default. + a8w8_mod.get_gfx = lambda: "gfx950" + a8w8_mod._CKGEMM_CONFIG_CACHE = {} + a8w8_mod._CKGEMM_HAS_GFX = {} + get_CKGEMM_config.cache_clear() + with _override_blockscale_csv(csv_empty): + record.clear() + a8w8_mod.gemm_a8w8_blockscale( + XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16 + ) + _check( + "no tuned row (non-gfx942): routed to gemm_a8w8_blockscale_ck", + record.get("libtype") == "ck", + f"recorded libtype={record.get('libtype')}", + ) + _check( + "no tuned row (non-gfx942): kernelName not explicitly set " + "(C++ sees default empty string)", + "kernelName" not in record.get("kwargs", {}), + f"recorded kwargs={record.get('kwargs')}", + ) + finally: + a8w8_mod.get_gfx = saved_get_gfx finally: a8w8_mod.gemm_a8w8_blockscale_ck = saved["ck"]