Skip to content

[BugFix] Skip MMA shared buffer layout inference when layout already exists#2008

Open
kurisu6912 wants to merge 3 commits intotile-ai:mainfrom
kurisu6912:fix/layout-inference-shared-buffer-conflict
Open

[BugFix] Skip MMA shared buffer layout inference when layout already exists#2008
kurisu6912 wants to merge 3 commits intotile-ai:mainfrom
kurisu6912:fix/layout-inference-shared-buffer-conflict

Conversation

@kurisu6912
Copy link
Copy Markdown
Collaborator

@kurisu6912 kurisu6912 commented Apr 2, 2026

Summary

  • Fixes [BUG] LayoutInference fails when shared buffer is used in two gemm ops with different transpose modes #1997
  • When a shared memory buffer is consumed by multiple T.gemm operations with different transpose semantics, each gemm infers a different swizzle layout for the shared buffer, causing a layout conflict error
  • For MMA instructions, the swizzle layout is only a bank conflict optimization, not a correctness requirement. Skip layout inference for shared buffers that already have a layout inferred by a prior operator
  • WGMMA / TCGEN5MMA / MFMA retain strict layout enforcement since they have hard shared memory layout requirements

Test plan

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes
    • Improved layout inference and strictness handling for matrix-multiply paths, fixing incorrect merges and making shared-memory buffer layouts more reliable across targets (including Hopper/WGMMA).
  • Chores
    • Internal bookkeeping now preserves promoted strict buffer claims across inference phases; no public APIs changed.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 2, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

📝 Walkthrough

Walkthrough

Gemm/GemmSP/GemmSPPy layout inference now reports shared-memory buffers as "strict" for Hopper/WGMMA/TCGEN5MMA paths and unconditionally requires the FFI infer_layout functions; the layout inference engine was extended to track, promote, and enforce strict buffer layouts across merge phases.

Changes

Cohort / File(s) Summary
GEMM core ops
src/op/gemm.cc, src/op/gemm_sp.cc, src/op/gemm_sp_py.cc
Unconditionally resolve FFI infer_layout functions; record inferred raw layouts then merge into results; for Hopper/WGMMA/TCGEN5MMA paths, expand inferred shared-memory layouts and append shared A/B buffers to LayoutInferArgs::strict_buffers_out when provided. Added null-checks and stricter validation for thread-range extents.
Layout inference engine
src/transform/layout_inference.cc
Add persistent promoted_strict_, thread per-attempt promoted strict set, pass promoted_strict_out through RunInferStep/FinishInferQueue/InferInFreeMode. Merge logic updated to honor strictness (strict+strict requires equality; strict wins over non-strict; promote on overwrite). Alias propagation updated to propagate strictness and promote aliases.
API surface
src/op/operator.h
Extend LayoutInferArgs with nullable out-parameter Array<Buffer>* strict_buffers_out = nullptr to allow ops to report strict shared-buffer claims.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Op as Op::InferLayout (Gemm/GemmSP/GemmSPPy)
participant LayoutArgs as LayoutInferArgs (includes strict_buffers_out)
participant InferEngine as LayoutInference::RunInferStep
participant Map as LayoutMap / promoted_strict_
Op->>LayoutArgs: call InferLayout(..., strict_buffers_out)
Op-->>LayoutArgs: append shared buffers to strict_buffers_out (Hopper/WGMMA/TCGEN5MMA)
LayoutArgs->>InferEngine: RunInferStep(results, strict_buffers_out)
InferEngine->>Map: consult existing layout_map and promoted_strict_
InferEngine->>Map: apply merge rules (strict vs non-strict cases)
InferEngine->>Map: promote strict buffers into promoted_strict_ / attempt set
InferEngine->>Op: return merged results / updated promoted_strict_

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Gongen-Ali

Poem

🐰 With tiny paws I hop and cheer,
Shared buffers now hold layouts dear.
Hopper, WGMMA, TC5 in tune,
Strictly set beneath the moon.
A rabbit’s nibble fixed the map — hooray! 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: implementing skip logic for MMA shared buffer layout inference when a layout already exists, addressing the core issue in #1997.
Linked Issues check ✅ Passed The code changes implement the objective from #1997 by introducing strict buffer tracking and selective enforcement: MMA instructions skip re-inferencing existing shared layouts while WGMMA/TCGEN5MMA/MFMA retain strict enforcement.
Out of Scope Changes check ✅ Passed All changes directly support the stated objectives of preventing layout conflicts for MMA operations with different transpose modes while maintaining strict enforcement for WGMMA/TCGEN5MMA/MFMA.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…exists (tile-ai#1997)

When a shared memory buffer is consumed by multiple gemm operations with
different transpose semantics, each gemm infers a different swizzle layout,
causing a layout conflict error. For MMA instructions, the swizzle layout
is only a bank conflict optimization, not a correctness requirement. Skip
layout inference for shared buffers that already have a layout inferred by
a prior operator. WGMMA/TCGEN5MMA/MFMA retain strict layout enforcement.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kurisu6912 kurisu6912 force-pushed the fix/layout-inference-shared-buffer-conflict branch from b4d5af2 to a9c0cfb Compare April 2, 2026 03:52
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/gemm_py.cc`:
- Around line 322-335: The python infer path in gemm_py currently skips setting
a shared-buffer layout if T.layout_map already contains an entry, which allows
an earlier MMA-backed gemm_py to block a later strict-layout consumer
(WGMMA/TCGEN5MMA); update the gemm_py visitor loop (the same block using
getGemmInst, GemmInst, is_mma, IsSharedBuffer and T.layout_map) so that
strict-layout GEMM kinds (WGMMA/TCGEN5MMA) always set/override the shared buffer
layout regardless of an existing T.layout_map entry—i.e., only skip when the
current op is non-strict and an entry already exists, but do not skip (force
set/override) when the current gemm is a strict-layout consumer.

In `@src/op/gemm.cc`:
- Around line 688-695: The current logic only seeds a layout when T.layout_map
lacks a_ (e.g., in the SM100 MMA branch using makeGemmABLayout and
results.Set/ExpandLayoutToMatchBuffer), which permits a non-strict MMA consumer
to win if it visits first; change the behavior so strict-layout consumers
(TCGEN5MMA/WGMMA/other strict paths) always override any existing layout for the
shared buffer regardless of T.layout_map presence — i.e., detect the strict
consumer path and call the same layout creation +
results.Set/ExpandLayoutToMatchBuffer unconditionally (or use an overwrite
variant) for a_ so strict layouts win independent of visit order.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: aea94e31-4616-47ac-bdaf-b699ce1973ec

📥 Commits

Reviewing files that changed from the base of the PR and between a82fa71 and b4d5af2.

📒 Files selected for processing (2)
  • src/op/gemm.cc
  • src/op/gemm_py.cc

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
src/op/gemm_py.cc (1)

322-335: ⚠️ Potential issue | 🟠 Major

Skip-on-existing still leaves strict shared layouts visit-order dependent.

If an MMA-backed gemm_py is visited before a later WGMMA/TCGEN5MMA/MFMA consumer of the same shared buffer, T.layout_map is still empty here, so this path still publishes the MMA swizzle. The later strict op can then hit the same conflicting-layout failure during merge. This needs a strict-wins precedence, not just a skip-when-present check.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm_py.cc` around lines 322 - 335, The current early-skip for MMA (in
the loop using getGemmInst, is_mma, raw_results, IsSharedBuffer and
T.layout_map) makes layout outcome depend on visitation order; change the logic
so "strict" gemm types (WGMMA/TCGEN5MMA/MFMA — identify by GemmInst enum values
or a helper like is_strict_gemm) always win: either (A) when encountering an MMA
(is_mma) skip writing a layout only if you can prove no strict consumer will
touch the same shared buffer later, or preferably (B) make strict gemm handling
idempotent by always setting/overwriting T.layout_map[buf] for shared buffers
when the current gemm is strict so it supersedes any prior MMA-written layout.
Update the decision in the loop (the if using is_mma && IsSharedBuffer(buf) &&
T.layout_map.count(buf)) to implement one of these strict-wins behaviors and
adjust any helpers that compute gemm kind (getGemmInst/GemmInst) accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@src/op/gemm_py.cc`:
- Around line 322-335: The current early-skip for MMA (in the loop using
getGemmInst, is_mma, raw_results, IsSharedBuffer and T.layout_map) makes layout
outcome depend on visitation order; change the logic so "strict" gemm types
(WGMMA/TCGEN5MMA/MFMA — identify by GemmInst enum values or a helper like
is_strict_gemm) always win: either (A) when encountering an MMA (is_mma) skip
writing a layout only if you can prove no strict consumer will touch the same
shared buffer later, or preferably (B) make strict gemm handling idempotent by
always setting/overwriting T.layout_map[buf] for shared buffers when the current
gemm is strict so it supersedes any prior MMA-written layout. Update the
decision in the loop (the if using is_mma && IsSharedBuffer(buf) &&
T.layout_map.count(buf)) to implement one of these strict-wins behaviors and
adjust any helpers that compute gemm kind (getGemmInst/GemmInst) accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b7095327-ee10-44bc-8da5-29c0918d8ee6

📥 Commits

Reviewing files that changed from the base of the PR and between b4d5af2 and a9c0cfb.

📒 Files selected for processing (2)
  • src/op/gemm.cc
  • src/op/gemm_py.cc
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/gemm.cc

…shared-buffer-conflict

# Conflicts:
#	src/op/gemm.cc
#	src/op/gemm_py.cc
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/gemm_sp_py.cc`:
- Around line 298-310: The current Hopper-specific strict-buffer marking in
op/gemm_sp_py.cc unconditionally marks non-Fragment shared buffers strict when
TargetIsHopper(T.target) is true; change this to only mark strict when WGMMA
would actually be selected (mirror gemm_sp.cc). Compute block_size from
T.thread_bounds->extent (use T.thread_bounds->extent->value or equivalent),
compute maybe_wgmma using the same conditions as gemm_sp.cc (m_ >= 64 and
(block_size / warp_size) % 4 == 0), and guard the strict-buffer loop with that
maybe_wgmma check; keep existing references to T.strict_buffers_out, results,
Buffer, Layout, Fragment, and IsSharedBuffer. Ensure no changes to behavior on
non-Hopper targets.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 30035e30-aebd-4745-9f39-379cd413651b

📥 Commits

Reviewing files that changed from the base of the PR and between 2209baf and 3bbc4db.

📒 Files selected for processing (5)
  • src/op/gemm.cc
  • src/op/gemm_sp.cc
  • src/op/gemm_sp_py.cc
  • src/op/operator.h
  • src/transform/layout_inference.cc

Comment on lines 298 to 310
// On Hopper, sparse gemm uses WGMMA which imposes exact shared-memory
// swizzle requirements. Mark shared A/B strict so the merge site will not
// overwrite them with non-strict updates from other ops. Ampere sparse
// remains non-strict (flexible MMA path).
if (TargetIsHopper(T.target) && T.strict_buffers_out != nullptr) {
for (auto kv : results) {
const Buffer &buf = kv.first;
const Layout &layout = kv.second;
if (!layout.as<Fragment>() && IsSharedBuffer(buf)) {
T.strict_buffers_out->push_back(buf);
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if GemmSPPy has a method to determine WGMMA usage like GemmSPNode does
rg -n "GetGemmInst|maybe_wgmma|kWGMMA" src/op/gemm_sp_py.cc src/op/gemm_sp_py.h

# Check the FFI function to understand what layouts it infers for Hopper
rg -n "gemm_sp_py.infer_layout" --type py

Repository: tile-ai/tilelang

Length of output: 818


🏁 Script executed:

# First, check the context around lines 298-310 to see the full picture
sed -n '290,320p' src/op/gemm_sp_py.cc

# Then check the CheckWGMMA() implementation
sed -n '168,180p' src/op/gemm_sp_py.cc

# Check GetGemmInst implementation
sed -n '120,140p' src/op/gemm_sp_py.cc

Repository: tile-ai/tilelang

Length of output: 2393


🏁 Script executed:

# Check what function contains lines 298-310
sed -n '250,300p' src/op/gemm_sp_py.cc | head -60

# Also check function signature to see what parameters are available
rg -n "def infer_layout|LayoutMap.*infer_layout" -A 5 src/op/gemm_sp_py.cc

Repository: tile-ai/tilelang

Length of output: 2266


🏁 Script executed:

# Check what LayoutInferArgs contains
rg -n "struct LayoutInferArgs|class LayoutInferArgs" -A 15 

# Check if block_size is computed from thread_bounds anywhere
rg -n "block_size|thread_bounds" src/op/gemm_sp_py.cc | head -20

# Look at gemm_sp.cc InferLayout to see how it handles strict marking
rg -n "InferLayout" src/op/gemm_sp.cc -A 30

Repository: tile-ai/tilelang

Length of output: 3128


🏁 Script executed:

# Check if LayoutInferArgs has block_size or if we can extract it from thread_bounds
rg -n "struct LayoutInferArgs" -A 20 src/op/operator.h

# Check if gemm_sp_py.cc ever computes block_size from thread_bounds
rg -n "thread_bounds.*extent|block_size.*thread_bounds" src/op/gemm_sp_py.cc

# See the complete strict_shared logic in gemm_sp.cc
sed -n '252,285p' src/op/gemm_sp.cc

Repository: tile-ai/tilelang

Length of output: 2930


🏁 Script executed:

# Check if gemm_sp_py.cc ever extracts block_size like gemm_sp.cc does
grep -n "thread_bounds->extent\|as_const_int" src/op/gemm_sp_py.cc

# Look at the complete context of the strict marking code in gemm_sp_py.cc
sed -n '295,315p' src/op/gemm_sp_py.cc

Repository: tile-ai/tilelang

Length of output: 777


Fix: Strict marking for Hopper should be conditional on WGMMA actually being used.

The code unconditionally marks all non-Fragment shared buffers as strict for Hopper, but CheckWGMMA() (line 168) returns false unconditionally ("not supported yet"). Since WGMMA is never actually selected, this is overly conservative and differs from gemm_sp.cc's approach.

In gemm_sp.cc, strict marking only occurs when maybe_wgmma is true—computed by checking m_ >= 64 and block_size / warp_size % 4 == 0. Apply the same logic here: extract block_size from T.thread_bounds->extent, compute whether WGMMA would be used, and only mark shared buffers strict when necessary. This prevents unnecessary strict buffers from blocking valid layout merges on Hopper.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm_sp_py.cc` around lines 298 - 310, The current Hopper-specific
strict-buffer marking in op/gemm_sp_py.cc unconditionally marks non-Fragment
shared buffers strict when TargetIsHopper(T.target) is true; change this to only
mark strict when WGMMA would actually be selected (mirror gemm_sp.cc). Compute
block_size from T.thread_bounds->extent (use T.thread_bounds->extent->value or
equivalent), compute maybe_wgmma using the same conditions as gemm_sp.cc (m_ >=
64 and (block_size / warp_size) % 4 == 0), and guard the strict-buffer loop with
that maybe_wgmma check; keep existing references to T.strict_buffers_out,
results, Buffer, Layout, Fragment, and IsSharedBuffer. Ensure no changes to
behavior on non-Hopper targets.

@LeiWang1999
Copy link
Copy Markdown
Member

@regression-perf

@LeiWang1999 LeiWang1999 force-pushed the fix/layout-inference-shared-buffer-conflict branch from 3bbc4db to 2209baf Compare April 14, 2026 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] LayoutInference fails when shared buffer is used in two gemm ops with different transpose modes

2 participants