[BugFix] Skip MMA shared buffer layout inference when layout already exists#2008
[BugFix] Skip MMA shared buffer layout inference when layout already exists#2008kurisu6912 wants to merge 3 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughGemm/GemmSP/GemmSPPy layout inference now reports shared-memory buffers as "strict" for Hopper/WGMMA/TCGEN5MMA paths and unconditionally requires the FFI infer_layout functions; the layout inference engine was extended to track, promote, and enforce strict buffer layouts across merge phases. Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…exists (tile-ai#1997) When a shared memory buffer is consumed by multiple gemm operations with different transpose semantics, each gemm infers a different swizzle layout, causing a layout conflict error. For MMA instructions, the swizzle layout is only a bank conflict optimization, not a correctness requirement. Skip layout inference for shared buffers that already have a layout inferred by a prior operator. WGMMA/TCGEN5MMA/MFMA retain strict layout enforcement. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
b4d5af2 to
a9c0cfb
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/gemm_py.cc`:
- Around line 322-335: The python infer path in gemm_py currently skips setting
a shared-buffer layout if T.layout_map already contains an entry, which allows
an earlier MMA-backed gemm_py to block a later strict-layout consumer
(WGMMA/TCGEN5MMA); update the gemm_py visitor loop (the same block using
getGemmInst, GemmInst, is_mma, IsSharedBuffer and T.layout_map) so that
strict-layout GEMM kinds (WGMMA/TCGEN5MMA) always set/override the shared buffer
layout regardless of an existing T.layout_map entry—i.e., only skip when the
current op is non-strict and an entry already exists, but do not skip (force
set/override) when the current gemm is a strict-layout consumer.
In `@src/op/gemm.cc`:
- Around line 688-695: The current logic only seeds a layout when T.layout_map
lacks a_ (e.g., in the SM100 MMA branch using makeGemmABLayout and
results.Set/ExpandLayoutToMatchBuffer), which permits a non-strict MMA consumer
to win if it visits first; change the behavior so strict-layout consumers
(TCGEN5MMA/WGMMA/other strict paths) always override any existing layout for the
shared buffer regardless of T.layout_map presence — i.e., detect the strict
consumer path and call the same layout creation +
results.Set/ExpandLayoutToMatchBuffer unconditionally (or use an overwrite
variant) for a_ so strict layouts win independent of visit order.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: aea94e31-4616-47ac-bdaf-b699ce1973ec
📒 Files selected for processing (2)
src/op/gemm.ccsrc/op/gemm_py.cc
There was a problem hiding this comment.
♻️ Duplicate comments (1)
src/op/gemm_py.cc (1)
322-335:⚠️ Potential issue | 🟠 MajorSkip-on-existing still leaves strict shared layouts visit-order dependent.
If an MMA-backed
gemm_pyis visited before a later WGMMA/TCGEN5MMA/MFMA consumer of the same shared buffer,T.layout_mapis still empty here, so this path still publishes the MMA swizzle. The later strict op can then hit the same conflicting-layout failure during merge. This needs a strict-wins precedence, not just a skip-when-present check.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm_py.cc` around lines 322 - 335, The current early-skip for MMA (in the loop using getGemmInst, is_mma, raw_results, IsSharedBuffer and T.layout_map) makes layout outcome depend on visitation order; change the logic so "strict" gemm types (WGMMA/TCGEN5MMA/MFMA — identify by GemmInst enum values or a helper like is_strict_gemm) always win: either (A) when encountering an MMA (is_mma) skip writing a layout only if you can prove no strict consumer will touch the same shared buffer later, or preferably (B) make strict gemm handling idempotent by always setting/overwriting T.layout_map[buf] for shared buffers when the current gemm is strict so it supersedes any prior MMA-written layout. Update the decision in the loop (the if using is_mma && IsSharedBuffer(buf) && T.layout_map.count(buf)) to implement one of these strict-wins behaviors and adjust any helpers that compute gemm kind (getGemmInst/GemmInst) accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@src/op/gemm_py.cc`:
- Around line 322-335: The current early-skip for MMA (in the loop using
getGemmInst, is_mma, raw_results, IsSharedBuffer and T.layout_map) makes layout
outcome depend on visitation order; change the logic so "strict" gemm types
(WGMMA/TCGEN5MMA/MFMA — identify by GemmInst enum values or a helper like
is_strict_gemm) always win: either (A) when encountering an MMA (is_mma) skip
writing a layout only if you can prove no strict consumer will touch the same
shared buffer later, or preferably (B) make strict gemm handling idempotent by
always setting/overwriting T.layout_map[buf] for shared buffers when the current
gemm is strict so it supersedes any prior MMA-written layout. Update the
decision in the loop (the if using is_mma && IsSharedBuffer(buf) &&
T.layout_map.count(buf)) to implement one of these strict-wins behaviors and
adjust any helpers that compute gemm kind (getGemmInst/GemmInst) accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b7095327-ee10-44bc-8da5-29c0918d8ee6
📒 Files selected for processing (2)
src/op/gemm.ccsrc/op/gemm_py.cc
🚧 Files skipped from review as they are similar to previous changes (1)
- src/op/gemm.cc
…shared-buffer-conflict
…shared-buffer-conflict # Conflicts: # src/op/gemm.cc # src/op/gemm_py.cc
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/gemm_sp_py.cc`:
- Around line 298-310: The current Hopper-specific strict-buffer marking in
op/gemm_sp_py.cc unconditionally marks non-Fragment shared buffers strict when
TargetIsHopper(T.target) is true; change this to only mark strict when WGMMA
would actually be selected (mirror gemm_sp.cc). Compute block_size from
T.thread_bounds->extent (use T.thread_bounds->extent->value or equivalent),
compute maybe_wgmma using the same conditions as gemm_sp.cc (m_ >= 64 and
(block_size / warp_size) % 4 == 0), and guard the strict-buffer loop with that
maybe_wgmma check; keep existing references to T.strict_buffers_out, results,
Buffer, Layout, Fragment, and IsSharedBuffer. Ensure no changes to behavior on
non-Hopper targets.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 30035e30-aebd-4745-9f39-379cd413651b
📒 Files selected for processing (5)
src/op/gemm.ccsrc/op/gemm_sp.ccsrc/op/gemm_sp_py.ccsrc/op/operator.hsrc/transform/layout_inference.cc
src/op/gemm_sp_py.cc
Outdated
| // On Hopper, sparse gemm uses WGMMA which imposes exact shared-memory | ||
| // swizzle requirements. Mark shared A/B strict so the merge site will not | ||
| // overwrite them with non-strict updates from other ops. Ampere sparse | ||
| // remains non-strict (flexible MMA path). | ||
| if (TargetIsHopper(T.target) && T.strict_buffers_out != nullptr) { | ||
| for (auto kv : results) { | ||
| const Buffer &buf = kv.first; | ||
| const Layout &layout = kv.second; | ||
| if (!layout.as<Fragment>() && IsSharedBuffer(buf)) { | ||
| T.strict_buffers_out->push_back(buf); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if GemmSPPy has a method to determine WGMMA usage like GemmSPNode does
rg -n "GetGemmInst|maybe_wgmma|kWGMMA" src/op/gemm_sp_py.cc src/op/gemm_sp_py.h
# Check the FFI function to understand what layouts it infers for Hopper
rg -n "gemm_sp_py.infer_layout" --type pyRepository: tile-ai/tilelang
Length of output: 818
🏁 Script executed:
# First, check the context around lines 298-310 to see the full picture
sed -n '290,320p' src/op/gemm_sp_py.cc
# Then check the CheckWGMMA() implementation
sed -n '168,180p' src/op/gemm_sp_py.cc
# Check GetGemmInst implementation
sed -n '120,140p' src/op/gemm_sp_py.ccRepository: tile-ai/tilelang
Length of output: 2393
🏁 Script executed:
# Check what function contains lines 298-310
sed -n '250,300p' src/op/gemm_sp_py.cc | head -60
# Also check function signature to see what parameters are available
rg -n "def infer_layout|LayoutMap.*infer_layout" -A 5 src/op/gemm_sp_py.ccRepository: tile-ai/tilelang
Length of output: 2266
🏁 Script executed:
# Check what LayoutInferArgs contains
rg -n "struct LayoutInferArgs|class LayoutInferArgs" -A 15
# Check if block_size is computed from thread_bounds anywhere
rg -n "block_size|thread_bounds" src/op/gemm_sp_py.cc | head -20
# Look at gemm_sp.cc InferLayout to see how it handles strict marking
rg -n "InferLayout" src/op/gemm_sp.cc -A 30Repository: tile-ai/tilelang
Length of output: 3128
🏁 Script executed:
# Check if LayoutInferArgs has block_size or if we can extract it from thread_bounds
rg -n "struct LayoutInferArgs" -A 20 src/op/operator.h
# Check if gemm_sp_py.cc ever computes block_size from thread_bounds
rg -n "thread_bounds.*extent|block_size.*thread_bounds" src/op/gemm_sp_py.cc
# See the complete strict_shared logic in gemm_sp.cc
sed -n '252,285p' src/op/gemm_sp.ccRepository: tile-ai/tilelang
Length of output: 2930
🏁 Script executed:
# Check if gemm_sp_py.cc ever extracts block_size like gemm_sp.cc does
grep -n "thread_bounds->extent\|as_const_int" src/op/gemm_sp_py.cc
# Look at the complete context of the strict marking code in gemm_sp_py.cc
sed -n '295,315p' src/op/gemm_sp_py.ccRepository: tile-ai/tilelang
Length of output: 777
Fix: Strict marking for Hopper should be conditional on WGMMA actually being used.
The code unconditionally marks all non-Fragment shared buffers as strict for Hopper, but CheckWGMMA() (line 168) returns false unconditionally ("not supported yet"). Since WGMMA is never actually selected, this is overly conservative and differs from gemm_sp.cc's approach.
In gemm_sp.cc, strict marking only occurs when maybe_wgmma is true—computed by checking m_ >= 64 and block_size / warp_size % 4 == 0. Apply the same logic here: extract block_size from T.thread_bounds->extent, compute whether WGMMA would be used, and only mark shared buffers strict when necessary. This prevents unnecessary strict buffers from blocking valid layout merges on Hopper.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/gemm_sp_py.cc` around lines 298 - 310, The current Hopper-specific
strict-buffer marking in op/gemm_sp_py.cc unconditionally marks non-Fragment
shared buffers strict when TargetIsHopper(T.target) is true; change this to only
mark strict when WGMMA would actually be selected (mirror gemm_sp.cc). Compute
block_size from T.thread_bounds->extent (use T.thread_bounds->extent->value or
equivalent), compute maybe_wgmma using the same conditions as gemm_sp.cc (m_ >=
64 and (block_size / warp_size) % 4 == 0), and guard the strict-buffer loop with
that maybe_wgmma check; keep existing references to T.strict_buffers_out,
results, Buffer, Layout, Fragment, and IsSharedBuffer. Ensure no changes to
behavior on non-Hopper targets.
|
@regression-perf |
3bbc4db to
2209baf
Compare
Summary
T.gemmoperations with different transpose semantics, each gemm infers a different swizzle layout for the shared buffer, causing a layout conflict errorTest plan
🤖 Generated with Claude Code
Summary by CodeRabbit