Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,27 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
if (completed_)
return {};
LayoutMap results;

if (const auto f = ffi::Function::GetGlobal("tl.gemm.infer_layout")) {
results = Downcast<LayoutMap>(
auto inferred_layouts = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<Gemm>(this), T.target, T.thread_bounds));
// Bind all fragment layouts with the provided thread range
for (auto kv : results) {
// For MMA instructions, skip shared buffer layouts that are already
// inferred by a prior operator to avoid layout conflicts when the same
// shared buffer is consumed by multiple gemm ops with different transpose
// semantics. WGMMA/TCGEN5MMA have strict shared memory layout requirements
// and must always set their layouts.
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = getGemmInst(block_size, T.target);
bool is_mma = (gemm_inst == GemmInst::kMMA);
for (auto kv : inferred_layouts) {
const Buffer &buf = kv.first;
const Layout &layout = kv.second;
if (is_mma && IsSharedBuffer(buf) && T.layout_map.count(buf)) {
continue;
}
if (auto frag = layout.as<Fragment>()) {
results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds));
} else {
results.Set(buf, layout);
}
}
} else {
Expand Down
Loading