[Metal] Add Metal GEMM support with simdgroup intrinsics#1869
[Metal] Add Metal GEMM support with simdgroup intrinsics#1869oraluben wants to merge 20 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Apple Metal/MPS support: new GEMM Metal backend, simdgroup buffer scope and allocators, Metal-specific intrinsics emitter, copy/fill/vectorize changes for simdgroup handling, Metal-target dispatch and tests, and dependency adjustments for Darwin. Changes
Sequence DiagramsequenceDiagram
participant Host as Host/PyTorch
participant Dispatch as GemmPy Dispatch
participant Backend as GemmMetal
participant Emitter as MPSIntrinEmitter
participant Metal as Metal Runtime
Host->>Dispatch: Request GEMM (Target=Metal)
Dispatch->>Backend: Route to GemmMetal (is_metal)
Backend->>Backend: Compute warp partitions / prepare buffers
Backend->>Emitter: Init with tiling & dtypes
loop per K block
Backend->>Emitter: ldmatrix_a / ldmatrix_b
Emitter->>Metal: simdgroup_load / simdgroup_multiply_accumulate
Metal-->>Emitter: partial accumulators
end
Backend->>Emitter: simdgroup_store C_simd -> C_buf
Backend-->>Host: Return lowered kernel / kernel source
Host->>Metal: Execute kernel on device
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/transform/decouple_type_cast.py (1)
92-95:⚠️ Potential issue | 🟡 MinorError message doesn't mention
metal.simdgroupas a known scope.Since
is_local_buffernow acceptsmetal.simdgroup, the error message should list it as a valid scope for completeness.Suggested fix
raise ValueError( f"Unknown buffer scope '{buffer.scope()}' for buffer '{buffer.name}'. " - f"Expected one of: local, local.fragment, local.var, global, shared, shared.dyn" + f"Expected one of: local, local.fragment, local.var, metal.simdgroup, global, shared, shared.dyn" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 92 - 95, The ValueError raised in decouple_type_cast.py lists valid buffer scopes but omits "metal.simdgroup"; update the error message (the raise ValueError that references buffer.scope() and buffer.name) to include "metal.simdgroup" in the Expected one of: list so it matches the scopes accepted by is_local_buffer/other checks and clearly communicates valid scopes.
🧹 Nitpick comments (7)
src/op/copy.cc (1)
639-645: Method nameCheckSIMDGroupStoreis misleading — it matches any simdgroup↔simdgroup copy.Both
srcanddstare checked for"metal.simdgroup"scope, meaning this matches loads and stores (or more accurately, simdgroup-to-simdgroup transfers). Compare withCheckLDSMCopy(shared→fragment) andCheckSTSMCopy(fragment→shared) which are directional. Consider renaming toCheckSIMDGroupCopy(and correspondinglyLowerSIMDGroupCopy) for consistency and clarity.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 639 - 645, The method CheckSIMDGroupStore incorrectly implies a directional store but actually matches any simdgroup↔simdgroup transfer; rename CheckSIMDGroupStore to CheckSIMDGroupCopy (and rename its lowering helper LowerSIMDGroupStore to LowerSIMDGroupCopy or similarly) and update all call sites to use the new names so the intent matches the implementation; keep the same body (checking src.scope() == "metal.simdgroup" && dst.scope() == "metal.simdgroup") but change identifiers for consistency with CheckLDSMCopy/CheckSTSMCopy naming conventions.src/op/gemm_py.cc (1)
136-137: Metal branch placement is correct; consider a minimalallowMetal()guard for future-proofingThe ordering (TCGEN5MMA → WGMMA → CDNA → CUDA → Metal → fallback) is correct — none of the earlier guards can fire on a Metal target. The change is logically sound.
For consistency with the other paths (which all have dedicated
allow*functions that validate shape, dtype, and scope constraints before returning their instruction type), a lightweightallowMetal()gate would help catch unsupported Metal configurations early rather than deferring to downstream errors.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm_py.cc` around lines 136 - 137, Add a lightweight allowMetal() guard before returning GemmInst::kMetalExp: implement allowMetal() similar to the other allow* helpers (validate shapes, dtypes, and memory scope constraints expected by the Metal path) and call it in the branch that currently checks TargetIsMetal(target); only return GemmInst::kMetalExp when allowMetal() returns true, otherwise fall back to the existing fallback/error path so unsupported Metal configs are rejected early. Reference: the TargetIsMetal(target) branch and GemmInst::kMetalExp; follow the validation pattern used by functions like allowCuda()/allowCdna()/allowWgmma() to ensure consistency.src/op/gemm.h (1)
42-42:kMetalExpnaming doesn't align with Python'sMETAL; enum values should be explicitTwo related concerns:
The C++ enumerator is named
kMetalExp(Exp = experimental), but the Python counterpart intilelang/tileop/gemm/inst.pyuses the plain nameMETAL = 4with no "experimental" qualifier. This asymmetry makes it unclear whether "experimental" is a meaningful status or just a stale suffix.The C++ enum assigns values implicitly (sequential 0–4). The Python
IntEnumassigns values explicitly (0–4). These are currently in sync, but inserting a new entry between existing ones would silently break the C++↔Python ABI. Adding explicit values in C++ is a cheap safeguard:♻️ Suggested: add explicit values and align naming
-enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, kMetalExp }; +enum class GemmInst : uint8_t { + kMMA = 0, + kWGMMA = 1, + kTCGEN5MMA = 2, + kMFMA = 3, + kMetal = 4, // align with Python GemmInst.METAL +};🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm.h` at line 42, The GemmInst enum currently uses an asymmetrical name and implicit values; change the enumerator kMetalExp to match the Python name (e.g., kMETAL) and make all enum members have explicit integer values matching the Python IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0, kWGMMA = 1, kTCGEN5MMA = 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type uint8_t, and update any usages of GemmInst::kMetalExp to the new symbol to preserve ABI alignment with tilelang/tileop/gemm/inst.py.testing/python/metal/test_metal_gemm_v2.py (3)
91-93: Redundanttorch.mps.is_available()guard silently swallows test output when Metal is absent.The individual test functions already carry
@tilelang.testing.requires_metal, which skips them if Metal is unavailable. Wrappingtilelang.testing.main()in an additionaltorch.mps.is_available()check means that when run on a non-Metal machine, no tests are registered at all — no "skipped" output, no indication tests exist. Compare withtest_metal_gemm_v2_linux.pywhich callstilelang.testing.main()unconditionally.🔧 Suggested fix
if __name__ == "__main__": - if torch.mps.is_available(): - tilelang.testing.main() + tilelang.testing.main()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 91 - 93, Remove the redundant torch.mps.is_available() guard in the __main__ block so that tilelang.testing.main() is called unconditionally; the tests themselves use the `@tilelang.testing.requires_metal` decorator (in this file's test functions) to mark/skips tests when Metal is absent, so update the __main__ section (the block checking __name__ == "__main__") to simply invoke tilelang.testing.main() without wrapping it in torch.mps.is_available().
80-83: Consider puttingrequires_metalas the outermost decorator for conventional ordering.Placing
@tilelang.testing.requires_metalinside@pytest.mark.xfailworks correctly (skip propagates), but the conventional pattern is to put guard/skip decorators outermost so the intent is immediately obvious at the call site.🎨 Suggested reorder
-@pytest.mark.xfail(reason="TODO: codegen not support float16x8") `@tilelang.testing.requires_metal` +@pytest.mark.xfail(reason="TODO: codegen not support float16x8") def test_gemm_v2_large():🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 80 - 83, Move the decorator order so the platform guard is outermost: swap the two decorators on test_gemm_v2_large so `@tilelang.testing.requires_metal` appears above `@pytest.mark.xfail`; update the decorators on the test_gemm_v2_large function accordingly to keep the same behavior but follow conventional ordering.
88-88: Add a comment explainingatol=1.0for the large-K test.For K=1024 with float16 inputs, accumulated rounding error per element can approach
K × ε_fp16 ≈ 1024 × 9.7e-4 ≈ 1.0, so the tolerance is intentional. A brief inline comment would prevent future readers from tightening it incorrectly.💬 Suggested change
- assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0) + # atol=1.0: with K=1024 fp16 inputs, accumulated rounding error ≈ K × ε_fp16 ≈ 1.0 + assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` at line 88, Add a brief inline comment next to the call to assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0) explaining that atol=1.0 is intentional because for K=1024 with float16 inputs the accumulated rounding error per element can approach K * ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute tolerance is required for this large-K test to avoid false failures.testing/python/metal/test_metal_gemm_v2_linux.py (1)
50-53: Misleading variable name and redundant target specification.Two minor issues:
- Line 50 sets
tvm.target.Target("metal")as a context manager and passestarget="metal"totilelang.loweron line 51. The context manager is redundant; the explicittarget=arg alone is sufficient (and matches the pattern in similar test files).- The return value of
tilelang.loweris namedartifact, butkernel_sourceis a property on the JIT kernel object (as shown intilelang/jit/kernel.py), not on a raw artifact. Naming itkernelwould be more accurate.🔧 Suggested cleanup
- with tvm.transform.PassContext(), tvm.target.Target("metal"): - artifact = tilelang.lower(func, target="metal") - - src_code = artifact.kernel_source + with tvm.transform.PassContext(): + kernel = tilelang.lower(func, target="metal") + + src_code = kernel.kernel_source🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 50 - 53, Remove the redundant PassContext target context and rename the returned value from tilelang.lower to reflect it's a JIT kernel: call tilelang.lower with the explicit target argument only (remove the tvm.transform.PassContext(), tvm.target.Target("metal") context manager) and rename the variable from artifact to kernel so you access kernel.kernel_source (i.e., locate the tilelang.lower call and the subsequent kernel_source access and update the target usage and variable name accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pyproject.toml`:
- Around line 33-35: The pyproject lower-bound for "apache-tvm-ffi" is too
permissive (">0.1.2") and allows Darwin installs of versions 0.1.3–0.1.5 that
contain the memory regression; update the constraint(s) for "apache-tvm-ffi" to
match requirements.txt by using a >=0.1.6 lower bound (e.g., change ">0.1.2" to
">=0.1.6" and ensure the Darwin-specific constraint "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" remains consistent with that lower bound).
In `@src/op/copy.cc`:
- Around line 685-686: GetCopyInst can return CopyInst::kMetalSIMDGroup but
Lower() lacks a matching case, causing a fatal crash; add a case for
CopyInst::kMetalSIMDGroup in Lower() (in the function Lower) before the kNormal
branch that implements the appropriate lowering behavior for SIMD-group metal
stores (mirror the pattern used for other Metal-specific cases in Lower(),
dispatching to the SIMD-group store lowering path and returning the resultant
statement), ensuring the new case references CopyInst::kMetalSIMDGroup so
execution no longer falls through to LOG(FATAL).
In `@src/op/copy.h`:
- Around line 232-235: The Doxygen comment for CheckSIMDGroupStore is
incorrect/copy-pasted from CheckTMemStore; update the comment to describe that
CheckSIMDGroupStore(Target target) determines whether the Metal SIMD-group
(SIMD-group/warp/wavefront) store instruction is supported on the given Target
(i.e., checks for Metal SIMD group store capability), not tensor memory store
support; reference CheckSIMDGroupStore and, for comparison, CheckTMemStore to
ensure the wording distinguishes SIMD group store support from tensor memory
store support.
- Around line 274-277: The CopyNode::Lower() switch must handle
CopyInst::kMetalSIMDGroup: add a case in Lower() that calls
LowerSIMDGroupStore(T, analyzer) when GetCopyInst(...) returns kMetalSIMDGroup
(this path is produced when CheckSIMDGroupStore(target) is true) to avoid the
LOG(FATAL) fallthrough; then implement the missing LowerSIMDGroupStore(const
LowerArgs &T, arith::Analyzer *analyzer) in copy.cc mirroring the pattern of
existing store lowerers (use the same argument handling and memory access
lowering used by other Metal-specific store methods, perform bounds/stride
checks as done in related Lower*Store functions, and produce the appropriate
Stmt), ensuring the declaration in copy.h matches the definition.
In `@src/op/gemm.cc`:
- Around line 150-154: Comments referencing a hardcoded "16" are now stale
because kMPerWarp is target-dependent; update the comments near kMPerWarp,
TargetIsMetal, and the logic that mentions "m_warp*16" and "Each warp needs at
least 16 elements in M" to reflect the runtime variable (e.g., refer to "m_warp
* kMPerWarp" and "kMPerWarp elements") so they accurately describe behavior on
Metal and other targets.
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 70-71: The test test_metal_gemm_v2_larger currently calls
assert_metal_gemm_v2_codegen with parameters known to fail at runtime; update
the test to either mark it as an expected failure or explicitly verify
codegen-only success: add `@pytest.mark.xfail`(reason="TODO: codegen not support
float16x8") above test_metal_gemm_v2_larger to match the runtime test, or
replace the single assert_metal_gemm_v2_codegen call with a two-step check that
runs the codegen path (tilelang.lower/codegen) and asserts it succeeds while
keeping execution separately marked xfail; reference test_metal_gemm_v2_larger
and assert_metal_gemm_v2_codegen when making the change.
- Around line 22-34: The codegen test's matmul_gemm_v2 kernel is structurally
different from the runtime test; update the matmul_gemm_v2 in the Linux codegen
test so its symbols match the runtime test: change C_local to use
T.alloc_shared(..., scope="shared") (instead of T.alloc_fragment), remove/adjust
coalesced_width=2 on T.copy calls to match the runtime test's copies, and use
T.gemm_v2 (or make the runtime test use T.gemm if you choose that canonical API)
so the GEMM operator is identical; ensure these changes are applied to the
matmul_gemm_v2 definition so the codegen pre-flight validates the same kernel
that the runtime test executes.
- Line 32: The test is calling the old lowering path via T.gemm instead of the
new Metal path; update the call to use T.gemm_v2 so the codegen test exercises
the gemm_v2 lowering (replace the invocation T.gemm(A_shared, B_shared, C_local)
with T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.
In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 43-44: The code computes self.warp_rows and self.warp_cols via
integer division of warp_row_tiles//micro_size_x and
warp_col_tiles//micro_size_y which will silently truncate if inputs are not
divisible by the micro sizes (8); add validation in the same initializer or
before these assignments (e.g., in the MetalMacroGenerator constructor or method
that sets warp_row_tiles/warp_col_tiles) that raises a clear error if
warp_row_tiles % micro_size_x != 0 or warp_col_tiles % micro_size_y != 0,
mentioning the offending values, and only then compute self.warp_rows and
self.warp_cols as the integer quotient.
In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-58: The method get_kernel_source currently claims to return str
but may return None and ignores the kernel_only flag; change its signature to ->
str | None (or keep -> str but assert/raise if kernel_global_source is None) and
implement the kernel_only branch: if kernel_only is True return
self.kernel_global_source, otherwise return the full Metal source (compose or
return the attribute that holds the complete module/source such as
self.metal_source or self.full_source); ensure you reference get_kernel_source
and kernel_global_source and either assert kernel_global_source is not None
before returning a str or update callers/types to accept Optional[str].
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 22-23: The int() cast on potentially symbolic shapes self.M and
self.N will fail at runtime for PrimExpr; update the computation of
warp_row_tiles and warp_col_tiles (currently int(self.M // m_warp) and
int(self.N // n_warp)) to preserve symbolic expressions instead of forcing
Python ints—either remove the int() and keep self.M // m_warp and self.N //
n_warp, or use tir.floordiv/tvm.tir.floordiv to produce a PrimExpr;
alternatively, if a concrete int is required, guard with an isinstance check for
tir.IntImm before casting. Ensure you change both warp_row_tiles and
warp_col_tiles and keep references to m_warp and n_warp.
---
Outside diff comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 92-95: The ValueError raised in decouple_type_cast.py lists valid
buffer scopes but omits "metal.simdgroup"; update the error message (the raise
ValueError that references buffer.scope() and buffer.name) to include
"metal.simdgroup" in the Expected one of: list so it matches the scopes accepted
by is_local_buffer/other checks and clearly communicates valid scopes.
---
Nitpick comments:
In `@src/op/copy.cc`:
- Around line 639-645: The method CheckSIMDGroupStore incorrectly implies a
directional store but actually matches any simdgroup↔simdgroup transfer; rename
CheckSIMDGroupStore to CheckSIMDGroupCopy (and rename its lowering helper
LowerSIMDGroupStore to LowerSIMDGroupCopy or similarly) and update all call
sites to use the new names so the intent matches the implementation; keep the
same body (checking src.scope() == "metal.simdgroup" && dst.scope() ==
"metal.simdgroup") but change identifiers for consistency with
CheckLDSMCopy/CheckSTSMCopy naming conventions.
In `@src/op/gemm_py.cc`:
- Around line 136-137: Add a lightweight allowMetal() guard before returning
GemmInst::kMetalExp: implement allowMetal() similar to the other allow* helpers
(validate shapes, dtypes, and memory scope constraints expected by the Metal
path) and call it in the branch that currently checks TargetIsMetal(target);
only return GemmInst::kMetalExp when allowMetal() returns true, otherwise fall
back to the existing fallback/error path so unsupported Metal configs are
rejected early. Reference: the TargetIsMetal(target) branch and
GemmInst::kMetalExp; follow the validation pattern used by functions like
allowCuda()/allowCdna()/allowWgmma() to ensure consistency.
In `@src/op/gemm.h`:
- Line 42: The GemmInst enum currently uses an asymmetrical name and implicit
values; change the enumerator kMetalExp to match the Python name (e.g., kMETAL)
and make all enum members have explicit integer values matching the Python
IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0, kWGMMA = 1, kTCGEN5MMA
= 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type uint8_t, and update any
usages of GemmInst::kMetalExp to the new symbol to preserve ABI alignment with
tilelang/tileop/gemm/inst.py.
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 50-53: Remove the redundant PassContext target context and rename
the returned value from tilelang.lower to reflect it's a JIT kernel: call
tilelang.lower with the explicit target argument only (remove the
tvm.transform.PassContext(), tvm.target.Target("metal") context manager) and
rename the variable from artifact to kernel so you access kernel.kernel_source
(i.e., locate the tilelang.lower call and the subsequent kernel_source access
and update the target usage and variable name accordingly).
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 91-93: Remove the redundant torch.mps.is_available() guard in the
__main__ block so that tilelang.testing.main() is called unconditionally; the
tests themselves use the `@tilelang.testing.requires_metal` decorator (in this
file's test functions) to mark/skips tests when Metal is absent, so update the
__main__ section (the block checking __name__ == "__main__") to simply invoke
tilelang.testing.main() without wrapping it in torch.mps.is_available().
- Around line 80-83: Move the decorator order so the platform guard is
outermost: swap the two decorators on test_gemm_v2_large so
`@tilelang.testing.requires_metal` appears above `@pytest.mark.xfail`; update the
decorators on the test_gemm_v2_large function accordingly to keep the same
behavior but follow conventional ordering.
- Line 88: Add a brief inline comment next to the call to assert_gemm_v2(1024,
1024, 1024, 16, 16, 16, atol=1.0) explaining that atol=1.0 is intentional
because for K=1024 with float16 inputs the accumulated rounding error per
element can approach K * ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute
tolerance is required for this large-K test to avoid false failures.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
pyproject.tomlrequirements.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/gemm.ccsrc/op/gemm.hsrc/op/gemm_py.ccsrc/op/parallel.ccsrc/op/utils.htesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/transform/decouple_type_cast.pytilelang/utils/language.py
| T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) | ||
| T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) | ||
|
|
||
| T.gemm(A_shared, B_shared, C_local) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the test file in question to see what's actually being called
find . -name "test_metal_gemm_v2_linux.py" -type f | head -1 | xargs cat -nRepository: tile-ai/tilelang
Length of output: 3353
🏁 Script executed:
# Search for gemm and gemm_v2 definitions in the language
rg -n "def gemm\b|def gemm_v2\b" --type py tilelang/language/ -A 5Repository: tile-ai/tilelang
Length of output: 743
🏁 Script executed:
# Check what tl.tileop operations are defined and used
rg -n "tl\.tileop\.gemm" --type py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 1314
🏁 Script executed:
# Check runtime test mentioned in review (test_metal_gemm_v2.py)
find . -name "test_metal_gemm_v2.py" -type f | head -1 | xargs cat -n 2>/dev/null || echo "File not found"Repository: tile-ai/tilelang
Length of output: 3454
🏁 Script executed:
# Search for T.gemm and T.gemm_v2 usage patterns
rg -n "T\.gemm\(|T\.gemm_v2\(" --type py | head -20Repository: tile-ai/tilelang
Length of output: 2497
Replace T.gemm with T.gemm_v2 on line 32.
T.gemm and T.gemm_v2 are distinct operations with different lowering paths (tl.tileop.gemm vs. tl.tileop.gemm_py). This test is meant to verify Metal codegen for the new gemm_v2 path, but it calls T.gemm instead. The assertions for simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store may be validating against the wrong lowering backend and could pass for the wrong reason or miss actual Metal simdgroup codegen issues.
The runtime test (test_metal_gemm_v2.py line 34) correctly uses T.gemm_v2; align this codegen test with it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` at line 32, The test is
calling the old lowering path via T.gemm instead of the new Metal path; update
the call to use T.gemm_v2 so the codegen test exercises the gemm_v2 lowering
(replace the invocation T.gemm(A_shared, B_shared, C_local) with
T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.
| def test_metal_gemm_v2_larger(): | ||
| assert_metal_gemm_v2_codegen(128, 128, 128, 32, 32, 32, dtype=T.float16) |
There was a problem hiding this comment.
test_metal_gemm_v2_larger may silently pass despite a known codegen limitation.
The runtime test test_gemm_v2_large in test_metal_gemm_v2.py tests the same parameters (128, 128, 128, 32, 32, 32, dtype=T.float16) and is marked @pytest.mark.xfail(reason="TODO: codegen not support float16x8"). If the codegen for this block configuration is genuinely broken, the Linux codegen test should also fail (and be expected to fail) for consistency. Either:
- Add
@pytest.mark.xfail(reason="TODO: codegen not support float16x8")here as well, or - Verify that the codegen-only path (
tilelang.lowerwithout execution) succeeds where execution fails.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 70 - 71, The
test test_metal_gemm_v2_larger currently calls assert_metal_gemm_v2_codegen with
parameters known to fail at runtime; update the test to either mark it as an
expected failure or explicitly verify codegen-only success: add
`@pytest.mark.xfail`(reason="TODO: codegen not support float16x8") above
test_metal_gemm_v2_larger to match the runtime test, or replace the single
assert_metal_gemm_v2_codegen call with a two-step check that runs the codegen
path (tilelang.lower/codegen) and asserts it succeeds while keeping execution
separately marked xfail; reference test_metal_gemm_v2_larger and
assert_metal_gemm_v2_codegen when making the change.
| self.warp_rows = warp_row_tiles // self.micro_size_x | ||
| self.warp_cols = warp_col_tiles // self.micro_size_y |
There was a problem hiding this comment.
Add validation that warp_row_tiles and warp_col_tiles are divisible by 8.
If warp_row_tiles or warp_col_tiles aren't multiples of micro_size_x/micro_size_y (8), the integer division would silently truncate, leading to incorrect tiling.
Suggested validation
# Number of 8x8 tiles per warp
+ assert warp_row_tiles % self.micro_size_x == 0, (
+ f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})"
+ )
+ assert warp_col_tiles % self.micro_size_y == 0, (
+ f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})"
+ )
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.warp_rows = warp_row_tiles // self.micro_size_x | |
| self.warp_cols = warp_col_tiles // self.micro_size_y | |
| # Number of 8x8 tiles per warp | |
| assert warp_row_tiles % self.micro_size_x == 0, ( | |
| f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})" | |
| ) | |
| assert warp_col_tiles % self.micro_size_y == 0, ( | |
| f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})" | |
| ) | |
| self.warp_rows = warp_row_tiles // self.micro_size_x | |
| self.warp_cols = warp_col_tiles // self.micro_size_y |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/intrinsics/metal_macro_generator.py` around lines 43 - 44, The code
computes self.warp_rows and self.warp_cols via integer division of
warp_row_tiles//micro_size_x and warp_col_tiles//micro_size_y which will
silently truncate if inputs are not divisible by the micro sizes (8); add
validation in the same initializer or before these assignments (e.g., in the
MetalMacroGenerator constructor or method that sets
warp_row_tiles/warp_col_tiles) that raises a clear error if warp_row_tiles %
micro_size_x != 0 or warp_col_tiles % micro_size_y != 0, mentioning the
offending values, and only then compute self.warp_rows and self.warp_cols as the
integer quotient.
| def get_kernel_source(self, kernel_only: bool = True) -> str: | ||
| return self.kernel_global_source | ||
|
|
There was a problem hiding this comment.
Fix the return type annotation and honour the kernel_only flag.
Two issues here:
-
Return type mismatch —
kernel_global_sourceis declaredstr | None(Line 23), so the method can returnNone, contradicting the-> strannotation. This will cause silent type errors for callers. -
Unused
kernel_onlyparameter — Every peer adapter branches on this flag (base.py,nvrtc/adapter.py,cython/adapter.py). Silently ignoring it here meansget_kernel_source(kernel_only=False)behaves identically tokernel_only=True, breaking the expected contract.
🛠️ Proposed fix
- def get_kernel_source(self, kernel_only: bool = True) -> str:
- return self.kernel_global_source
+ def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+ # Metal has a single unified source; kernel_only has no distinct meaning here.
+ return self.kernel_global_sourceIf a non-None guarantee is truly required at call sites, add an explicit assertion:
- def get_kernel_source(self, kernel_only: bool = True) -> str:
- return self.kernel_global_source
+ def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+ assert self.kernel_global_source is not None, "kernel_global_source is not available"
+ return self.kernel_global_source🧰 Tools
🪛 Ruff (0.15.1)
[warning] 56-56: Unused method argument: kernel_only
(ARG002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/jit/adapter/torch/metal.py` around lines 56 - 58, The method
get_kernel_source currently claims to return str but may return None and ignores
the kernel_only flag; change its signature to -> str | None (or keep -> str but
assert/raise if kernel_global_source is None) and implement the kernel_only
branch: if kernel_only is True return self.kernel_global_source, otherwise
return the full Metal source (compose or return the attribute that holds the
complete module/source such as self.metal_source or self.full_source); ensure
you reference get_kernel_source and kernel_global_source and either assert
kernel_global_source is not None before returning a str or update callers/types
to accept Optional[str].
| warp_row_tiles = int(self.M // m_warp) | ||
| warp_col_tiles = int(self.N // n_warp) |
There was a problem hiding this comment.
int() cast on potentially symbolic self.M / self.N.
self.M and self.N originate from buffer shapes and may be tir.IntImm or symbolic PrimExpr. If symbolic, int(self.M // m_warp) will raise at runtime. Other GEMM backends (e.g., GemmMMA) handle this similarly, so this is likely fine for Metal's concrete-size use case, but worth noting.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/gemm_metal.py` around lines 22 - 23, The int() cast on
potentially symbolic shapes self.M and self.N will fail at runtime for PrimExpr;
update the computation of warp_row_tiles and warp_col_tiles (currently
int(self.M // m_warp) and int(self.N // n_warp)) to preserve symbolic
expressions instead of forcing Python ints—either remove the int() and keep
self.M // m_warp and self.N // n_warp, or use tir.floordiv/tvm.tir.floordiv to
produce a PrimExpr; alternatively, if a concrete int is required, guard with an
isinstance check for tir.IntImm before casting. Ensure you change both
warp_row_tiles and warp_col_tiles and keep references to m_warp and n_warp.
|
why we need to introduce storage scope |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)
355-383:⚠️ Potential issue | 🔴 CriticalGuard all Fragment-dependent paths, not just
source_bufferselection.Good fix at Line 381, but this loop can still crash on non-
Fragmentlayouts: Line 363 (IsBufferCompletelyReplicated) and theread_source_bufferpath (Line 373-377 → Line 418/610) still assumeas<Fragment>().value()is valid.Suggested hardening
for (const auto &[buffer, access] : indice_map_) { if (T.layout_map.count(buffer)) { // skip reducers with rep=ALL if (auto info = reducer_info_map_.Get(buffer->data); info && info.value()->rep == ReducerRepType::ALL) continue; + auto frag = T.layout_map[buffer].as<Fragment>(); + if (!frag.has_value()) { + // Non-Fragment layout: cannot be used as infer source here. + continue; + } bool is_fully_replicated = IsBufferCompletelyReplicated(buffer, T.layout_map); if (access.is_write) { source_buffer = buffer; } else { @@ - if ((!read_source_buffer.defined() || + if (!is_fully_replicated && + (!read_source_buffer.defined() || access.indices.size() > GetAccessInfo(read_source_buffer).indices.size())) { read_source_buffer = buffer; } @@ - auto frag = T.layout_map[buffer].as<Fragment>(); - if (frag.has_value() && is_one(frag.value()->ReplicateExtent()) && + if (is_one(frag.value()->ReplicateExtent()) && !source_buffer.defined()) { source_buffer = buffer; } } } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/parallel.cc` around lines 355 - 383, The loop over indice_map_ assumes T.layout_map[buffer].as<Fragment>() is valid in multiple places; besides the existing guard around source_buffer selection, also guard calls to IsBufferCompletelyReplicated(buffer, T.layout_map) and the read_source_buffer selection/usage so they only run when the layout is a Fragment: first check T.layout_map.count(buffer) and that T.layout_map[buffer].as<Fragment>().has_value() before calling IsBufferCompletelyReplicated or using frag.value()->ReplicateExtent(); likewise ensure any later uses of read_source_buffer rely on a verified Fragment layout (or skip/handle non-Fragment layouts) so GetAccessInfo(read_source_buffer) and ReplicateExtent() are only invoked on confirmed Fragment instances.
♻️ Duplicate comments (2)
src/op/copy.cc (1)
893-894:⚠️ Potential issue | 🔴 Critical
kMetalSIMDGroupcan still fall through toLOG(FATAL).This branch makes
GetCopyInst()return the new enum, butCopyNode::Lower()below still has nokMetalSIMDGroupcase, andLowerSIMDGroupStore()is only declared in the header. The first matching copy will crash at runtime.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 893 - 894, GetCopyInst() can return CopyInst::kMetalSIMDGroup (when CheckSIMDGroupStore is true) but CopyNode::Lower() lacks a case for kMetalSIMDGroup and LowerSIMDGroupStore is only declared, causing a runtime LOG(FATAL); implement LowerSIMDGroupStore in src/op/copy.cc and add a case in CopyNode::Lower() that handles CopyInst::kMetalSIMDGroup by calling LowerSIMDGroupStore (mirroring the pattern used for other CopyInst enum values), ensuring the new enum path is fully implemented and returns/lowers correctly.tilelang/intrinsics/metal_macro_generator.py (1)
42-44:⚠️ Potential issue | 🟠 MajorFail fast when per-warp tiles are not multiples of 8.
These values are rounded down with
// 8. A legal Metal partition can still produce per-warp shapes like12x12, so this silently drops the tail rows/cols instead of rejecting the configuration.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/metal_macro_generator.py` around lines 42 - 44, Validate that per-warp tile counts divide evenly rather than truncating: check that warp_row_tiles is divisible by self.micro_size_x and warp_col_tiles is divisible by self.micro_size_y before computing self.warp_rows and self.warp_cols (or immediately after assignment) and raise a clear error / exception (with the offending values) if either modulus is non-zero so illegal Metal partitions like 12x12 are rejected instead of silently dropping tails; reference the variables warp_row_tiles, warp_col_tiles, self.micro_size_x, self.micro_size_y, and the computed warp_rows/warp_cols when adding the check and error.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/copy.cc`:
- Around line 813-819: CopyNode::CheckSIMDGroupStore currently returns true only
for local-to-local copies because it requires src.scope() == dst.scope() ==
"metal.simdgroup"; change the predicate to detect a simdgroup source storing to
a non-simdgroup destination so it properly gates LowerSIMDGroupStore.
Specifically, in CopyNode::CheckSIMDGroupStore(Target target) keep the
TargetIsMetal(target) check but replace the scope test with something like
src.scope() == "metal.simdgroup" && dst.scope() != "metal.simdgroup" (or
explicitly check for shared/global destination scopes used by the new path) so
the function matches simdgroup->shared/global store cases rather than
simdgroup->simdgroup copies.
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 59-60: The current assertion only enforces block_K >= micro_size_k
but not that block_K is a multiple of the micro-K step, which allows leftover K
values to be dropped; update the guard in gemm_metal.py to assert that block_K
is an exact multiple of micro_size_k (e.g., assert block_K % micro_size_k == 0)
and include a clear error message referencing block_K and micro_size_k so
callers know it must be a multiple (or explicitly require multiple-of-8 if
micro_size_k==8); keep the existing is_full_region(C_region) check unchanged.
---
Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 355-383: The loop over indice_map_ assumes
T.layout_map[buffer].as<Fragment>() is valid in multiple places; besides the
existing guard around source_buffer selection, also guard calls to
IsBufferCompletelyReplicated(buffer, T.layout_map) and the read_source_buffer
selection/usage so they only run when the layout is a Fragment: first check
T.layout_map.count(buffer) and that
T.layout_map[buffer].as<Fragment>().has_value() before calling
IsBufferCompletelyReplicated or using frag.value()->ReplicateExtent(); likewise
ensure any later uses of read_source_buffer rely on a verified Fragment layout
(or skip/handle non-Fragment layouts) so GetAccessInfo(read_source_buffer) and
ReplicateExtent() are only invoked on confirmed Fragment instances.
---
Duplicate comments:
In `@src/op/copy.cc`:
- Around line 893-894: GetCopyInst() can return CopyInst::kMetalSIMDGroup (when
CheckSIMDGroupStore is true) but CopyNode::Lower() lacks a case for
kMetalSIMDGroup and LowerSIMDGroupStore is only declared, causing a runtime
LOG(FATAL); implement LowerSIMDGroupStore in src/op/copy.cc and add a case in
CopyNode::Lower() that handles CopyInst::kMetalSIMDGroup by calling
LowerSIMDGroupStore (mirroring the pattern used for other CopyInst enum values),
ensuring the new enum path is fully implemented and returns/lowers correctly.
In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 42-44: Validate that per-warp tile counts divide evenly rather
than truncating: check that warp_row_tiles is divisible by self.micro_size_x and
warp_col_tiles is divisible by self.micro_size_y before computing self.warp_rows
and self.warp_cols (or immediately after assignment) and raise a clear error /
exception (with the offending values) if either modulus is non-zero so illegal
Metal partitions like 12x12 are rejected instead of silently dropping tails;
reference the variables warp_row_tiles, warp_col_tiles, self.micro_size_x,
self.micro_size_y, and the computed warp_rows/warp_cols when adding the check
and error.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 014ea53f-df65-431b-9c9d-62a27cd42be7
📒 Files selected for processing (18)
pyproject.tomlrequirements.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/gemm.ccsrc/op/gemm.hsrc/op/gemm_py.ccsrc/op/parallel.ccsrc/op/utils.htesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/transform/decouple_type_cast.pytilelang/utils/language.py
✅ Files skipped from review due to trivial changes (2)
- src/op/utils.h
- requirements.txt
🚧 Files skipped from review as they are similar to previous changes (7)
- pyproject.toml
- tilelang/transform/decouple_type_cast.py
- tilelang/utils/language.py
- tilelang/jit/adapter/torch/metal.py
- tilelang/tileop/gemm/inst.py
- tilelang/tileop/gemm/init.py
- testing/python/metal/test_metal_gemm_v2_linux.py
| assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" | ||
| assert is_full_region(C_region), "Fragment output C must be a full region" |
There was a problem hiding this comment.
Require block_K to be a multiple of 8.
Line 74 uses block_K // micro_size_k, so block_K=12 would only execute one 8-wide MMA step and silently drop the remaining K values.
🧩 Suggested guard
- assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+ assert block_K >= micro_size_k and block_K % micro_size_k == 0, (
+ f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" | |
| assert is_full_region(C_region), "Fragment output C must be a full region" | |
| assert block_K >= micro_size_k and block_K % micro_size_k == 0, ( | |
| f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})" | |
| ) | |
| assert is_full_region(C_region), "Fragment output C must be a full region" |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/gemm_metal.py` around lines 59 - 60, The current
assertion only enforces block_K >= micro_size_k but not that block_K is a
multiple of the micro-K step, which allows leftover K values to be dropped;
update the guard in gemm_metal.py to assert that block_K is an exact multiple of
micro_size_k (e.g., assert block_K % micro_size_k == 0) and include a clear
error message referencing block_K and micro_size_k so callers know it must be a
multiple (or explicitly require multiple-of-8 if micro_size_k==8); keep the
existing is_full_region(C_region) check unchanged.
Metal Shading Language only supports vector types up to 4 components (e.g. half4, float4). The default 128-bit vectorization produced float16x8 which has no Metal equivalent, causing codegen failure. Limit vectorize width to 64 bits for Metal target, matching the native vector type maximum (half4 = 4×16 = 64 bits). Perf vs torch.mm (MPS) on Apple Silicon: - 1024²: 0.85x (4.34 vs 5.14 TFLOPS) - 2048²: 0.76x (4.62 vs 6.06 TFLOPS) - 4096²: 1.37x (3.86 vs 2.81 TFLOPS)
Eliminate C_simd shared memory round-trip by keeping the accumulator in simdgroup registers across K iterations. Add LowerSIMDGroupCopy for direct simdgroup_store to device/shared memory, simdgroup-aware Fill lowering, and T.alloc_simdgroup API. Separates metal.simdgroup from local.fragment in IsFragmentBuffer since they have different semantics (cooperative vs per-thread). Introduces IsRegisterBuffer for cases needing either. 8-15% speedup over shared-memory C path on Apple Silicon.
Codegen tests verify correct Metal shader output (simdgroup_store to device, no C_local simdgroup_load from shared, make_filled init). Correctness tests verify numerical results against torch.matmul across square/non-square blocks and matrix sizes.
There was a problem hiding this comment.
🧹 Nitpick comments (3)
testing/python/metal/test_metal_gemm_v2_linux.py (1)
22-34: Codegen test uses different buffer scopes than the runtime test.This test allocates
C_localwithT.alloc_simdgroup(simdgroup register path), while the runtime test (test_metal_gemm_v2.py) allocatesC_localwithT.alloc_shared(..., scope="shared")(shared memory path). Pertilelang/tileop/gemm/gemm_metal.py:65-99, these trigger different lowering branches:
- Simdgroup scope →
_gemm_ss_simdgroup(no C load/store round-trips)- Shared scope →
_gemm_ss_shared(usessimd_load/simd_storefor C)This appears intentional to test both paths, but the test names/docstrings don't clarify this distinction. Consider adding a comment noting that
test_metal_simdgroup_store.pycovers the simdgroup accumulator path whiletest_metal_gemm_v2.pycovers the shared-memory accumulator path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 22 - 34, The test uses T.alloc_simdgroup to create C_local (simdgroup register path) while the runtime test uses T.alloc_shared for C_local (shared-memory path), which exercise different lowering branches (_gemm_ss_simdgroup vs _gemm_ss_shared); add a short clarifying comment or update the test docstring near the C_local allocation (or at the top of the test) stating that this file intentionally tests the simdgroup-accumulator path (T.alloc_simdgroup / C_local) and that test_metal_gemm_v2.py covers the shared-memory accumulator path (T.alloc_shared / C_local) so readers know both lowering branches are exercised.testing/python/metal/test_metal_simdgroup_store.py (1)
72-76: String-based assertions may be fragile if variable names change.The checks for
"simdgroup_store(C_local"and"simdgroup_load" in line and "C_local" in linerely on the generated Metal code using the exact variable nameC_local. If the codegen changes variable naming (e.g., toC_local_1or a different convention), these assertions could fail or pass incorrectly.Consider using a more robust pattern, such as counting total
simdgroup_storecalls or using a regex that matches the buffer pattern more flexibly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_simdgroup_store.py` around lines 72 - 76, The string-based assertions are fragile because they require an exact match of the variable name C_local; update the checks in test_metal_simdgroup_store.py to use regexes that allow optional suffixes (e.g. C_local or C_local_1) and/or count simdgroup_store/load calls more generally: replace the plain substring search for "simdgroup_store(C_local" with a regex like r"simdgroup_store\(\s*C_local(?:_\d+)?\b" (used to compute store_to_device) and replace the load check to search for r"simdgroup_load\(\s*C_local(?:_\d+)?\b" (ensuring len(...) == 0), or alternatively assert total simdgroup_store occurrences and that none of the simdgroup_load matches reference a C_local-like buffer; update variables store_to_device and load_c_from_shared accordingly.testing/python/metal/test_metal_gemm_v2.py (1)
84-86: Loose tolerance (atol=1.0) may mask correctness issues.For the 1024×1024 GEMM,
atol=1.0is very permissive—any element could differ by up to 1.0 from the reference. While float16 accumulation errors do compound over larger K dimensions, consider also checking relative error or documenting why this tolerance is acceptable for this configuration.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 84 - 86, The test test_gemm_v2_1024 is using a very loose absolute tolerance (assert_gemm_v2 with atol=1.0) which can mask numerical regressions; tighten the check by lowering atol (e.g., to <=1e-2) and/or add a relative tolerance (rtol, e.g., 1e-2) to the assert_gemm_v2 call, or if you must keep a larger absolute tolerance, add a brief comment on why float16 accumulation justifies that specific tolerance; update the assert_gemm_v2 invocation (and its signature use) accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 22-34: The test uses T.alloc_simdgroup to create C_local
(simdgroup register path) while the runtime test uses T.alloc_shared for C_local
(shared-memory path), which exercise different lowering branches
(_gemm_ss_simdgroup vs _gemm_ss_shared); add a short clarifying comment or
update the test docstring near the C_local allocation (or at the top of the
test) stating that this file intentionally tests the simdgroup-accumulator path
(T.alloc_simdgroup / C_local) and that test_metal_gemm_v2.py covers the
shared-memory accumulator path (T.alloc_shared / C_local) so readers know both
lowering branches are exercised.
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 84-86: The test test_gemm_v2_1024 is using a very loose absolute
tolerance (assert_gemm_v2 with atol=1.0) which can mask numerical regressions;
tighten the check by lowering atol (e.g., to <=1e-2) and/or add a relative
tolerance (rtol, e.g., 1e-2) to the assert_gemm_v2 call, or if you must keep a
larger absolute tolerance, add a brief comment on why float16 accumulation
justifies that specific tolerance; update the assert_gemm_v2 invocation (and its
signature use) accordingly.
In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 72-76: The string-based assertions are fragile because they
require an exact match of the variable name C_local; update the checks in
test_metal_simdgroup_store.py to use regexes that allow optional suffixes (e.g.
C_local or C_local_1) and/or count simdgroup_store/load calls more generally:
replace the plain substring search for "simdgroup_store(C_local" with a regex
like r"simdgroup_store\(\s*C_local(?:_\d+)?\b" (used to compute store_to_device)
and replace the load check to search for
r"simdgroup_load\(\s*C_local(?:_\d+)?\b" (ensuring len(...) == 0), or
alternatively assert total simdgroup_store occurrences and that none of the
simdgroup_load matches reference a C_local-like buffer; update variables
store_to_device and load_c_from_shared accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c1d29626-6946-4687-ad90-1df7efa113ee
📒 Files selected for processing (11)
src/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/utils.hsrc/transform/loop_vectorize.cctesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/language/__init__.pytilelang/language/allocate.pytilelang/tileop/gemm/gemm_metal.py
✅ Files skipped from review due to trivial changes (1)
- src/op/utils.h
🚧 Files skipped from review as they are similar to previous changes (3)
- src/op/copy.h
- tilelang/tileop/gemm/gemm_metal.py
- src/op/copy.cc
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/tileop/gemm/gemm_metal.py (2)
62-62:⚠️ Potential issue | 🟠 MajorRequire
block_Kto be an exact micro-K multiple.Lines 76 and 94 iterate with
block_K // micro_size_k, so any non-multiple silently drops the tail of K instead of computing it. The guard on Line 62 needs to enforce divisibility, not just>=.Suggested guard
- assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + assert block_K >= micro_size_k and block_K % micro_size_k == 0, ( + f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_metal.py` at line 62, The assertion currently only enforces block_K >= micro_size_k but must require that block_K is an exact multiple of micro_size_k so iterations using block_K // micro_size_k do not drop a tail; replace the assert on block_K and micro_size_k with a divisibility check (e.g. require block_K % micro_size_k == 0) and update the assertion message to state both values and that block_K must be a multiple of micro_size_k so the subsequent loops over block_K // micro_size_k are correct.
26-27:⚠️ Potential issue | 🟡 MinorDon’t force potentially symbolic tile counts through
int().Lines 26-27 will blow up during lowering if
self.Morself.Nare symbolictir.PrimExpr. Either keep these as TIR expressions end-to-end or make the static-shape requirement explicit before the cast so this fails predictably instead of with a PythonTypeError.One safe option if Metal GEMM is intentionally static-shape-only
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.METAL) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) + if not isinstance(self.M, (int, tir.IntImm)) or not isinstance(self.N, (int, tir.IntImm)): + raise ValueError("Metal GEMM currently requires static M/N tile sizes") + warp_row_tiles = int(self.M) // m_warp + warp_col_tiles = int(self.N) // n_warp🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_metal.py` around lines 26 - 27, warp_row_tiles and warp_col_tiles are being forced through Python int() which will fail when self.M or self.N are TIR/PrimExpr; instead either keep these as TIR expressions (e.g., use TIR floordiv/indexdiv of self.M by m_warp and self.N by n_warp) so lowering stays symbolic, or add an explicit static-shape guard before casting (validate self.M and self.N are ints and raise a clear error) so the requirement fails predictably; update the expressions that set warp_row_tiles and warp_col_tiles (and any downstream uses) to use the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Line 64: The assertion in gemm_metal.py that checks "c_in_simdgroup_reg or
is_shared(C_buf)" blocks global-scope outputs before simdgroup_copy/simd_store
can run; update that assertion to permit global/storage-scoped C buffers (e.g.,
allow is_global/is_storage(C_buf)) if Metal GEMM should support direct global
output, and ensure simdgroup_copy/simd_store (which call T.simdgroup_store)
handle the global target correctly, or alternatively replace the assertion with
an explicit error message documenting that only simdgroup or shared scopes are
supported so callers know the limitation; locate the assertion referencing
C_buf.scope() and adjust it and any related scope checks in the
simdgroup_copy/simd_store path accordingly.
---
Duplicate comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Line 62: The assertion currently only enforces block_K >= micro_size_k but
must require that block_K is an exact multiple of micro_size_k so iterations
using block_K // micro_size_k do not drop a tail; replace the assert on block_K
and micro_size_k with a divisibility check (e.g. require block_K % micro_size_k
== 0) and update the assertion message to state both values and that block_K
must be a multiple of micro_size_k so the subsequent loops over block_K //
micro_size_k are correct.
- Around line 26-27: warp_row_tiles and warp_col_tiles are being forced through
Python int() which will fail when self.M or self.N are TIR/PrimExpr; instead
either keep these as TIR expressions (e.g., use TIR floordiv/indexdiv of self.M
by m_warp and self.N by n_warp) so lowering stays symbolic, or add an explicit
static-shape guard before casting (validate self.M and self.N are ints and raise
a clear error) so the requirement fails predictably; update the expressions that
set warp_row_tiles and warp_col_tiles (and any downstream uses) to use the
chosen approach.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 66a64930-ddc8-4e57-a7a5-98309502299a
📒 Files selected for processing (1)
tilelang/tileop/gemm/gemm_metal.py
Copy TVM's codegen_metal.cc/h into tilelang's src/target/ and register as override for target.build.metal via Python-side FFI override at import time. This allows Metal-specific codegen changes without modifying the TVM submodule. Add float16x8 type support (mapped to uint4 for packed storage), with corresponding PrintVecElemLoad/Store using pointer cast and BroadcastNode using as_type<uint>(half2(...)) packing. Restore 128-bit vectorize width for Metal (was limited to 64-bit), enabling half8 global memory loads for ~2% throughput improvement on large matrices (2048²: 94% → 97% of torch).
cc1440b to
f06c241
Compare
# Conflicts: # pyproject.toml # requirements.txt
|
This is ready for review @LeiWang1999
I've updated this in the PR description. |
# Conflicts: # src/op/gemm_py.cc
…fragment Remove the Metal-specific alloc_simdgroup() API so users write the same alloc_fragment() + T.gemm() code for both CUDA and Metal targets. A new MetalFragmentToSimdgroup pass runs before LayoutInference and selectively rewrites local.fragment GEMM accumulators to metal.simdgroup scope, leaving non-GEMM fragment buffers (scalar loops) untouched. This ensures zero performance regression on M4 — the generated Metal shader code is identical to the previous alloc_simdgroup path. Also prepares the dispatch point for future M5 NAX tensor core support, where the same alloc_fragment code can route to a NAX lowering path based on architecture detection.
…e for MPP/NAX support Introduce 4 new TIR builtins (cooperative_tensor_fill/load/store/multiply_accumulate) that mirror the existing simdgroup builtins but target MetalPerformancePrimitives. Add metal.cooperative_tensor storage scope with full pipeline support: - TVM: builtin declarations, registrations, Python wrappers, script parser exports, StorageRank enum and scope string parsing - tilelang: MetalFragmentToCooperativeTensor pass, codegen_metal AllocateNode and builtin emit, fill/copy/layout_inference scope handling - Runtime: Metal language version bumped to 4.0 for MPP header availability Currently emits simdgroup_* Metal calls as functional baseline; actual MPP matmul2d codegen will follow in a subsequent commit.
Replace the simdgroup_matrix (8x8) fallback with MetalPerformancePrimitives matmul2d using 16x16 base fragments and (M=16, N=16, K=32) micro-tiles. Key changes: - codegen_metal.cc: emit MPP matmul2d code with manual per-thread coordinate calculation (MLX-style BaseNAXFrag layout) instead of get_multidimensional_index - cooperative_tensor_load/store builtins now accept mma_M/N/K and operand_role - micro_size_k raised from 16 to 32 (MPP requires at least one dim >= 32) - A/B local buffers sized for 16x32 tiles (512 elements = 16 per thread) - Metal address space qualifiers resolved via expression tree traversal Performance: 6.3 TFLOPS on 4096x4096x4096 fp16 GEMM (M5), up from 3.4 TFLOPS with simdgroup_matrix. PyTorch MPS reference: 14.2 TFLOPS.
Summary
Add Metal backend support for
T.gemmusing Apple'ssimdgroup_multiply_accumulateintrinsics. On Apple Silicon, achieves 96-97% of PyTorch MPS (MPSMatrixMultiplication) performance on large matrices, with no TVM submodule modifications.Why simdgroup intrinsics instead of per-thread layout?
TileLang's CUDA/HIP GEMM path maps matrix tiles to individual threads — each thread owns a fragment of the accumulator and executes FMA instructions independently. This per-thread layout gives fine-grained control over data placement, enabling techniques like complex swizzling patterns to avoid bank conflicts.
Metal offers no equivalent per-thread MMA interface. Apple Silicon's matrix acceleration is exposed exclusively through
simdgroup_multiply_accumulate, which operates on opaquesimdgroup_matrix<T, 8, 8>values cooperatively held across a 32-thread SIMD group. There is no documented way to decompose these matrices into per-thread elements. A scalar FMA fallback would bypass the hardware matrix unit entirely, losing an order of magnitude of performance.This is a deliberate trade-off: we give up per-thread control (e.g. custom swizzle patterns for bank conflict avoidance) in exchange for the only available path to Apple Silicon's matrix hardware. The opaque simdgroup semantics mean the hardware controls the internal data layout, but this is the cost of accessing the accelerator — there is simply no other way on Metal.
This required introducing a new
metal.simdgroupbuffer scope (distinct fromlocal.fragment) and corresponding lowering/codegen support throughout the stack.Performance
Benchmarked on Apple Silicon (M4 Pro, float16 GEMM, float32 accumulator):
Key Changes
Metal GEMM Engine (
tilelang/tileop/gemm/gemm_metal.py)metal.simdgroup(fast path): C stays in simdgroup registers across K iterations — zero shared memory round-trip. ~15% faster than shared path.shared(compat path): C goes through shared memory each iteration viasimdgroup_load/simdgroup_store. Works but slower.Simdgroup Register Support
T.alloc_simdgroup(shape, dtype)— new API for allocating Metal simdgroup matrix registers, parallel toT.alloc_shared/T.alloc_fragment.FillNode::Lower— generatesmake_filled_simdgroup_matrixformetal.simdgroupscope buffers.CopyNode::LowerSIMDGroupCopy— generatessimdgroup_storedirectly to device/shared memory, with warp partition matching the GEMM layout.IsFragmentBuffer/IsSIMDGroupBuffer/IsRegisterBuffer— separatedlocal.fragment(per-thread SIMT) frommetal.simdgroup(cooperative simdgroup) semantics.Metal Codegen Fork (
src/target/codegen_metal.cc)target.build.tilelang_metalto enable Metal-specific changes without modifying the TVM submodule.float16x8 → uint4type mapping for 128-bit vectorized global memory loads.device_codegenanddevice_codegen_without_compiledispatch to the forked codegen.Bug Fixes
parallel.cc: Fixed crash whenlayout_map[buffer]is not aFragment(e.g.metal.simdgroupbuffers). Now uses.has_value()guard before accessingReplicateExtent().decouple_type_cast.py: Treatmetal.simdgroupbuffers as local (register-level) in the type cast decoupling pass, preventing them from being misclassified as shared/global.Dependency Constraints
apache-tvm-ffi<0.1.8on macOS (workaround for [Bug] NPE since 0.1.8 apache/tvm-ffi#464).apache-tvm-ffi>=0.1.6acrosspyproject.toml,requirements.txt, andrequirements-dev.txt(memory fix from tilelang#1502).User Code Example