The Python snippets (modified examples/gemm_streamk/example_tilelang_gemm_streamk.py with tl.disable_tma_lower):
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True})
def tl_matmul_streamk(
M,
N,
K,
streamk_tiles,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
):
assert not trans_A
A_shape = (M, K) if not trans_A else (K, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
@T.prim_func
def main(
A: T.Tensor(A_shape, dtypeAB),
B: T.Tensor(B_shape, dtypeAB),
C: T.Tensor((M, N), dtypeC),
):
with T.Kernel(streamk_programs, threads=threads) as pid:
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...
tensor([[ 1.1520e+03, -2.0117e-01, -2.6820e+03, ..., 3.0000e+00,
2.3824e+04, 3.2188e+00],
[-1.5700e+03, 7.1289e+00, -9.5050e+02, ..., 8.3203e+00,
nan, 3.8418e+00],
[ 1.2728e+04, -1.5918e-01, nan, ..., 2.9648e+00,
-3.5328e+04, -4.8340e-01],
...,
[-1.7354e+00, 1.1758e+00, -5.3828e+00, ..., 3.0332e+00,
1.2836e+01, -5.7266e+00],
[ 1.0992e+01, 8.7188e+00, 1.5047e+01, ..., 3.6777e+00,
1.9893e+00, 1.0555e+01],
[ 2.1562e+00, 1.0615e+00, 4.3633e+00, ..., 6.5547e+00,
-1.5117e+01, 4.1719e+00]], device='cuda:0', dtype=torch.float16)
tensor([[ -1.3701, 8.3672, -11.4219, ..., -12.4219, 0.7490, 8.2500],
[ 3.7031, 1.3291, -9.6016, ..., -5.9922, 1.1436, 10.1641],
[ -0.1477, -5.0977, -10.4219, ..., 3.6152, -3.8203, 9.1719],
...,
[ -1.7354, 1.1758, -5.3828, ..., 3.0332, 12.8359, -5.7266],
[ 10.9922, 8.7188, 15.0469, ..., 3.6777, 1.9893, 10.5547],
[ 2.1562, 1.0615, 4.3633, ..., 6.5547, -15.1172, 4.1719]],
device='cuda:0', dtype=torch.float16)
Traceback (most recent call last):
File "/mnt/ssd/gaoxuchen.gxc/tilelang/examples/gemm_streamk/example_tilelang_gemm_streamk.py", line 211, in <module>
main()
File "/mnt/ssd/gaoxuchen.gxc/tilelang/examples/gemm_streamk/example_tilelang_gemm_streamk.py", line 182, in main
torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2)
File "/usr/local/lib/python3.12/dist-packages/torch/testing/_comparison.py", line 1587, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 40866 / 262144 (15.6%)
Greatest absolute difference: nan at index (0, 4) (up to 0.01 allowed)
Greatest relative difference: nan at index (0, 4) (up to 0.01 allowed)
Required prerequisites
What version of TileLang are you using?
main
System information
cuda:12.9
arch: sm90 (H20)
Problem description
I followed issue#972 to run gemm_streamk example with
disable_tma_lower. But here comes result mismatchingIt run at main branch, commit 94959b8 (HEAD -> main, origin/main, origin/HEAD)
Reproducible example code
The Python snippets (modified examples/gemm_streamk/example_tilelang_gemm_streamk.py with
tl.disable_tma_lower):Traceback
tensor([[ 1.1520e+03, -2.0117e-01, -2.6820e+03, ..., 3.0000e+00, 2.3824e+04, 3.2188e+00], [-1.5700e+03, 7.1289e+00, -9.5050e+02, ..., 8.3203e+00, nan, 3.8418e+00], [ 1.2728e+04, -1.5918e-01, nan, ..., 2.9648e+00, -3.5328e+04, -4.8340e-01], ..., [-1.7354e+00, 1.1758e+00, -5.3828e+00, ..., 3.0332e+00, 1.2836e+01, -5.7266e+00], [ 1.0992e+01, 8.7188e+00, 1.5047e+01, ..., 3.6777e+00, 1.9893e+00, 1.0555e+01], [ 2.1562e+00, 1.0615e+00, 4.3633e+00, ..., 6.5547e+00, -1.5117e+01, 4.1719e+00]], device='cuda:0', dtype=torch.float16) tensor([[ -1.3701, 8.3672, -11.4219, ..., -12.4219, 0.7490, 8.2500], [ 3.7031, 1.3291, -9.6016, ..., -5.9922, 1.1436, 10.1641], [ -0.1477, -5.0977, -10.4219, ..., 3.6152, -3.8203, 9.1719], ..., [ -1.7354, 1.1758, -5.3828, ..., 3.0332, 12.8359, -5.7266], [ 10.9922, 8.7188, 15.0469, ..., 3.6777, 1.9893, 10.5547], [ 2.1562, 1.0615, 4.3633, ..., 6.5547, -15.1172, 4.1719]], device='cuda:0', dtype=torch.float16) Traceback (most recent call last): File "/mnt/ssd/gaoxuchen.gxc/tilelang/examples/gemm_streamk/example_tilelang_gemm_streamk.py", line 211, in <module> main() File "/mnt/ssd/gaoxuchen.gxc/tilelang/examples/gemm_streamk/example_tilelang_gemm_streamk.py", line 182, in main torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) File "/usr/local/lib/python3.12/dist-packages/torch/testing/_comparison.py", line 1587, in assert_close raise error_metas[0].to_error(msg) AssertionError: Tensor-likes are not close! Mismatched elements: 40866 / 262144 (15.6%) Greatest absolute difference: nan at index (0, 4) (up to 0.01 allowed) Greatest relative difference: nan at index (0, 4) (up to 0.01 allowed)Expected behavior
No response
Additional context
No response