Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion aiter/ops/gemm_op_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,59 @@ def gemm_a8w8_blockscale(
)
else:
assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale"
# No tuned row: on gfx942 select kernel by shape; elsewhere use the
# C++ default (empty kernelName) to avoid behaviour changes on other
# architectures where these tile choices have not been validated.
# get_gfx() is lru_cache(maxsize=1) — no per-call overhead.
_ck_kwargs = {}
if get_gfx() == "gfx942":
# K=128: KPerBLOCK=128 v1 kernels avoid KPadding vs the KPerBLOCK=256 defaults.
# K<384 (not 128): v3 pipeline requires K≥384; substitute v1 kernel.
_mn = m * n
if k == 128:
# v3 pipeline unsupported at K=128; use v1 throughout this branch.
if _mn <= 307200:
_ck_kwargs["kernelName"] = ( # 16×64, v1
"a8w8_blockscale_1x128x128_256x16x64x128_8x16_16x16_1x1"
"_16x16x1_8x32x1_1x16x1x16_4_1x1_intrawave_v1"
)
elif _mn <= 2000000:
_ck_kwargs["kernelName"] = ( # 32×64, v1
"a8w8_blockscale_1x128x128_256x32x64x128_16x16_16x16_2x1"
"_8x32x1_8x32x1_1x32x1x8_8_2x1_intrawave_v1"
)
else:
_ck_kwargs["kernelName"] = ( # 64×64, v1
"a8w8_blockscale_1x128x128_256x64x64x128_16x16_32x32_1x1"
"_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v1"
)
elif _mn <= 307200:
_ck_kwargs["kernelName"] = ( # 16×64, v1
"a8w8_blockscale_1x128x128_256x16x64x256_16x16_16x16_1x1"
"_16x16x1_16x16x1_1x16x1x16_4_1x1_intrawave_v1"
)
elif _mn <= 1050000:
_ck_kwargs["kernelName"] = ( # 64×64, v1
"a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1"
"_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1"
)
elif k < 384:
_ck_kwargs["kernelName"] = ( # 64×64, v1 — v3 unsupported below K=384
"a8w8_blockscale_1x128x128_256x64x64x256_16x16_32x32_1x1"
"_16x16x1_16x16x1_1x32x1x8_8_1x1_intrawave_v1"
)
elif _mn <= 2500000:
_ck_kwargs["kernelName"] = ( # 64×128, v3
"a8w8_blockscale_1x128x128_256x64x128x128_16x16_32x32_1x2"
"_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3"
)
else:
_ck_kwargs["kernelName"] = ( # 128×128, v3
"a8w8_blockscale_1x128x128_256x128x128x128_16x16_32x32_2x2"
"_8x32x1_8x32x1_1x32x1x8_8_1x1_intrawave_v3"
)
try:
return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y)
return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y, **_ck_kwargs)
except RuntimeError as e:
raise RuntimeError(
f"gemm_a8w8_blockscale failed for shape M={m}, N={n}, K={k}, "
Expand Down
23 changes: 23 additions & 0 deletions csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,26 @@ def name(self) -> str:
# (kernelId, kernelName) desync in the CSV silently producing a .so that
# TORCH_CHECK(false, ...) at runtime for the offending shape.
candidate_kernels_by_name = {v.name: v for v in candidate_kernels_dict.values()}

# Kernels always compiled regardless of tuned CSV presence.
# These are the targets of the shape-aware dispatch heuristic in gemm_op_a8w8.py
# and must appear in the name-keyed C++ lookup table in every build. Keep in
# sync with the heuristic if kernel selections change.
heuristic_kernels_dict = {
"heuristic_small_mn": candidate_kernels_dict[8], # 16×64 tile, v1 — small M×N
"heuristic_medium_mn": candidate_kernels_dict[
18
], # 64×64 tile, v1 — medium M×N; also K<384 fallback
"heuristic_large_mn": candidate_kernels_dict[2], # 64×128 tile, v3 — large M×N
"heuristic_xlarge_mn": candidate_kernels_dict[0], # 128×128 tile, v3 — xlarge M×N
# K=128: KPerBLOCK=128 exact fit avoids KPadding overhead; v3 unsupported at K=128.
"heuristic_k128_xsmall_mn": candidate_kernels_dict[
6
], # 16×64 tile, v1 — K=128, small M×N
"heuristic_k128_medium_large_mn": candidate_kernels_dict[
11
], # 32×64 tile, v1 — K=128, medium M×N
"heuristic_k128_xlarge_mn": candidate_kernels_dict[
16
], # 64×64 tile, v1 — K=128, large M×N
}
Comment on lines +127 to +148
8 changes: 6 additions & 2 deletions csrc/ck_gemm_a8w8_blockscale/gen_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
KernelInstance,
candidate_kernels_dict,
candidate_kernels_by_name,
heuristic_kernels_dict,
)

"""
Expand Down Expand Up @@ -409,8 +410,11 @@ def run(self):
# generate code for default kernels
self.gen_code(candidate_kernels_dict)
else:
# generate code for tuned kernels from tune_file
self.gen_code(self.get_tune_dict(self.tune_file))
# generate code for tuned kernels from tune_file, always including
# heuristic kernels so they are in the name-keyed lookup table
# regardless of which tuned CSVs are installed.
tune_dict = self.get_tune_dict(self.tune_file)
self.gen_code({**tune_dict, **heuristic_kernels_dict})


if __name__ == "__main__":
Expand Down
76 changes: 51 additions & 25 deletions op_tests/test_gemm_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,36 +634,62 @@ def _override_blockscale_csv(csv_path):
f"recorded kwargs={record.get('kwargs')}",
)

# 6.4 No tuned row for the shape → kernelName="" forwarded (default
# heuristic kicks in inside C++). This guards the empty-name fallback
# path that's intentionally distinct from the wrong-name hard error.
# 6.4 No tuned row for the shape → shape-aware heuristic on gfx942,
# empty-string default on other architectures.
csv_empty = _make_temp_csv(
"gfx,cu_num,M,N,K,kernelId,libtype,splitK,us,kernelName,"
"tflops,bw,errRatio\n"
)
csv_paths.append(csv_empty)
a8w8_mod._CKGEMM_CONFIG_CACHE = {}
a8w8_mod._CKGEMM_HAS_GFX = {}
get_CKGEMM_config.cache_clear()
with _override_blockscale_csv(csv_empty):
record.clear()
a8w8_mod.gemm_a8w8_blockscale(
XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16
)
# With no row matched, the dispatcher hits the "no config" fallback,
# which calls gemm_a8w8_blockscale_ck without kernelName= — Python's
# default kwarg ("") then propagates to C++.
_check(
"no tuned row: still routed to gemm_a8w8_blockscale_ck (default path)",
record.get("libtype") == "ck",
f"recorded libtype={record.get('libtype')}",
)
_check(
"no tuned row: kernelName not explicitly set "
"(C++ sees default empty string)",
"kernelName" not in record.get("kwargs", {}),
f"recorded kwargs={record.get('kwargs')}",
)

saved_get_gfx = a8w8_mod.get_gfx

try:
# 6.4a gfx942: heuristic selects a specific kernelName.
a8w8_mod.get_gfx = lambda: "gfx942"
a8w8_mod._CKGEMM_CONFIG_CACHE = {}
a8w8_mod._CKGEMM_HAS_GFX = {}
get_CKGEMM_config.cache_clear()
with _override_blockscale_csv(csv_empty):
record.clear()
a8w8_mod.gemm_a8w8_blockscale(
XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16
)
_check(
"no tuned row (gfx942): routed to gemm_a8w8_blockscale_ck",
record.get("libtype") == "ck",
f"recorded libtype={record.get('libtype')}",
)
_check(
"no tuned row (gfx942): heuristic kernelName explicitly set",
"kernelName" in record.get("kwargs", {})
and record["kwargs"]["kernelName"] != "",
f"recorded kwargs={record.get('kwargs')}",
)

# 6.4b non-gfx942: no kernelName kwarg — C++ sees its own default.
a8w8_mod.get_gfx = lambda: "gfx950"
a8w8_mod._CKGEMM_CONFIG_CACHE = {}
a8w8_mod._CKGEMM_HAS_GFX = {}
get_CKGEMM_config.cache_clear()
with _override_blockscale_csv(csv_empty):
record.clear()
a8w8_mod.gemm_a8w8_blockscale(
XQ, WQ, x_scale, w_scale, dtype=torch.bfloat16
)
_check(
"no tuned row (non-gfx942): routed to gemm_a8w8_blockscale_ck",
record.get("libtype") == "ck",
f"recorded libtype={record.get('libtype')}",
)
_check(
"no tuned row (non-gfx942): kernelName not explicitly set "
"(C++ sees default empty string)",
"kernelName" not in record.get("kwargs", {}),
f"recorded kwargs={record.get('kwargs')}",
)
finally:
a8w8_mod.get_gfx = saved_get_gfx

finally:
a8w8_mod.gemm_a8w8_blockscale_ck = saved["ck"]
Expand Down
Loading