Skip to content

[Feature]: gfx950 A8W8 blockscale bpreshuffle CK GEMM is nondeterministic for WKV-shaped untuned rows #3261

@hdt98

Description

@hdt98

Suggestion Description

gfx950 A8W8 blockscale bpreshuffle CK GEMM is nondeterministic for WKV-shaped untuned rows

Summary

gemm_a8w8_blockscale_bpreshuffle can produce nondeterministic outputs on MI35x/gfx950 for fixed inputs and a fixed launch shape.

The repro shape is WKV-like:

M=8192, N=512, K=4096
A: [8192, 4096]
B: [512, 4096]
C: [8192, 512]

Environment

Image: rocm/sgl-dev:rocm720-mi35x-800693e-20260517-DSv4
Image digest: sha256:83680195f7ffba945966bf6104134255e37b3c3ff1d3726a748c0d05545429a7

AITER: 3c5f0ba937861da585ef0f72d20d9c8dba30ef3b
CK submodule: 10cb6916c34f957e81e8472c085603b4427baab9

Hardware:
- AMD MI355 / gfx950
- AMD MI350 / gfx950

Reproduction

Run this from the AITER repo root:

export PYTHONPATH=$PWD:$PYTHONPATH
export AITER_JIT_DIR=/tmp/aiter_wkv_repro_${USER}_$$
export HIP_VISIBLE_DEVICES=0

python3 - <<'PY'
import hashlib
import torch

from aiter import dtypes
from aiter.ops.gemm_op_a8w8 import gemm_a8w8_blockscale_bpreshuffle
from aiter.ops.shuffle import shuffle_weight

M, N, K = 8192, 512, 4096
iters, warmup = 1000, 10

torch.manual_seed(1)

x = (torch.rand((M, K), dtype=torch.float32, device="cuda") / 10).to(dtypes.fp8)
w = (torch.rand((N, K), dtype=torch.float32, device="cuda") / 10).to(dtypes.fp8)

x_scale = torch.rand((M, (K + 127) // 128), dtype=torch.float32, device="cuda")
w_scale = torch.rand(((N + 127) // 128, (K + 127) // 128), dtype=torch.float32, device="cuda")

x_scale_t = x_scale.transpose(0, 1).contiguous().view(*x_scale.shape)
w_shuffle = shuffle_weight(w, layout=(16, 16))

def run_once():
    y = gemm_a8w8_blockscale_bpreshuffle(
        x, w_shuffle, x_scale_t, w_scale, dtypes.bf16
    )
    torch.cuda.synchronize()
    return hashlib.sha256(y.cpu().numpy().tobytes()).hexdigest()[:16]

for _ in range(warmup):
    run_once()

hashes = [run_once() for _ in range(iters)]
unique = sorted(set(hashes))

print(f"unique_hashes={len(unique)}/{iters}")
print(f"first_hash={hashes[0]}")
print(f"sample_unique={unique[:10]}")
PY

Observed Result

MI355: unstable, 1000/1000 unique hashes
MI350: unstable, 1000/1000 unique hashes

Expected Result

With deterministic kernels, for fixed inputs and a fixed launch path, repeated kernel launches should produce one output hash:

unique_hashes=1/1000

Suggested Fix / Patch for Review

Patch branch:

Repo: ROCm/rocm-libraries
Base: ce3e67b2a634e01ef85b4482feb17ae1983828ba
Branch: sonle5/gemm_a8w8_blockscale_stabilizer
Commit: 408a5bbebd4e1eaec3ddce3c35ef16df37a7acfb

Branch URL:

https://github.com/hdt98/rocm-libraries/tree/sonle5/gemm_a8w8_blockscale_stabilizer

Commit URL:

hdt98/rocm-libraries@408a5bb

PR creation URL:

https://github.com/hdt98/rocm-libraries/pull/new/sonle5/gemm_a8w8_blockscale_stabilizer

The patch touches:

projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp
projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp

The stabilizer passes a runtime-null accumulator observer through the v1 bpreshuffle blockwise path and adds compile-visible accumulator value anchors after the packed FMA sites in the main loop, even-tail, and odd-tail paths.

auto anchor_accumulator_value = [&](float value) {
    if(p_accum_observer != nullptr)
    {
        asm volatile("" : : "v"(value));
    }
};

auto anchor_accumulator_target_thread = [&]() {
    if(accum_observer_tile && get_thread_local_1d_id() == accum_observer_thread)
    {
        asm volatile("" : : "s"(accum_observer_thread));
    }
};

After the post-MFMA packed FMA into c_thread_buf:

anchor_accumulator_value(
    type_convert<float>(
        c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
            .template AsType<AccDataType>()[Number<0>{}]));

Stabilizer Result

MI355: stable, 1/1000 unique hashes
MI350: stable, 1/1000 unique hashes

Focused additional rows also pass with the stabilizer:

M=8192, N=512, K={384,640,3968,4096,4224}

Clean CK baseline: unstable on MI355 and MI350
Stabilized CK: stable on MI355 and MI350
Triton comparator: stable on MI355 and MI350

Performance stays in the same CK latency class. For example, clean p50 timing for 8192x512x4096 was roughly:

MI355 main CK / stabilized CK / Triton: 0.0639 / 0.0670 / 1.0229 ms
MI350 main CK / stabilized CK / Triton: 0.0695 / 0.0719 / 1.1023 ms

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions