Required prerequisites
What version of TileLang are you using?
0.1.8
System information
>>> print(sys.version, sys.platform)
3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] linux
>>> print(tilelang.__version__)
0.1.8+cuda.git7ab6a7c5
>>> print(torch.__version__)
2.9.1+cu128
Problem description
The buffer dsT_shared (shape [64, 16]) is consumed by two T.gemm operations with different transpose semantics, in witch each gemm infers a different layout for dsT_shared and two layouts conflict.
Reproducible example code
The Python snippets:
@tilelang.jit()
def func(dim_qk: int, block_qo: int, block_kv: int):
dtype = T.bfloat16
accum_dtype = T.float32
@T.prim_func
def prim_func():
with T.Kernel():
K_shared = T.alloc_shared([block_kv, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_kv, block_qo], dtype)
q = T.alloc_shared([block_qo, dim_qk], dtype)
dk = T.alloc_fragment([block_kv, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_qo, dim_qk], accum_dtype)
T.gemm(dsT_shared, q, dk)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
return prim_func
def main(*args, **kwargs):
DIM_QK = 96
BLOCK_QO = 16
BLOCK_KV = 64
func(dim_qk=DIM_QK, block_qo=BLOCK_QO, block_kv=BLOCK_KV)()
Traceback
tvm.error.InternalError: Get different layout for dsT_shared
current layout: Layout([64, 16] -> [1528], transform: [_i, _j] -> [_i * 24 + _j])
previous layout: Layout([64, 16] -> [1, 8, 128], transform: [_i, _j] -> [0, _i // 8, _i % 8 * 16 + (_j // 8 + _i % 8 // 4) % 2 * 8 + _j % 8])
Expected behavior
No response
Additional context
No response
Required prerequisites
What version of TileLang are you using?
0.1.8
System information
Problem description
The buffer
dsT_shared(shape[64, 16]) is consumed by twoT.gemmoperations with different transpose semantics, in witch each gemm infers a different layout fordsT_sharedand two layouts conflict.Reproducible example code
The Python snippets:
Traceback
Expected behavior
No response
Additional context
No response