Skip to content
Merged
10 changes: 10 additions & 0 deletions examples/gemm_sm100/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ T.tcgen05_gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, clear_ac
T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required
```

TileLang now has a conservative `InjectTcgen05Fence` pass on SM100+ that can
insert `tcgen05_before_thread_sync()` / `tcgen05_after_thread_sync()` around:
- `tvm_storage_sync("shared"|"shared.dyn")`
- linear `mbarrier_wait_parity(...) -> tcgen05/TMEM use` regions
- linear `tcgen05/TMEM use -> mbarrier_arrive(...)` regions

This does **not** eliminate the need to structure the mbarrier protocol
explicitly in user code, and the examples in this directory still keep manual
fences where they make the handoff points obvious.

## Examples

### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
Expand Down
4 changes: 0 additions & 4 deletions examples/gemm_sm100/gemm_tcgen5mma_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_
elif tx < 64: # warp 1: issue tcgen5
for k in T.serial(k_iters):
T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1)
T.tcgen05_after_thread_sync()
T.tcgen05_gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
Expand All @@ -60,7 +59,6 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_

# Wait for all tcgen5 to finish
T.mbarrier_wait_parity(tmem_full, 0)
T.tcgen05_after_thread_sync()
T.copy(C_tmem, C_local)
if use_tma_store:
T.copy(C_local, C_shared)
Expand Down Expand Up @@ -115,7 +113,6 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype,
elif cta_id == 0 and tx < 64: # Only warp 1 on leader cta issues tcgen5
for k in T.serial(k_iters):
T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1)
T.tcgen05_after_thread_sync()
T.tcgen05_gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
Expand All @@ -128,7 +125,6 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype,

# Wait for all tcgen5 to finish
T.mbarrier_wait_parity(tmem_full, 0)
T.tcgen05_after_thread_sync()
T.copy(C_tmem, C_local)
if use_tma_store:
T.copy(C_local, C_shared)
Expand Down
10 changes: 0 additions & 10 deletions examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,9 @@ def gemm_persistent(

if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1)
T.tcgen05_after_thread_sync()
for k in T.serial(k_blocks):
phase = w * k_blocks + k
T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1)
T.tcgen05_after_thread_sync()
if w & 1 == 0:
T.tcgen05_gemm(
A_shared[k % num_stages, :, :],
Expand Down Expand Up @@ -116,13 +114,10 @@ def gemm_persistent(

if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1)
T.tcgen05_after_thread_sync()
T.sync_threads(1, 128)
if (w & 1) == 0:
T.copy(C_tmem_0, C_local)
else:
T.copy(C_tmem_1, C_local)
T.tcgen05_before_thread_sync()
T.mbarrier_arrive(tmem_empty[w & 1])

if use_tma_store:
Expand Down Expand Up @@ -220,11 +215,9 @@ def gemm_persistent_2cta(

if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1)
T.tcgen05_after_thread_sync()
for k in T.serial(k_blocks):
phase = w * k_blocks + k
T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1)
T.tcgen05_after_thread_sync()
if w & 1 == 0:
T.tcgen05_gemm(
A_shared[phase % num_stages, :, :],
Expand Down Expand Up @@ -256,13 +249,10 @@ def gemm_persistent_2cta(

if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1)
T.tcgen05_after_thread_sync()
T.sync_threads(1, 128)
if (w & 1) == 0:
T.copy(C_tmem_0, C_local)
else:
T.copy(C_tmem_1, C_local)
T.tcgen05_before_thread_sync()
T.mbarrier_arrive(tmem_empty[w & 1], 0)

if use_tma_store:
Expand Down
18 changes: 14 additions & 4 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tcgen05_before_thread_sync)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tcgen05_after_thread_sync)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tma_store_arrive)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -552,13 +562,13 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tcgen05_before_thread_sync)
.set_num_inputs(0)
TIR_DEFINE_TL_BUILTIN(tcgen05_ld)
.set_num_inputs(6)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tcgen05_after_thread_sync)
.set_num_inputs(0)
TIR_DEFINE_TL_BUILTIN(tcgen05_st)
.set_num_inputs(6)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Expand Down
16 changes: 16 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,22 @@ TVM_DLL const Op &initialize_tcgen05_descriptor();
*/
TVM_DLL const Op &tcgen05_mma_arrive();

/*!
* \brief tilelang intrinsic for lowered TCGEN05 tensor-memory load.
*
* Internal lowering op used by LowerTmemCopy to represent
* `tl::tcgen05_ld_*` calls without routing through `call_extern`.
*/
TVM_DLL const Op &tcgen05_ld();

/*!
* \brief tilelang intrinsic for lowered TCGEN05 tensor-memory store.
*
* Internal lowering op used by LowerTmemCopy to represent
* `tl::tcgen05_st_*` calls without routing through `call_extern`.
*/
TVM_DLL const Op &tcgen05_st();

/*!
* \brief TCGEN05 fence before a thread-block-wide sync (__syncthreads /
* bar.sync). Matches PTX \c tcgen05.fence::before_thread_sync (DeepGEMM /
Expand Down
49 changes: 32 additions & 17 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,6 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
// unpack::16b) so MMA TS reads correctly packed bf16 from TMEM columns.
// For tcgen05_ld, pack::16b is still needed when reading unpacked data.
bool use_pack_unpack_modifier = is_ld ? needs_pack_unpack : false;
const char *bool_str = use_pack_unpack_modifier ? "true" : "false";
int effective_chunks =
needs_pack_unpack ? num_chunks_each_wg / 2 : num_chunks_each_wg;
PrimExpr relative_wg_idx =
Expand All @@ -1497,22 +1496,38 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
: relative_wg_idx * (effective_chunks * meta.width);
have_succeeded = true;
Array<PrimExpr> args;
args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(effective_chunks) + ", " +
bool_str + ">"));
args.push_back(
BufferLoad(tmem_buf, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated
// later in
// lower_shared_tmem
// pass
args.push_back(col_offset);
int reg_access_mode = is_ld ? 2 : 1;
args.push_back(reg_buf.access_ptr(reg_access_mode, DataType::Handle(), 1,
0, PrimExpr(tmem_phy_col_extent)));

Stmt call =
Evaluate(Call(DataType::Handle(), builtin::call_extern(), args));
Stmt call;
if (is_ld) {
args.push_back(IntImm(DataType::Int(32), meta.width * 32));
args.push_back(IntImm(DataType::Int(32), effective_chunks));
args.push_back(Bool(use_pack_unpack_modifier));
args.push_back(
BufferLoad(tmem_buf, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated
// later in
// lower_shared_tmem
// pass
args.push_back(col_offset);
args.push_back(reg_buf.access_ptr(/*access_mask=*/2, DataType::Handle(),
/*content_lanes=*/1, /*offset=*/0,
PrimExpr(tmem_phy_col_extent)));
call = Evaluate(Call(DataType::Handle(), tcgen05_ld(), args));
} else {
args.push_back(IntImm(DataType::Int(32), meta.width * 32));
args.push_back(IntImm(DataType::Int(32), effective_chunks));
args.push_back(Bool(use_pack_unpack_modifier));
args.push_back(
BufferLoad(tmem_buf, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated
// later in
// lower_shared_tmem
// pass
args.push_back(col_offset);
int reg_access_mode = 1;
args.push_back(reg_buf.access_ptr(reg_access_mode, DataType::Handle(),
1, 0, PrimExpr(tmem_phy_col_extent)));
call = Evaluate(Call(DataType::Handle(), tcgen05_st(), args));
}
if (num_useful_threads != num_threads) {
body =
IfThenElse(T.thread_var < T.thread_bounds->min + num_useful_threads,
Expand Down
27 changes: 27 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2767,6 +2767,33 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::tcgen05_ld())) {
ICHECK_EQ(op->args.size(), 6U) << "tcgen05_ld expects 6 arguments";
need_tcgen05_common_h_ = true;
int inst_bits = Downcast<IntImm>(op->args[0])->value;
int chunks = Downcast<IntImm>(op->args[1])->value;
bool pack16 = Downcast<Bool>(op->args[2])->value;
std::string tmem_start_col = this->PrintExpr(op->args[3]);
std::string col_offset = this->PrintExpr(op->args[4]);
std::string dst_ptr = this->PrintExpr(op->args[5]);
this->PrintIndent();
this->stream << "tl::tcgen05_ld_32dp" << inst_bits << "bNx<" << chunks
<< ", " << (pack16 ? "true" : "false") << ">("
<< tmem_start_col << ", " << col_offset << ", " << dst_ptr
<< ");\n";
} else if (op->op.same_as(tl::tcgen05_st())) {
ICHECK_EQ(op->args.size(), 6U) << "tcgen05_st expects 6 arguments";
int inst_bits = Downcast<IntImm>(op->args[0])->value;
int chunks = Downcast<IntImm>(op->args[1])->value;
bool unpack16 = Downcast<Bool>(op->args[2])->value;
std::string tmem_start_col = this->PrintExpr(op->args[3]);
std::string col_offset = this->PrintExpr(op->args[4]);
std::string src_ptr = this->PrintExpr(op->args[5]);
this->PrintIndent();
this->stream << "tl::tcgen05_st_32dp" << inst_bits << "bNx<" << chunks
<< ", " << (unpack16 ? "true" : "false") << ">("
<< tmem_start_col << ", " << col_offset << ", " << src_ptr
<< ");\n";
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument";
need_tcgen05_common_h_ = true;
Expand Down
24 changes: 24 additions & 0 deletions src/target/codegen_cutedsl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,30 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
<< "[0] + " << c_offset << ", " << desc_val << ", " << scale_out
<< ", " << mask0 << ", " << mask1 << ", " << mask2 << ", " << mask3
<< ")\n";
} else if (op->op.same_as(tl::tcgen05_ld())) {
ICHECK_EQ(op->args.size(), 6U) << "tcgen05_ld expects 6 arguments";
int inst_bits = Downcast<IntImm>(op->args[0])->value;
int chunks = Downcast<IntImm>(op->args[1])->value;
bool pack16 = Downcast<Bool>(op->args[2])->value;
std::string tmem_start_col = PrintExpr_(op->args[3]);
std::string col_offset = PrintExpr_(op->args[4]);
std::string dst_ptr = PrintExpr_(op->args[5]);
PrintIndent();
stream << "tl.tcgen05_ld_32dp" << inst_bits << "bNx(" << chunks << ", "
<< (pack16 ? "True" : "False") << ", " << tmem_start_col << ", "
<< col_offset << ", " << dst_ptr << ")\n";
} else if (op->op.same_as(tl::tcgen05_st())) {
ICHECK_EQ(op->args.size(), 6U) << "tcgen05_st expects 6 arguments";
int inst_bits = Downcast<IntImm>(op->args[0])->value;
int chunks = Downcast<IntImm>(op->args[1])->value;
bool unpack16 = Downcast<Bool>(op->args[2])->value;
std::string tmem_start_col = PrintExpr_(op->args[3]);
std::string col_offset = PrintExpr_(op->args[4]);
std::string src_ptr = PrintExpr_(op->args[5]);
PrintIndent();
stream << "tl.tcgen05_st_32dp" << inst_bits << "bNx(" << chunks << ", "
<< (unpack16 ? "True" : "False") << ", " << tmem_start_col << ", "
<< col_offset << ", " << src_ptr << ")\n";
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument";
PrintIndent();
Expand Down
Loading
Loading