diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index ba0e523e51..1209d03f8a 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -716,8 +716,59 @@ def gemm_a8w8_blockscale( ) else: assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale" + # No tuned row: on gfx942 select kernel by shape; elsewhere use the + # C++ default (empty kernelName) to avoid behaviour changes on other + # architectures where these tile choices have not been validated. + # get_gfx() is lru_cache(maxsize=1) — no per-call overhead. + _ck_kwargs = {} + if get_gfx() == "gfx942": + # K=128: KPerBLOCK=128 v1 kernels avoid KPadding vs the KPerBLOCK=256 defaults. + # K<384 (not 128): v3 pipeline requires K≥384; substitute v1 kernel. + _mn = m * n + if k == 128: + # v3 pipeline unsupported at K=128; use v1 throughout this branch. + if _mn <= 307200: + _ck_kwargs["kernelName"] = ( # 16×64, v1 + "a8w8_blockscale_1x128x128_256x16x64x128_8x16_16x16_1x1" + "_16x16x1_8x32x1_1x16x1x16_4_1x1_intrawave_v1" + ) + elif _mn <= 2000000: + _ck_kwargs["kernelName"] = ( # 32×64, v1 + "a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_2x1" + "_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1" + ) + else: + _ck_kwargs["kernelName"] = ( # 64×64, v1 + "a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_1x1" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif _mn <= 307200: + _ck_kwargs["kernelName"] = ( # 16×64, v1 + "a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_1x1" + "_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1" + ) + elif _mn <= 1050000: + _ck_kwargs["kernelName"] = ( # 64×64, v1 + "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" + "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif k < 384: + _ck_kwargs["kernelName"] = ( # 64×64, v1 — v3 unsupported below K=384 + "a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1" + "_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1" + ) + elif _mn <= 2500000: + _ck_kwargs["kernelName"] = ( # 64×128, v3 + "a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_1x2" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" + ) + else: + _ck_kwargs["kernelName"] = ( # 128×128, v3 + "a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_2x2" + "_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3" + ) try: - return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y) + return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y, **_ck_kwargs) except RuntimeError as e: raise RuntimeError( f"gemm_a8w8_blockscale failed for shape M={m}, N={n}, K={k}, " diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py index 48fe8276de..58888c4347 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py @@ -123,3 +123,26 @@ def name(self) -> str: # (kernelId, kernelName) desync in the CSV silently producing a .so that # TORCH_CHECK(false, ...) at runtime for the offending shape. candidate_kernels_by_name = {v.name: v for v in candidate_kernels_dict.values()} + +# Kernels always compiled regardless of tuned CSV presence. +# These are the targets of the shape-aware dispatch heuristic in gemm_op_a8w8.py +# and must appear in the name-keyed C++ lookup table in every build. Keep in +# sync with the heuristic if kernel selections change. +heuristic_kernels_dict = { + "heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N + "heuristic_medium_mn": candidate_kernels_dict[ + 18 + ], # 64×64 tile, v1 — medium M×N; also K<384 fallback + "heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N + "heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N + # K=128: KPerBLOCK=128 exact fit avoids KPadding overhead; v3 unsupported at K=128. + "heuristic_k128_xsmall_mn": candidate_kernels_dict[ + 6 + ], # 16×64 tile, v1 — K=128, small M×N + "heuristic_k128_medium_large_mn": candidate_kernels_dict[ + 11 + ], # 32×64 tile, v1 — K=128, medium M×N + "heuristic_k128_xlarge_mn": candidate_kernels_dict[ + 16 + ], # 64×64 tile, v1 — K=128, large M×N +} diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py index 29de526813..b9d9bbb06c 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py @@ -28,6 +28,7 @@ KernelInstance, candidate_kernels_dict, candidate_kernels_by_name, + heuristic_kernels_dict, ) """ @@ -409,8 +410,11 @@ def run(self): # generate code for default kernels self.gen_code(candidate_kernels_dict) else: - # generate code for tuned kernels from tune_file - self.gen_code(self.get_tune_dict(self.tune_file)) + # generate code for tuned kernels from tune_file, always including + # heuristic kernels so they are in the name-keyed lookup table + # regardless of which tuned CSVs are installed. + tune_dict = self.get_tune_dict(self.tune_file) + self.gen_code({**tune_dict, **heuristic_kernels_dict}) if __name__ == "__main__": diff --git a/op_tests/test_gemm_codegen.py b/op_tests/test_gemm_codegen.py index 8d82453f0c..b6ade94c82 100644 --- a/op_tests/test_gemm_codegen.py +++ b/op_tests/test_gemm_codegen.py @@ -634,36 +634,62 @@ def _override_blockscale_csv(csv_path): f"recorded kwargs={record.get('kwargs')}", ) - # 6.4 No tuned row for the shape → kernelName="" forwarded (default - # heuristic kicks in inside C++). This guards the empty-name fallback - # path that's intentionally distinct from the wrong-name hard error. + # 6.4 No tuned row for the shape → shape-aware heuristic on gfx942, + # empty-string default on other architectures. csv_empty = _make_temp_csv( "gfx,cu_num,M,N,K,kernelId,libtype,splitK,us,kernelName," "tflops,bw,errRatio\n" ) csv_paths.append(csv_empty) - a8w8_mod._CKGEMM_CONFIG_CACHE = {} - a8w8_mod._CKGEMM_HAS_GFX = {} - get_CKGEMM_config.cache_clear() - with _override_blockscale_csv(csv_empty): - record.clear() - a8w8_mod.gemm_a8w8_blockscale( - XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16 - ) - # With no row matched, the dispatcher hits the "no config" fallback, - # which calls gemm_a8w8_blockscale_ck without kernelName= — Python's - # default kwarg ("") then propagates to C++. - _check( - "no tuned row: still routed to gemm_a8w8_blockscale_ck (default path)", - record.get("libtype") == "ck", - f"recorded libtype={record.get('libtype')}", - ) - _check( - "no tuned row: kernelName not explicitly set " - "(C++ sees default empty string)", - "kernelName" not in record.get("kwargs", {}), - f"recorded kwargs={record.get('kwargs')}", - ) + + saved_get_gfx = a8w8_mod.get_gfx + + try: + # 6.4a gfx942: heuristic selects a specific kernelName. + a8w8_mod.get_gfx = lambda: "gfx942" + a8w8_mod._CKGEMM_CONFIG_CACHE = {} + a8w8_mod._CKGEMM_HAS_GFX = {} + get_CKGEMM_config.cache_clear() + with _override_blockscale_csv(csv_empty): + record.clear() + a8w8_mod.gemm_a8w8_blockscale( + XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16 + ) + _check( + "no tuned row (gfx942): routed to gemm_a8w8_blockscale_ck", + record.get("libtype") == "ck", + f"recorded libtype={record.get('libtype')}", + ) + _check( + "no tuned row (gfx942): heuristic kernelName explicitly set", + "kernelName" in record.get("kwargs", {}) + and record["kwargs"]["kernelName"] != "", + f"recorded kwargs={record.get('kwargs')}", + ) + + # 6.4b non-gfx942: no kernelName kwarg — C++ sees its own default. + a8w8_mod.get_gfx = lambda: "gfx950" + a8w8_mod._CKGEMM_CONFIG_CACHE = {} + a8w8_mod._CKGEMM_HAS_GFX = {} + get_CKGEMM_config.cache_clear() + with _override_blockscale_csv(csv_empty): + record.clear() + a8w8_mod.gemm_a8w8_blockscale( + XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16 + ) + _check( + "no tuned row (non-gfx942): routed to gemm_a8w8_blockscale_ck", + record.get("libtype") == "ck", + f"recorded libtype={record.get('libtype')}", + ) + _check( + "no tuned row (non-gfx942): kernelName not explicitly set " + "(C++ sees default empty string)", + "kernelName" not in record.get("kwargs", {}), + f"recorded kwargs={record.get('kwargs')}", + ) + finally: + a8w8_mod.get_gfx = saved_get_gfx finally: a8w8_mod.gemm_a8w8_blockscale_ck = saved["ck"]