Suggestion Description
gfx950 A8W8 blockscale bpreshuffle CK GEMM is nondeterministic for WKV-shaped untuned rows
Summary
gemm_a8w8_blockscale_bpreshuffle can produce nondeterministic outputs on MI35x/gfx950 for fixed inputs and a fixed launch shape.
The repro shape is WKV-like:
M=8192, N=512, K=4096
A: [8192, 4096]
B: [512, 4096]
C: [8192, 512]
Environment
Image: rocm/sgl-dev:rocm720-mi35x-800693e-20260517-DSv4
Image digest: sha256:83680195f7ffba945966bf6104134255e37b3c3ff1d3726a748c0d05545429a7
AITER: 3c5f0ba937861da585ef0f72d20d9c8dba30ef3b
CK submodule: 10cb6916c34f957e81e8472c085603b4427baab9
Hardware:
- AMD MI355 / gfx950
- AMD MI350 / gfx950
Reproduction
Run this from the AITER repo root:
export PYTHONPATH=$PWD:$PYTHONPATH
export AITER_JIT_DIR=/tmp/aiter_wkv_repro_${USER}_$$
export HIP_VISIBLE_DEVICES=0
python3 - <<'PY'
import hashlib
import torch
from aiter import dtypes
from aiter.ops.gemm_op_a8w8 import gemm_a8w8_blockscale_bpreshuffle
from aiter.ops.shuffle import shuffle_weight
M, N, K = 8192, 512, 4096
iters, warmup = 1000, 10
torch.manual_seed(1)
x = (torch.rand((M, K), dtype=torch.float32, device="cuda") / 10).to(dtypes.fp8)
w = (torch.rand((N, K), dtype=torch.float32, device="cuda") / 10).to(dtypes.fp8)
x_scale = torch.rand((M, (K + 127) // 128), dtype=torch.float32, device="cuda")
w_scale = torch.rand(((N + 127) // 128, (K + 127) // 128), dtype=torch.float32, device="cuda")
x_scale_t = x_scale.transpose(0, 1).contiguous().view(*x_scale.shape)
w_shuffle = shuffle_weight(w, layout=(16, 16))
def run_once():
y = gemm_a8w8_blockscale_bpreshuffle(
x, w_shuffle, x_scale_t, w_scale, dtypes.bf16
)
torch.cuda.synchronize()
return hashlib.sha256(y.cpu().numpy().tobytes()).hexdigest()[:16]
for _ in range(warmup):
run_once()
hashes = [run_once() for _ in range(iters)]
unique = sorted(set(hashes))
print(f"unique_hashes={len(unique)}/{iters}")
print(f"first_hash={hashes[0]}")
print(f"sample_unique={unique[:10]}")
PY
Observed Result
MI355: unstable, 1000/1000 unique hashes
MI350: unstable, 1000/1000 unique hashes
Expected Result
With deterministic kernels, for fixed inputs and a fixed launch path, repeated kernel launches should produce one output hash:
Suggested Fix / Patch for Review
Patch branch:
Repo: ROCm/rocm-libraries
Base: ce3e67b2a634e01ef85b4482feb17ae1983828ba
Branch: sonle5/gemm_a8w8_blockscale_stabilizer
Commit: 408a5bbebd4e1eaec3ddce3c35ef16df37a7acfb
Branch URL:
https://github.com/hdt98/rocm-libraries/tree/sonle5/gemm_a8w8_blockscale_stabilizer
Commit URL:
hdt98/rocm-libraries@408a5bb
PR creation URL:
https://github.com/hdt98/rocm-libraries/pull/new/sonle5/gemm_a8w8_blockscale_stabilizer
The patch touches:
projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp
projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp
The stabilizer passes a runtime-null accumulator observer through the v1 bpreshuffle blockwise path and adds compile-visible accumulator value anchors after the packed FMA sites in the main loop, even-tail, and odd-tail paths.
auto anchor_accumulator_value = [&](float value) {
if(p_accum_observer != nullptr)
{
asm volatile("" : : "v"(value));
}
};
auto anchor_accumulator_target_thread = [&]() {
if(accum_observer_tile && get_thread_local_1d_id() == accum_observer_thread)
{
asm volatile("" : : "s"(accum_observer_thread));
}
};
After the post-MFMA packed FMA into c_thread_buf:
anchor_accumulator_value(
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<AccDataType>()[Number<0>{}]));
Stabilizer Result
MI355: stable, 1/1000 unique hashes
MI350: stable, 1/1000 unique hashes
Focused additional rows also pass with the stabilizer:
M=8192, N=512, K={384,640,3968,4096,4224}
Clean CK baseline: unstable on MI355 and MI350
Stabilized CK: stable on MI355 and MI350
Triton comparator: stable on MI355 and MI350
Performance stays in the same CK latency class. For example, clean p50 timing for 8192x512x4096 was roughly:
MI355 main CK / stabilized CK / Triton: 0.0639 / 0.0670 / 1.0229 ms
MI350 main CK / stabilized CK / Triton: 0.0695 / 0.0719 / 1.1023 ms
Suggestion Description
gfx950 A8W8 blockscale bpreshuffle CK GEMM is nondeterministic for WKV-shaped untuned rows
Summary
gemm_a8w8_blockscale_bpreshufflecan produce nondeterministic outputs on MI35x/gfx950 for fixed inputs and a fixed launch shape.The repro shape is WKV-like:
Environment
Reproduction
Run this from the AITER repo root:
Observed Result
Expected Result
With deterministic kernels, for fixed inputs and a fixed launch path, repeated kernel launches should produce one output hash:
Suggested Fix / Patch for Review
Patch branch:
Branch URL:
https://github.com/hdt98/rocm-libraries/tree/sonle5/gemm_a8w8_blockscale_stabilizer
Commit URL:
hdt98/rocm-libraries@408a5bb
PR creation URL:
https://github.com/hdt98/rocm-libraries/pull/new/sonle5/gemm_a8w8_blockscale_stabilizer
The patch touches:
The stabilizer passes a runtime-null accumulator observer through the v1 bpreshuffle blockwise path and adds compile-visible accumulator value anchors after the packed FMA sites in the main loop, even-tail, and odd-tail paths.
After the post-MFMA packed FMA into
c_thread_buf:Stabilizer Result
Focused additional rows also pass with the stabilizer:
Performance stays in the same CK latency class. For example, clean p50 timing for
8192x512x4096was roughly: