Skip to content

[Metal] Add Metal GEMM support with simdgroup intrinsics#1869

Open
oraluben wants to merge 20 commits intotile-ai:mainfrom
oraluben:metal-gemm
Open

[Metal] Add Metal GEMM support with simdgroup intrinsics#1869
oraluben wants to merge 20 commits intotile-ai:mainfrom
oraluben:metal-gemm

Conversation

@oraluben
Copy link
Copy Markdown
Collaborator

@oraluben oraluben commented Feb 23, 2026

Summary

Add Metal backend support for T.gemm using Apple's simdgroup_multiply_accumulate intrinsics. On Apple Silicon, achieves 96-97% of PyTorch MPS (MPSMatrixMultiplication) performance on large matrices, with no TVM submodule modifications.

Why simdgroup intrinsics instead of per-thread layout?

TileLang's CUDA/HIP GEMM path maps matrix tiles to individual threads — each thread owns a fragment of the accumulator and executes FMA instructions independently. This per-thread layout gives fine-grained control over data placement, enabling techniques like complex swizzling patterns to avoid bank conflicts.

Metal offers no equivalent per-thread MMA interface. Apple Silicon's matrix acceleration is exposed exclusively through simdgroup_multiply_accumulate, which operates on opaque simdgroup_matrix<T, 8, 8> values cooperatively held across a 32-thread SIMD group. There is no documented way to decompose these matrices into per-thread elements. A scalar FMA fallback would bypass the hardware matrix unit entirely, losing an order of magnitude of performance.

This is a deliberate trade-off: we give up per-thread control (e.g. custom swizzle patterns for bank conflict avoidance) in exchange for the only available path to Apple Silicon's matrix hardware. The opaque simdgroup semantics mean the hardware controls the internal data layout, but this is the cost of accessing the accelerator — there is simply no other way on Metal.

This required introducing a new metal.simdgroup buffer scope (distinct from local.fragment) and corresponding lowering/codegen support throughout the stack.

Performance

Benchmarked on Apple Silicon (M4 Pro, float16 GEMM, float32 accumulator):

Matrix Size PyTorch MPS TileLang Ratio
1024×1024 5.6 TFLOPS 5.4 TFLOPS 96%
2048×2048 6.0 TFLOPS 5.8 TFLOPS 97%
4096×4096 6.0 TFLOPS 5.8 TFLOPS 97%

Key Changes

Metal GEMM Engine (tilelang/tileop/gemm/gemm_metal.py)

  • Two code paths based on C accumulator scope:
    • metal.simdgroup (fast path): C stays in simdgroup registers across K iterations — zero shared memory round-trip. ~15% faster than shared path.
    • shared (compat path): C goes through shared memory each iteration via simdgroup_load/simdgroup_store. Works but slower.

Simdgroup Register Support

  • T.alloc_simdgroup(shape, dtype) — new API for allocating Metal simdgroup matrix registers, parallel to T.alloc_shared/T.alloc_fragment.
  • FillNode::Lower — generates make_filled_simdgroup_matrix for metal.simdgroup scope buffers.
  • CopyNode::LowerSIMDGroupCopy — generates simdgroup_store directly to device/shared memory, with warp partition matching the GEMM layout.
  • IsFragmentBuffer / IsSIMDGroupBuffer / IsRegisterBuffer — separated local.fragment (per-thread SIMT) from metal.simdgroup (cooperative simdgroup) semantics.

Metal Codegen Fork (src/target/codegen_metal.cc)

  • Forked TVM's Metal codegen into tilelang as target.build.tilelang_metal to enable Metal-specific changes without modifying the TVM submodule.
  • Added float16x8 → uint4 type mapping for 128-bit vectorized global memory loads.
  • Both device_codegen and device_codegen_without_compile dispatch to the forked codegen.

Bug Fixes

  • parallel.cc: Fixed crash when layout_map[buffer] is not a Fragment (e.g. metal.simdgroup buffers). Now uses .has_value() guard before accessing ReplicateExtent().
  • decouple_type_cast.py: Treat metal.simdgroup buffers as local (register-level) in the type cast decoupling pass, preventing them from being misclassified as shared/global.

Dependency Constraints

  • Pin apache-tvm-ffi<0.1.8 on macOS (workaround for [Bug] NPE since 0.1.8 apache/tvm-ffi#464).
  • Unify apache-tvm-ffi>=0.1.6 across pyproject.toml, requirements.txt, and requirements-dev.txt (memory fix from tilelang#1502).

User Code Example

@T.prim_func
def gemm_kernel(A: T.Tensor((M, K), T.float16), B: T.Tensor((K, N), T.float16), C: T.Tensor((M, N), T.float32)):
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), T.float16)
        B_shared = T.alloc_shared((block_K, block_N), T.float16)
        C_local  = T.alloc_simdgroup((block_M, block_N), T.float32)  # simdgroup registers
        T.clear(C_local)
        for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
            T.copy(A[by * block_M, ko * block_K], A_shared)
            T.copy(B[ko * block_K, bx * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[by * block_M, bx * block_N])

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds Apple Metal/MPS support: new GEMM Metal backend, simdgroup buffer scope and allocators, Metal-specific intrinsics emitter, copy/fill/vectorize changes for simdgroup handling, Metal-target dispatch and tests, and dependency adjustments for Darwin.

Changes

Cohort / File(s) Summary
Dependency Management
pyproject.toml, requirements.txt
Tightened/adjusted apache-tvm-ffi bounds: raised minimum constraint and added macOS-specific upper bound <0.1.8 for Darwin.
Copy Operation
src/op/copy.h, src/op/copy.cc
Added CopyInst::kMetalSIMDGroup; new CheckSIMDGroupCopy, selection in GetCopyInst, and LowerSIMDGroupCopy emitting simdgroup_store sequences with tiling/warp logic.
GEMM Core & Selection
src/op/gemm.h, src/op/gemm.cc, src/op/gemm_py.cc
Added Metal GEMM enum/mapping, Metal branch in instruction selection, and Metal-specific warp-partition adjustment (kMPerWarp=8 for Metal targets).
Buffer/Layout Utilities
src/op/parallel.cc, src/op/utils.h
Guarded Fragment extraction in ParallelOpNode::InferLayout; added IsSIMDGroupBuffer and IsRegisterBuffer predicates.
Fill/Vectorize Adjustments
src/op/fill.cc, src/transform/loop_vectorize.cc
FillNode::Lower handles metal.simdgroup via make_filled_simdgroup_matrix calls; vectorize planner chooses 64-element vectors for Metal targets.
TileLang Metal Intrinsics
tilelang/intrinsics/metal_macro_generator.py
Added MPSIntrinEmitter implementing simdgroup load/store, ldmatrix-like loads, mma accumulate macros, thread/warp indexing, and Buffer/BufferRegion normalization.
TileLang GEMM Backend & Dispatch
tilelang/tileop/gemm/inst.py, tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_metal.py
Added METAL GemmInst + is_metal(); exported and implemented GemmMetal lowering that emits Metal-target prim_funcs using MPSIntrinEmitter and metal.simdgroup buffers.
TileLang Language & Allocation
tilelang/language/allocate.py, tilelang/language/__init__.py
Added alloc_simdgroup(shape, dtype) and re-exported it via tilelang.language.
TileLang Utilities & Transforms
tilelang/utils/language.py, tilelang/transform/decouple_type_cast.py
Added is_metal_simdgroup(...) helper and updated is_local_buffer to treat metal.simdgroup as local.
JIT Adapter
tilelang/jit/adapter/torch/metal.py
Added MetalKernelAdapter.get_kernel_source(kernel_only: bool=True) -> str accessor returning stored kernel source.
Tests
testing/python/metal/test_metal_simdgroup_store.py, testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_gemm_v2_linux.py
Added Metal runtime and codegen tests covering simdgroup stores, gemm_v2 lowering, and kernel source assertions for simdgroup intrinsics.

Sequence Diagram

sequenceDiagram
    participant Host as Host/PyTorch
    participant Dispatch as GemmPy Dispatch
    participant Backend as GemmMetal
    participant Emitter as MPSIntrinEmitter
    participant Metal as Metal Runtime

    Host->>Dispatch: Request GEMM (Target=Metal)
    Dispatch->>Backend: Route to GemmMetal (is_metal)
    Backend->>Backend: Compute warp partitions / prepare buffers
    Backend->>Emitter: Init with tiling & dtypes
    loop per K block
        Backend->>Emitter: ldmatrix_a / ldmatrix_b
        Emitter->>Metal: simdgroup_load / simdgroup_multiply_accumulate
        Metal-->>Emitter: partial accumulators
    end
    Backend->>Emitter: simdgroup_store C_simd -> C_buf
    Backend-->>Host: Return lowered kernel / kernel source
    Host->>Metal: Execute kernel on device
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 In simdgroup rows my whiskers weave,

Threads hop, matrices spin and cleave,
MPS hums as warps align in tune,
TileLang hops beneath Metal's moon,
A rabbit cheers — gemm's bright bloom.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Metal] Add Metal GEMM support with simdgroup intrinsics' accurately and clearly summarizes the main change: implementing Metal-specific GEMM backend using simdgroup intrinsics. It is concise, specific, and directly corresponds to the primary objective of the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/transform/decouple_type_cast.py (1)

92-95: ⚠️ Potential issue | 🟡 Minor

Error message doesn't mention metal.simdgroup as a known scope.

Since is_local_buffer now accepts metal.simdgroup, the error message should list it as a valid scope for completeness.

Suggested fix
         raise ValueError(
             f"Unknown buffer scope '{buffer.scope()}' for buffer '{buffer.name}'. "
-            f"Expected one of: local, local.fragment, local.var, global, shared, shared.dyn"
+            f"Expected one of: local, local.fragment, local.var, metal.simdgroup, global, shared, shared.dyn"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 92 - 95, The
ValueError raised in decouple_type_cast.py lists valid buffer scopes but omits
"metal.simdgroup"; update the error message (the raise ValueError that
references buffer.scope() and buffer.name) to include "metal.simdgroup" in the
Expected one of: list so it matches the scopes accepted by is_local_buffer/other
checks and clearly communicates valid scopes.
🧹 Nitpick comments (7)
src/op/copy.cc (1)

639-645: Method name CheckSIMDGroupStore is misleading — it matches any simdgroup↔simdgroup copy.

Both src and dst are checked for "metal.simdgroup" scope, meaning this matches loads and stores (or more accurately, simdgroup-to-simdgroup transfers). Compare with CheckLDSMCopy (shared→fragment) and CheckSTSMCopy (fragment→shared) which are directional. Consider renaming to CheckSIMDGroupCopy (and correspondingly LowerSIMDGroupCopy) for consistency and clarity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 639 - 645, The method CheckSIMDGroupStore
incorrectly implies a directional store but actually matches any
simdgroup↔simdgroup transfer; rename CheckSIMDGroupStore to CheckSIMDGroupCopy
(and rename its lowering helper LowerSIMDGroupStore to LowerSIMDGroupCopy or
similarly) and update all call sites to use the new names so the intent matches
the implementation; keep the same body (checking src.scope() ==
"metal.simdgroup" && dst.scope() == "metal.simdgroup") but change identifiers
for consistency with CheckLDSMCopy/CheckSTSMCopy naming conventions.
src/op/gemm_py.cc (1)

136-137: Metal branch placement is correct; consider a minimal allowMetal() guard for future-proofing

The ordering (TCGEN5MMA → WGMMA → CDNA → CUDA → Metal → fallback) is correct — none of the earlier guards can fire on a Metal target. The change is logically sound.

For consistency with the other paths (which all have dedicated allow* functions that validate shape, dtype, and scope constraints before returning their instruction type), a lightweight allowMetal() gate would help catch unsupported Metal configurations early rather than deferring to downstream errors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm_py.cc` around lines 136 - 137, Add a lightweight allowMetal()
guard before returning GemmInst::kMetalExp: implement allowMetal() similar to
the other allow* helpers (validate shapes, dtypes, and memory scope constraints
expected by the Metal path) and call it in the branch that currently checks
TargetIsMetal(target); only return GemmInst::kMetalExp when allowMetal() returns
true, otherwise fall back to the existing fallback/error path so unsupported
Metal configs are rejected early. Reference: the TargetIsMetal(target) branch
and GemmInst::kMetalExp; follow the validation pattern used by functions like
allowCuda()/allowCdna()/allowWgmma() to ensure consistency.
src/op/gemm.h (1)

42-42: kMetalExp naming doesn't align with Python's METAL; enum values should be explicit

Two related concerns:

  1. The C++ enumerator is named kMetalExp (Exp = experimental), but the Python counterpart in tilelang/tileop/gemm/inst.py uses the plain name METAL = 4 with no "experimental" qualifier. This asymmetry makes it unclear whether "experimental" is a meaningful status or just a stale suffix.

  2. The C++ enum assigns values implicitly (sequential 0–4). The Python IntEnum assigns values explicitly (0–4). These are currently in sync, but inserting a new entry between existing ones would silently break the C++↔Python ABI. Adding explicit values in C++ is a cheap safeguard:

♻️ Suggested: add explicit values and align naming
-enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, kMetalExp };
+enum class GemmInst : uint8_t {
+  kMMA = 0,
+  kWGMMA = 1,
+  kTCGEN5MMA = 2,
+  kMFMA = 3,
+  kMetal = 4,  // align with Python GemmInst.METAL
+};
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm.h` at line 42, The GemmInst enum currently uses an asymmetrical
name and implicit values; change the enumerator kMetalExp to match the Python
name (e.g., kMETAL) and make all enum members have explicit integer values
matching the Python IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0,
kWGMMA = 1, kTCGEN5MMA = 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type
uint8_t, and update any usages of GemmInst::kMetalExp to the new symbol to
preserve ABI alignment with tilelang/tileop/gemm/inst.py.
testing/python/metal/test_metal_gemm_v2.py (3)

91-93: Redundant torch.mps.is_available() guard silently swallows test output when Metal is absent.

The individual test functions already carry @tilelang.testing.requires_metal, which skips them if Metal is unavailable. Wrapping tilelang.testing.main() in an additional torch.mps.is_available() check means that when run on a non-Metal machine, no tests are registered at all — no "skipped" output, no indication tests exist. Compare with test_metal_gemm_v2_linux.py which calls tilelang.testing.main() unconditionally.

🔧 Suggested fix
 if __name__ == "__main__":
-    if torch.mps.is_available():
-        tilelang.testing.main()
+    tilelang.testing.main()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 91 - 93, Remove the
redundant torch.mps.is_available() guard in the __main__ block so that
tilelang.testing.main() is called unconditionally; the tests themselves use the
`@tilelang.testing.requires_metal` decorator (in this file's test functions) to
mark/skips tests when Metal is absent, so update the __main__ section (the block
checking __name__ == "__main__") to simply invoke tilelang.testing.main()
without wrapping it in torch.mps.is_available().

80-83: Consider putting requires_metal as the outermost decorator for conventional ordering.

Placing @tilelang.testing.requires_metal inside @pytest.mark.xfail works correctly (skip propagates), but the conventional pattern is to put guard/skip decorators outermost so the intent is immediately obvious at the call site.

🎨 Suggested reorder
-@pytest.mark.xfail(reason="TODO: codegen not support float16x8")
 `@tilelang.testing.requires_metal`
+@pytest.mark.xfail(reason="TODO: codegen not support float16x8")
 def test_gemm_v2_large():
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 80 - 83, Move the
decorator order so the platform guard is outermost: swap the two decorators on
test_gemm_v2_large so `@tilelang.testing.requires_metal` appears above
`@pytest.mark.xfail`; update the decorators on the test_gemm_v2_large function
accordingly to keep the same behavior but follow conventional ordering.

88-88: Add a comment explaining atol=1.0 for the large-K test.

For K=1024 with float16 inputs, accumulated rounding error per element can approach K × ε_fp16 ≈ 1024 × 9.7e-4 ≈ 1.0, so the tolerance is intentional. A brief inline comment would prevent future readers from tightening it incorrectly.

💬 Suggested change
-    assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0)
+    # atol=1.0: with K=1024 fp16 inputs, accumulated rounding error ≈ K × ε_fp16 ≈ 1.0
+    assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` at line 88, Add a brief inline
comment next to the call to assert_gemm_v2(1024, 1024, 1024, 16, 16, 16,
atol=1.0) explaining that atol=1.0 is intentional because for K=1024 with
float16 inputs the accumulated rounding error per element can approach K *
ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute tolerance is required for
this large-K test to avoid false failures.
testing/python/metal/test_metal_gemm_v2_linux.py (1)

50-53: Misleading variable name and redundant target specification.

Two minor issues:

  1. Line 50 sets tvm.target.Target("metal") as a context manager and passes target="metal" to tilelang.lower on line 51. The context manager is redundant; the explicit target= arg alone is sufficient (and matches the pattern in similar test files).
  2. The return value of tilelang.lower is named artifact, but kernel_source is a property on the JIT kernel object (as shown in tilelang/jit/kernel.py), not on a raw artifact. Naming it kernel would be more accurate.
🔧 Suggested cleanup
-    with tvm.transform.PassContext(), tvm.target.Target("metal"):
-        artifact = tilelang.lower(func, target="metal")
-
-    src_code = artifact.kernel_source
+    with tvm.transform.PassContext():
+        kernel = tilelang.lower(func, target="metal")
+
+    src_code = kernel.kernel_source
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 50 - 53,
Remove the redundant PassContext target context and rename the returned value
from tilelang.lower to reflect it's a JIT kernel: call tilelang.lower with the
explicit target argument only (remove the tvm.transform.PassContext(),
tvm.target.Target("metal") context manager) and rename the variable from
artifact to kernel so you access kernel.kernel_source (i.e., locate the
tilelang.lower call and the subsequent kernel_source access and update the
target usage and variable name accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pyproject.toml`:
- Around line 33-35: The pyproject lower-bound for "apache-tvm-ffi" is too
permissive (">0.1.2") and allows Darwin installs of versions 0.1.3–0.1.5 that
contain the memory regression; update the constraint(s) for "apache-tvm-ffi" to
match requirements.txt by using a >=0.1.6 lower bound (e.g., change ">0.1.2" to
">=0.1.6" and ensure the Darwin-specific constraint "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" remains consistent with that lower bound).

In `@src/op/copy.cc`:
- Around line 685-686: GetCopyInst can return CopyInst::kMetalSIMDGroup but
Lower() lacks a matching case, causing a fatal crash; add a case for
CopyInst::kMetalSIMDGroup in Lower() (in the function Lower) before the kNormal
branch that implements the appropriate lowering behavior for SIMD-group metal
stores (mirror the pattern used for other Metal-specific cases in Lower(),
dispatching to the SIMD-group store lowering path and returning the resultant
statement), ensuring the new case references CopyInst::kMetalSIMDGroup so
execution no longer falls through to LOG(FATAL).

In `@src/op/copy.h`:
- Around line 232-235: The Doxygen comment for CheckSIMDGroupStore is
incorrect/copy-pasted from CheckTMemStore; update the comment to describe that
CheckSIMDGroupStore(Target target) determines whether the Metal SIMD-group
(SIMD-group/warp/wavefront) store instruction is supported on the given Target
(i.e., checks for Metal SIMD group store capability), not tensor memory store
support; reference CheckSIMDGroupStore and, for comparison, CheckTMemStore to
ensure the wording distinguishes SIMD group store support from tensor memory
store support.
- Around line 274-277: The CopyNode::Lower() switch must handle
CopyInst::kMetalSIMDGroup: add a case in Lower() that calls
LowerSIMDGroupStore(T, analyzer) when GetCopyInst(...) returns kMetalSIMDGroup
(this path is produced when CheckSIMDGroupStore(target) is true) to avoid the
LOG(FATAL) fallthrough; then implement the missing LowerSIMDGroupStore(const
LowerArgs &T, arith::Analyzer *analyzer) in copy.cc mirroring the pattern of
existing store lowerers (use the same argument handling and memory access
lowering used by other Metal-specific store methods, perform bounds/stride
checks as done in related Lower*Store functions, and produce the appropriate
Stmt), ensuring the declaration in copy.h matches the definition.

In `@src/op/gemm.cc`:
- Around line 150-154: Comments referencing a hardcoded "16" are now stale
because kMPerWarp is target-dependent; update the comments near kMPerWarp,
TargetIsMetal, and the logic that mentions "m_warp*16" and "Each warp needs at
least 16 elements in M" to reflect the runtime variable (e.g., refer to "m_warp
* kMPerWarp" and "kMPerWarp elements") so they accurately describe behavior on
Metal and other targets.

In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 70-71: The test test_metal_gemm_v2_larger currently calls
assert_metal_gemm_v2_codegen with parameters known to fail at runtime; update
the test to either mark it as an expected failure or explicitly verify
codegen-only success: add `@pytest.mark.xfail`(reason="TODO: codegen not support
float16x8") above test_metal_gemm_v2_larger to match the runtime test, or
replace the single assert_metal_gemm_v2_codegen call with a two-step check that
runs the codegen path (tilelang.lower/codegen) and asserts it succeeds while
keeping execution separately marked xfail; reference test_metal_gemm_v2_larger
and assert_metal_gemm_v2_codegen when making the change.
- Around line 22-34: The codegen test's matmul_gemm_v2 kernel is structurally
different from the runtime test; update the matmul_gemm_v2 in the Linux codegen
test so its symbols match the runtime test: change C_local to use
T.alloc_shared(..., scope="shared") (instead of T.alloc_fragment), remove/adjust
coalesced_width=2 on T.copy calls to match the runtime test's copies, and use
T.gemm_v2 (or make the runtime test use T.gemm if you choose that canonical API)
so the GEMM operator is identical; ensure these changes are applied to the
matmul_gemm_v2 definition so the codegen pre-flight validates the same kernel
that the runtime test executes.
- Line 32: The test is calling the old lowering path via T.gemm instead of the
new Metal path; update the call to use T.gemm_v2 so the codegen test exercises
the gemm_v2 lowering (replace the invocation T.gemm(A_shared, B_shared, C_local)
with T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.

In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 43-44: The code computes self.warp_rows and self.warp_cols via
integer division of warp_row_tiles//micro_size_x and
warp_col_tiles//micro_size_y which will silently truncate if inputs are not
divisible by the micro sizes (8); add validation in the same initializer or
before these assignments (e.g., in the MetalMacroGenerator constructor or method
that sets warp_row_tiles/warp_col_tiles) that raises a clear error if
warp_row_tiles % micro_size_x != 0 or warp_col_tiles % micro_size_y != 0,
mentioning the offending values, and only then compute self.warp_rows and
self.warp_cols as the integer quotient.

In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-58: The method get_kernel_source currently claims to return str
but may return None and ignores the kernel_only flag; change its signature to ->
str | None (or keep -> str but assert/raise if kernel_global_source is None) and
implement the kernel_only branch: if kernel_only is True return
self.kernel_global_source, otherwise return the full Metal source (compose or
return the attribute that holds the complete module/source such as
self.metal_source or self.full_source); ensure you reference get_kernel_source
and kernel_global_source and either assert kernel_global_source is not None
before returning a str or update callers/types to accept Optional[str].

In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 22-23: The int() cast on potentially symbolic shapes self.M and
self.N will fail at runtime for PrimExpr; update the computation of
warp_row_tiles and warp_col_tiles (currently int(self.M // m_warp) and
int(self.N // n_warp)) to preserve symbolic expressions instead of forcing
Python ints—either remove the int() and keep self.M // m_warp and self.N //
n_warp, or use tir.floordiv/tvm.tir.floordiv to produce a PrimExpr;
alternatively, if a concrete int is required, guard with an isinstance check for
tir.IntImm before casting. Ensure you change both warp_row_tiles and
warp_col_tiles and keep references to m_warp and n_warp.

---

Outside diff comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 92-95: The ValueError raised in decouple_type_cast.py lists valid
buffer scopes but omits "metal.simdgroup"; update the error message (the raise
ValueError that references buffer.scope() and buffer.name) to include
"metal.simdgroup" in the Expected one of: list so it matches the scopes accepted
by is_local_buffer/other checks and clearly communicates valid scopes.

---

Nitpick comments:
In `@src/op/copy.cc`:
- Around line 639-645: The method CheckSIMDGroupStore incorrectly implies a
directional store but actually matches any simdgroup↔simdgroup transfer; rename
CheckSIMDGroupStore to CheckSIMDGroupCopy (and rename its lowering helper
LowerSIMDGroupStore to LowerSIMDGroupCopy or similarly) and update all call
sites to use the new names so the intent matches the implementation; keep the
same body (checking src.scope() == "metal.simdgroup" && dst.scope() ==
"metal.simdgroup") but change identifiers for consistency with
CheckLDSMCopy/CheckSTSMCopy naming conventions.

In `@src/op/gemm_py.cc`:
- Around line 136-137: Add a lightweight allowMetal() guard before returning
GemmInst::kMetalExp: implement allowMetal() similar to the other allow* helpers
(validate shapes, dtypes, and memory scope constraints expected by the Metal
path) and call it in the branch that currently checks TargetIsMetal(target);
only return GemmInst::kMetalExp when allowMetal() returns true, otherwise fall
back to the existing fallback/error path so unsupported Metal configs are
rejected early. Reference: the TargetIsMetal(target) branch and
GemmInst::kMetalExp; follow the validation pattern used by functions like
allowCuda()/allowCdna()/allowWgmma() to ensure consistency.

In `@src/op/gemm.h`:
- Line 42: The GemmInst enum currently uses an asymmetrical name and implicit
values; change the enumerator kMetalExp to match the Python name (e.g., kMETAL)
and make all enum members have explicit integer values matching the Python
IntEnum (e.g., enum class GemmInst : uint8_t { kMMA = 0, kWGMMA = 1, kTCGEN5MMA
= 2, kMFMA = 3, kMETAL = 4 };), keep the underlying type uint8_t, and update any
usages of GemmInst::kMetalExp to the new symbol to preserve ABI alignment with
tilelang/tileop/gemm/inst.py.

In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 50-53: Remove the redundant PassContext target context and rename
the returned value from tilelang.lower to reflect it's a JIT kernel: call
tilelang.lower with the explicit target argument only (remove the
tvm.transform.PassContext(), tvm.target.Target("metal") context manager) and
rename the variable from artifact to kernel so you access kernel.kernel_source
(i.e., locate the tilelang.lower call and the subsequent kernel_source access
and update the target usage and variable name accordingly).

In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 91-93: Remove the redundant torch.mps.is_available() guard in the
__main__ block so that tilelang.testing.main() is called unconditionally; the
tests themselves use the `@tilelang.testing.requires_metal` decorator (in this
file's test functions) to mark/skips tests when Metal is absent, so update the
__main__ section (the block checking __name__ == "__main__") to simply invoke
tilelang.testing.main() without wrapping it in torch.mps.is_available().
- Around line 80-83: Move the decorator order so the platform guard is
outermost: swap the two decorators on test_gemm_v2_large so
`@tilelang.testing.requires_metal` appears above `@pytest.mark.xfail`; update the
decorators on the test_gemm_v2_large function accordingly to keep the same
behavior but follow conventional ordering.
- Line 88: Add a brief inline comment next to the call to assert_gemm_v2(1024,
1024, 1024, 16, 16, 16, atol=1.0) explaining that atol=1.0 is intentional
because for K=1024 with float16 inputs the accumulated rounding error per
element can approach K * ε_fp16 ≈ 1024 * 9.7e-4 ≈ 1.0, so the larger absolute
tolerance is required for this large-K test to avoid false failures.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f25954 and 696f703.

📒 Files selected for processing (18)
  • pyproject.toml
  • requirements.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/gemm_py.cc
  • src/op/parallel.cc
  • src/op/utils.h
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/utils/language.py

Comment thread pyproject.toml Outdated
Comment thread src/op/copy.cc Outdated
Comment thread src/op/copy.h Outdated
Comment thread src/op/copy.h Outdated
Comment thread src/op/gemm.cc Outdated
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)

T.gemm(A_shared, B_shared, C_local)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check the test file in question to see what's actually being called
find . -name "test_metal_gemm_v2_linux.py" -type f | head -1 | xargs cat -n

Repository: tile-ai/tilelang

Length of output: 3353


🏁 Script executed:

# Search for gemm and gemm_v2 definitions in the language
rg -n "def gemm\b|def gemm_v2\b" --type py tilelang/language/ -A 5

Repository: tile-ai/tilelang

Length of output: 743


🏁 Script executed:

# Check what tl.tileop operations are defined and used
rg -n "tl\.tileop\.gemm" --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 1314


🏁 Script executed:

# Check runtime test mentioned in review (test_metal_gemm_v2.py)
find . -name "test_metal_gemm_v2.py" -type f | head -1 | xargs cat -n 2>/dev/null || echo "File not found"

Repository: tile-ai/tilelang

Length of output: 3454


🏁 Script executed:

# Search for T.gemm and T.gemm_v2 usage patterns
rg -n "T\.gemm\(|T\.gemm_v2\(" --type py | head -20

Repository: tile-ai/tilelang

Length of output: 2497


Replace T.gemm with T.gemm_v2 on line 32.

T.gemm and T.gemm_v2 are distinct operations with different lowering paths (tl.tileop.gemm vs. tl.tileop.gemm_py). This test is meant to verify Metal codegen for the new gemm_v2 path, but it calls T.gemm instead. The assertions for simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store may be validating against the wrong lowering backend and could pass for the wrong reason or miss actual Metal simdgroup codegen issues.

The runtime test (test_metal_gemm_v2.py line 34) correctly uses T.gemm_v2; align this codegen test with it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` at line 32, The test is
calling the old lowering path via T.gemm instead of the new Metal path; update
the call to use T.gemm_v2 so the codegen test exercises the gemm_v2 lowering
(replace the invocation T.gemm(A_shared, B_shared, C_local) with
T.gemm_v2(A_shared, B_shared, C_local) in the test function) to align this
codegen test with the runtime test and ensure assertions on
simdgroup_multiply_accumulate, simdgroup_load, and simdgroup_store validate the
correct backend.

Comment on lines +70 to +71
def test_metal_gemm_v2_larger():
assert_metal_gemm_v2_codegen(128, 128, 128, 32, 32, 32, dtype=T.float16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

test_metal_gemm_v2_larger may silently pass despite a known codegen limitation.

The runtime test test_gemm_v2_large in test_metal_gemm_v2.py tests the same parameters (128, 128, 128, 32, 32, 32, dtype=T.float16) and is marked @pytest.mark.xfail(reason="TODO: codegen not support float16x8"). If the codegen for this block configuration is genuinely broken, the Linux codegen test should also fail (and be expected to fail) for consistency. Either:

  • Add @pytest.mark.xfail(reason="TODO: codegen not support float16x8") here as well, or
  • Verify that the codegen-only path (tilelang.lower without execution) succeeds where execution fails.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 70 - 71, The
test test_metal_gemm_v2_larger currently calls assert_metal_gemm_v2_codegen with
parameters known to fail at runtime; update the test to either mark it as an
expected failure or explicitly verify codegen-only success: add
`@pytest.mark.xfail`(reason="TODO: codegen not support float16x8") above
test_metal_gemm_v2_larger to match the runtime test, or replace the single
assert_metal_gemm_v2_codegen call with a two-step check that runs the codegen
path (tilelang.lower/codegen) and asserts it succeeds while keeping execution
separately marked xfail; reference test_metal_gemm_v2_larger and
assert_metal_gemm_v2_codegen when making the change.

Comment on lines +43 to +44
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add validation that warp_row_tiles and warp_col_tiles are divisible by 8.

If warp_row_tiles or warp_col_tiles aren't multiples of micro_size_x/micro_size_y (8), the integer division would silently truncate, leading to incorrect tiling.

Suggested validation
         # Number of 8x8 tiles per warp
+        assert warp_row_tiles % self.micro_size_x == 0, (
+            f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})"
+        )
+        assert warp_col_tiles % self.micro_size_y == 0, (
+            f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})"
+        )
         self.warp_rows = warp_row_tiles // self.micro_size_x
         self.warp_cols = warp_col_tiles // self.micro_size_y
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
# Number of 8x8 tiles per warp
assert warp_row_tiles % self.micro_size_x == 0, (
f"warp_row_tiles ({warp_row_tiles}) must be divisible by micro_size_x ({self.micro_size_x})"
)
assert warp_col_tiles % self.micro_size_y == 0, (
f"warp_col_tiles ({warp_col_tiles}) must be divisible by micro_size_y ({self.micro_size_y})"
)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/metal_macro_generator.py` around lines 43 - 44, The code
computes self.warp_rows and self.warp_cols via integer division of
warp_row_tiles//micro_size_x and warp_col_tiles//micro_size_y which will
silently truncate if inputs are not divisible by the micro sizes (8); add
validation in the same initializer or before these assignments (e.g., in the
MetalMacroGenerator constructor or method that sets
warp_row_tiles/warp_col_tiles) that raises a clear error if warp_row_tiles %
micro_size_x != 0 or warp_col_tiles % micro_size_y != 0, mentioning the
offending values, and only then compute self.warp_rows and self.warp_cols as the
integer quotient.

Comment on lines +56 to +58
def get_kernel_source(self, kernel_only: bool = True) -> str:
return self.kernel_global_source

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the return type annotation and honour the kernel_only flag.

Two issues here:

  1. Return type mismatchkernel_global_source is declared str | None (Line 23), so the method can return None, contradicting the -> str annotation. This will cause silent type errors for callers.

  2. Unused kernel_only parameter — Every peer adapter branches on this flag (base.py, nvrtc/adapter.py, cython/adapter.py). Silently ignoring it here means get_kernel_source(kernel_only=False) behaves identically to kernel_only=True, breaking the expected contract.

🛠️ Proposed fix
-    def get_kernel_source(self, kernel_only: bool = True) -> str:
-        return self.kernel_global_source
+    def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+        # Metal has a single unified source; kernel_only has no distinct meaning here.
+        return self.kernel_global_source

If a non-None guarantee is truly required at call sites, add an explicit assertion:

-    def get_kernel_source(self, kernel_only: bool = True) -> str:
-        return self.kernel_global_source
+    def get_kernel_source(self, kernel_only: bool = True) -> str | None:
+        assert self.kernel_global_source is not None, "kernel_global_source is not available"
+        return self.kernel_global_source
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 56-56: Unused method argument: kernel_only

(ARG002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/torch/metal.py` around lines 56 - 58, The method
get_kernel_source currently claims to return str but may return None and ignores
the kernel_only flag; change its signature to -> str | None (or keep -> str but
assert/raise if kernel_global_source is None) and implement the kernel_only
branch: if kernel_only is True return self.kernel_global_source, otherwise
return the full Metal source (compose or return the attribute that holds the
complete module/source such as self.metal_source or self.full_source); ensure
you reference get_kernel_source and kernel_global_source and either assert
kernel_global_source is not None before returning a str or update callers/types
to accept Optional[str].

Comment on lines +22 to +23
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

int() cast on potentially symbolic self.M / self.N.

self.M and self.N originate from buffer shapes and may be tir.IntImm or symbolic PrimExpr. If symbolic, int(self.M // m_warp) will raise at runtime. Other GEMM backends (e.g., GemmMMA) handle this similarly, so this is likely fine for Metal's concrete-size use case, but worth noting.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` around lines 22 - 23, The int() cast on
potentially symbolic shapes self.M and self.N will fail at runtime for PrimExpr;
update the computation of warp_row_tiles and warp_col_tiles (currently
int(self.M // m_warp) and int(self.N // n_warp)) to preserve symbolic
expressions instead of forcing Python ints—either remove the int() and keep
self.M // m_warp and self.N // n_warp, or use tir.floordiv/tvm.tir.floordiv to
produce a PrimExpr; alternatively, if a concrete int is required, guard with an
isinstance check for tir.IntImm before casting. Ensure you change both
warp_row_tiles and warp_col_tiles and keep references to m_warp and n_warp.

@oraluben oraluben requested a review from LeiWang1999 February 23, 2026 11:30
@LeiWang1999
Copy link
Copy Markdown
Member

why we need to introduce storage scope metal.simdgroup instead of just using local.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)

355-383: ⚠️ Potential issue | 🔴 Critical

Guard all Fragment-dependent paths, not just source_buffer selection.

Good fix at Line 381, but this loop can still crash on non-Fragment layouts: Line 363 (IsBufferCompletelyReplicated) and the read_source_buffer path (Line 373-377 → Line 418/610) still assume as<Fragment>().value() is valid.

Suggested hardening
   for (const auto &[buffer, access] : indice_map_) {
     if (T.layout_map.count(buffer)) {
       // skip reducers with rep=ALL
       if (auto info = reducer_info_map_.Get(buffer->data);
           info && info.value()->rep == ReducerRepType::ALL)
         continue;
+      auto frag = T.layout_map[buffer].as<Fragment>();
+      if (!frag.has_value()) {
+        // Non-Fragment layout: cannot be used as infer source here.
+        continue;
+      }

       bool is_fully_replicated =
           IsBufferCompletelyReplicated(buffer, T.layout_map);

       if (access.is_write) {
         source_buffer = buffer;
       } else {
@@
-        if ((!read_source_buffer.defined() ||
+        if (!is_fully_replicated &&
+            (!read_source_buffer.defined() ||
              access.indices.size() >
                  GetAccessInfo(read_source_buffer).indices.size())) {
           read_source_buffer = buffer;
         }
@@
-        auto frag = T.layout_map[buffer].as<Fragment>();
-        if (frag.has_value() && is_one(frag.value()->ReplicateExtent()) &&
+        if (is_one(frag.value()->ReplicateExtent()) &&
             !source_buffer.defined()) {
           source_buffer = buffer;
         }
       }
     }
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/parallel.cc` around lines 355 - 383, The loop over indice_map_ assumes
T.layout_map[buffer].as<Fragment>() is valid in multiple places; besides the
existing guard around source_buffer selection, also guard calls to
IsBufferCompletelyReplicated(buffer, T.layout_map) and the read_source_buffer
selection/usage so they only run when the layout is a Fragment: first check
T.layout_map.count(buffer) and that
T.layout_map[buffer].as<Fragment>().has_value() before calling
IsBufferCompletelyReplicated or using frag.value()->ReplicateExtent(); likewise
ensure any later uses of read_source_buffer rely on a verified Fragment layout
(or skip/handle non-Fragment layouts) so GetAccessInfo(read_source_buffer) and
ReplicateExtent() are only invoked on confirmed Fragment instances.
♻️ Duplicate comments (2)
src/op/copy.cc (1)

893-894: ⚠️ Potential issue | 🔴 Critical

kMetalSIMDGroup can still fall through to LOG(FATAL).

This branch makes GetCopyInst() return the new enum, but CopyNode::Lower() below still has no kMetalSIMDGroup case, and LowerSIMDGroupStore() is only declared in the header. The first matching copy will crash at runtime.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 893 - 894, GetCopyInst() can return
CopyInst::kMetalSIMDGroup (when CheckSIMDGroupStore is true) but
CopyNode::Lower() lacks a case for kMetalSIMDGroup and LowerSIMDGroupStore is
only declared, causing a runtime LOG(FATAL); implement LowerSIMDGroupStore in
src/op/copy.cc and add a case in CopyNode::Lower() that handles
CopyInst::kMetalSIMDGroup by calling LowerSIMDGroupStore (mirroring the pattern
used for other CopyInst enum values), ensuring the new enum path is fully
implemented and returns/lowers correctly.
tilelang/intrinsics/metal_macro_generator.py (1)

42-44: ⚠️ Potential issue | 🟠 Major

Fail fast when per-warp tiles are not multiples of 8.

These values are rounded down with // 8. A legal Metal partition can still produce per-warp shapes like 12x12, so this silently drops the tail rows/cols instead of rejecting the configuration.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/metal_macro_generator.py` around lines 42 - 44, Validate
that per-warp tile counts divide evenly rather than truncating: check that
warp_row_tiles is divisible by self.micro_size_x and warp_col_tiles is divisible
by self.micro_size_y before computing self.warp_rows and self.warp_cols (or
immediately after assignment) and raise a clear error / exception (with the
offending values) if either modulus is non-zero so illegal Metal partitions like
12x12 are rejected instead of silently dropping tails; reference the variables
warp_row_tiles, warp_col_tiles, self.micro_size_x, self.micro_size_y, and the
computed warp_rows/warp_cols when adding the check and error.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/copy.cc`:
- Around line 813-819: CopyNode::CheckSIMDGroupStore currently returns true only
for local-to-local copies because it requires src.scope() == dst.scope() ==
"metal.simdgroup"; change the predicate to detect a simdgroup source storing to
a non-simdgroup destination so it properly gates LowerSIMDGroupStore.
Specifically, in CopyNode::CheckSIMDGroupStore(Target target) keep the
TargetIsMetal(target) check but replace the scope test with something like
src.scope() == "metal.simdgroup" && dst.scope() != "metal.simdgroup" (or
explicitly check for shared/global destination scopes used by the new path) so
the function matches simdgroup->shared/global store cases rather than
simdgroup->simdgroup copies.

In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 59-60: The current assertion only enforces block_K >= micro_size_k
but not that block_K is a multiple of the micro-K step, which allows leftover K
values to be dropped; update the guard in gemm_metal.py to assert that block_K
is an exact multiple of micro_size_k (e.g., assert block_K % micro_size_k == 0)
and include a clear error message referencing block_K and micro_size_k so
callers know it must be a multiple (or explicitly require multiple-of-8 if
micro_size_k==8); keep the existing is_full_region(C_region) check unchanged.

---

Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 355-383: The loop over indice_map_ assumes
T.layout_map[buffer].as<Fragment>() is valid in multiple places; besides the
existing guard around source_buffer selection, also guard calls to
IsBufferCompletelyReplicated(buffer, T.layout_map) and the read_source_buffer
selection/usage so they only run when the layout is a Fragment: first check
T.layout_map.count(buffer) and that
T.layout_map[buffer].as<Fragment>().has_value() before calling
IsBufferCompletelyReplicated or using frag.value()->ReplicateExtent(); likewise
ensure any later uses of read_source_buffer rely on a verified Fragment layout
(or skip/handle non-Fragment layouts) so GetAccessInfo(read_source_buffer) and
ReplicateExtent() are only invoked on confirmed Fragment instances.

---

Duplicate comments:
In `@src/op/copy.cc`:
- Around line 893-894: GetCopyInst() can return CopyInst::kMetalSIMDGroup (when
CheckSIMDGroupStore is true) but CopyNode::Lower() lacks a case for
kMetalSIMDGroup and LowerSIMDGroupStore is only declared, causing a runtime
LOG(FATAL); implement LowerSIMDGroupStore in src/op/copy.cc and add a case in
CopyNode::Lower() that handles CopyInst::kMetalSIMDGroup by calling
LowerSIMDGroupStore (mirroring the pattern used for other CopyInst enum values),
ensuring the new enum path is fully implemented and returns/lowers correctly.

In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 42-44: Validate that per-warp tile counts divide evenly rather
than truncating: check that warp_row_tiles is divisible by self.micro_size_x and
warp_col_tiles is divisible by self.micro_size_y before computing self.warp_rows
and self.warp_cols (or immediately after assignment) and raise a clear error /
exception (with the offending values) if either modulus is non-zero so illegal
Metal partitions like 12x12 are rejected instead of silently dropping tails;
reference the variables warp_row_tiles, warp_col_tiles, self.micro_size_x,
self.micro_size_y, and the computed warp_rows/warp_cols when adding the check
and error.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 014ea53f-df65-431b-9c9d-62a27cd42be7

📥 Commits

Reviewing files that changed from the base of the PR and between 0976055 and 78ca6ba.

📒 Files selected for processing (18)
  • pyproject.toml
  • requirements.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/gemm_py.cc
  • src/op/parallel.cc
  • src/op/utils.h
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/utils/language.py
✅ Files skipped from review due to trivial changes (2)
  • src/op/utils.h
  • requirements.txt
🚧 Files skipped from review as they are similar to previous changes (7)
  • pyproject.toml
  • tilelang/transform/decouple_type_cast.py
  • tilelang/utils/language.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/gemm/init.py
  • testing/python/metal/test_metal_gemm_v2_linux.py

Comment thread src/op/copy.cc Outdated
Comment on lines +59 to +60
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Require block_K to be a multiple of 8.

Line 74 uses block_K // micro_size_k, so block_K=12 would only execute one 8-wide MMA step and silently drop the remaining K values.

🧩 Suggested guard
-        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        assert block_K >= micro_size_k and block_K % micro_size_k == 0, (
+            f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
assert block_K >= micro_size_k and block_K % micro_size_k == 0, (
f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})"
)
assert is_full_region(C_region), "Fragment output C must be a full region"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` around lines 59 - 60, The current
assertion only enforces block_K >= micro_size_k but not that block_K is a
multiple of the micro-K step, which allows leftover K values to be dropped;
update the guard in gemm_metal.py to assert that block_K is an exact multiple of
micro_size_k (e.g., assert block_K % micro_size_k == 0) and include a clear
error message referencing block_K and micro_size_k so callers know it must be a
multiple (or explicitly require multiple-of-8 if micro_size_k==8); keep the
existing is_full_region(C_region) check unchanged.

oraluben added 3 commits April 2, 2026 18:34
Metal Shading Language only supports vector types up to 4 components
(e.g. half4, float4). The default 128-bit vectorization produced
float16x8 which has no Metal equivalent, causing codegen failure.

Limit vectorize width to 64 bits for Metal target, matching the
native vector type maximum (half4 = 4×16 = 64 bits).

Perf vs torch.mm (MPS) on Apple Silicon:
- 1024²: 0.85x (4.34 vs 5.14 TFLOPS)
- 2048²: 0.76x (4.62 vs 6.06 TFLOPS)
- 4096²: 1.37x (3.86 vs 2.81 TFLOPS)
Eliminate C_simd shared memory round-trip by keeping the accumulator
in simdgroup registers across K iterations. Add LowerSIMDGroupCopy
for direct simdgroup_store to device/shared memory, simdgroup-aware
Fill lowering, and T.alloc_simdgroup API.

Separates metal.simdgroup from local.fragment in IsFragmentBuffer
since they have different semantics (cooperative vs per-thread).
Introduces IsRegisterBuffer for cases needing either.

8-15% speedup over shared-memory C path on Apple Silicon.
Codegen tests verify correct Metal shader output (simdgroup_store to
device, no C_local simdgroup_load from shared, make_filled init).
Correctness tests verify numerical results against torch.matmul
across square/non-square blocks and matrix sizes.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
testing/python/metal/test_metal_gemm_v2_linux.py (1)

22-34: Codegen test uses different buffer scopes than the runtime test.

This test allocates C_local with T.alloc_simdgroup (simdgroup register path), while the runtime test (test_metal_gemm_v2.py) allocates C_local with T.alloc_shared(..., scope="shared") (shared memory path). Per tilelang/tileop/gemm/gemm_metal.py:65-99, these trigger different lowering branches:

  • Simdgroup scope → _gemm_ss_simdgroup (no C load/store round-trips)
  • Shared scope → _gemm_ss_shared (uses simd_load/simd_store for C)

This appears intentional to test both paths, but the test names/docstrings don't clarify this distinction. Consider adding a comment noting that test_metal_simdgroup_store.py covers the simdgroup accumulator path while test_metal_gemm_v2.py covers the shared-memory accumulator path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 22 - 34, The
test uses T.alloc_simdgroup to create C_local (simdgroup register path) while
the runtime test uses T.alloc_shared for C_local (shared-memory path), which
exercise different lowering branches (_gemm_ss_simdgroup vs _gemm_ss_shared);
add a short clarifying comment or update the test docstring near the C_local
allocation (or at the top of the test) stating that this file intentionally
tests the simdgroup-accumulator path (T.alloc_simdgroup / C_local) and that
test_metal_gemm_v2.py covers the shared-memory accumulator path (T.alloc_shared
/ C_local) so readers know both lowering branches are exercised.
testing/python/metal/test_metal_simdgroup_store.py (1)

72-76: String-based assertions may be fragile if variable names change.

The checks for "simdgroup_store(C_local" and "simdgroup_load" in line and "C_local" in line rely on the generated Metal code using the exact variable name C_local. If the codegen changes variable naming (e.g., to C_local_1 or a different convention), these assertions could fail or pass incorrectly.

Consider using a more robust pattern, such as counting total simdgroup_store calls or using a regex that matches the buffer pattern more flexibly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_simdgroup_store.py` around lines 72 - 76, The
string-based assertions are fragile because they require an exact match of the
variable name C_local; update the checks in test_metal_simdgroup_store.py to use
regexes that allow optional suffixes (e.g. C_local or C_local_1) and/or count
simdgroup_store/load calls more generally: replace the plain substring search
for "simdgroup_store(C_local" with a regex like
r"simdgroup_store\(\s*C_local(?:_\d+)?\b" (used to compute store_to_device) and
replace the load check to search for r"simdgroup_load\(\s*C_local(?:_\d+)?\b"
(ensuring len(...) == 0), or alternatively assert total simdgroup_store
occurrences and that none of the simdgroup_load matches reference a C_local-like
buffer; update variables store_to_device and load_c_from_shared accordingly.
testing/python/metal/test_metal_gemm_v2.py (1)

84-86: Loose tolerance (atol=1.0) may mask correctness issues.

For the 1024×1024 GEMM, atol=1.0 is very permissive—any element could differ by up to 1.0 from the reference. While float16 accumulation errors do compound over larger K dimensions, consider also checking relative error or documenting why this tolerance is acceptable for this configuration.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 84 - 86, The test
test_gemm_v2_1024 is using a very loose absolute tolerance (assert_gemm_v2 with
atol=1.0) which can mask numerical regressions; tighten the check by lowering
atol (e.g., to <=1e-2) and/or add a relative tolerance (rtol, e.g., 1e-2) to the
assert_gemm_v2 call, or if you must keep a larger absolute tolerance, add a
brief comment on why float16 accumulation justifies that specific tolerance;
update the assert_gemm_v2 invocation (and its signature use) accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 22-34: The test uses T.alloc_simdgroup to create C_local
(simdgroup register path) while the runtime test uses T.alloc_shared for C_local
(shared-memory path), which exercise different lowering branches
(_gemm_ss_simdgroup vs _gemm_ss_shared); add a short clarifying comment or
update the test docstring near the C_local allocation (or at the top of the
test) stating that this file intentionally tests the simdgroup-accumulator path
(T.alloc_simdgroup / C_local) and that test_metal_gemm_v2.py covers the
shared-memory accumulator path (T.alloc_shared / C_local) so readers know both
lowering branches are exercised.

In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 84-86: The test test_gemm_v2_1024 is using a very loose absolute
tolerance (assert_gemm_v2 with atol=1.0) which can mask numerical regressions;
tighten the check by lowering atol (e.g., to <=1e-2) and/or add a relative
tolerance (rtol, e.g., 1e-2) to the assert_gemm_v2 call, or if you must keep a
larger absolute tolerance, add a brief comment on why float16 accumulation
justifies that specific tolerance; update the assert_gemm_v2 invocation (and its
signature use) accordingly.

In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 72-76: The string-based assertions are fragile because they
require an exact match of the variable name C_local; update the checks in
test_metal_simdgroup_store.py to use regexes that allow optional suffixes (e.g.
C_local or C_local_1) and/or count simdgroup_store/load calls more generally:
replace the plain substring search for "simdgroup_store(C_local" with a regex
like r"simdgroup_store\(\s*C_local(?:_\d+)?\b" (used to compute store_to_device)
and replace the load check to search for
r"simdgroup_load\(\s*C_local(?:_\d+)?\b" (ensuring len(...) == 0), or
alternatively assert total simdgroup_store occurrences and that none of the
simdgroup_load matches reference a C_local-like buffer; update variables
store_to_device and load_c_from_shared accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c1d29626-6946-4687-ad90-1df7efa113ee

📥 Commits

Reviewing files that changed from the base of the PR and between 78ca6ba and 110f120.

📒 Files selected for processing (11)
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/utils.h
  • src/transform/loop_vectorize.cc
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/language/__init__.py
  • tilelang/language/allocate.py
  • tilelang/tileop/gemm/gemm_metal.py
✅ Files skipped from review due to trivial changes (1)
  • src/op/utils.h
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/op/copy.h
  • tilelang/tileop/gemm/gemm_metal.py
  • src/op/copy.cc

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tilelang/tileop/gemm/gemm_metal.py (2)

62-62: ⚠️ Potential issue | 🟠 Major

Require block_K to be an exact micro-K multiple.

Lines 76 and 94 iterate with block_K // micro_size_k, so any non-multiple silently drops the tail of K instead of computing it. The guard on Line 62 needs to enforce divisibility, not just >=.

Suggested guard
-        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        assert block_K >= micro_size_k and block_K % micro_size_k == 0, (
+            f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` at line 62, The assertion currently only
enforces block_K >= micro_size_k but must require that block_K is an exact
multiple of micro_size_k so iterations using block_K // micro_size_k do not drop
a tail; replace the assert on block_K and micro_size_k with a divisibility check
(e.g. require block_K % micro_size_k == 0) and update the assertion message to
state both values and that block_K must be a multiple of micro_size_k so the
subsequent loops over block_K // micro_size_k are correct.

26-27: ⚠️ Potential issue | 🟡 Minor

Don’t force potentially symbolic tile counts through int().

Lines 26-27 will blow up during lowering if self.M or self.N are symbolic tir.PrimExpr. Either keep these as TIR expressions end-to-end or make the static-shape requirement explicit before the cast so this fails predictably instead of with a Python TypeError.

One safe option if Metal GEMM is intentionally static-shape-only
         m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.METAL)
-        warp_row_tiles = int(self.M // m_warp)
-        warp_col_tiles = int(self.N // n_warp)
+        if not isinstance(self.M, (int, tir.IntImm)) or not isinstance(self.N, (int, tir.IntImm)):
+            raise ValueError("Metal GEMM currently requires static M/N tile sizes")
+        warp_row_tiles = int(self.M) // m_warp
+        warp_col_tiles = int(self.N) // n_warp
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` around lines 26 - 27, warp_row_tiles and
warp_col_tiles are being forced through Python int() which will fail when self.M
or self.N are TIR/PrimExpr; instead either keep these as TIR expressions (e.g.,
use TIR floordiv/indexdiv of self.M by m_warp and self.N by n_warp) so lowering
stays symbolic, or add an explicit static-shape guard before casting (validate
self.M and self.N are ints and raise a clear error) so the requirement fails
predictably; update the expressions that set warp_row_tiles and warp_col_tiles
(and any downstream uses) to use the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Line 64: The assertion in gemm_metal.py that checks "c_in_simdgroup_reg or
is_shared(C_buf)" blocks global-scope outputs before simdgroup_copy/simd_store
can run; update that assertion to permit global/storage-scoped C buffers (e.g.,
allow is_global/is_storage(C_buf)) if Metal GEMM should support direct global
output, and ensure simdgroup_copy/simd_store (which call T.simdgroup_store)
handle the global target correctly, or alternatively replace the assertion with
an explicit error message documenting that only simdgroup or shared scopes are
supported so callers know the limitation; locate the assertion referencing
C_buf.scope() and adjust it and any related scope checks in the
simdgroup_copy/simd_store path accordingly.

---

Duplicate comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Line 62: The assertion currently only enforces block_K >= micro_size_k but
must require that block_K is an exact multiple of micro_size_k so iterations
using block_K // micro_size_k do not drop a tail; replace the assert on block_K
and micro_size_k with a divisibility check (e.g. require block_K % micro_size_k
== 0) and update the assertion message to state both values and that block_K
must be a multiple of micro_size_k so the subsequent loops over block_K //
micro_size_k are correct.
- Around line 26-27: warp_row_tiles and warp_col_tiles are being forced through
Python int() which will fail when self.M or self.N are TIR/PrimExpr; instead
either keep these as TIR expressions (e.g., use TIR floordiv/indexdiv of self.M
by m_warp and self.N by n_warp) so lowering stays symbolic, or add an explicit
static-shape guard before casting (validate self.M and self.N are ints and raise
a clear error) so the requirement fails predictably; update the expressions that
set warp_row_tiles and warp_col_tiles (and any downstream uses) to use the
chosen approach.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 66a64930-ddc8-4e57-a7a5-98309502299a

📥 Commits

Reviewing files that changed from the base of the PR and between 110f120 and 7d2a4fa.

📒 Files selected for processing (1)
  • tilelang/tileop/gemm/gemm_metal.py

Comment thread tilelang/tileop/gemm/gemm_metal.py Outdated
Copy TVM's codegen_metal.cc/h into tilelang's src/target/ and register
as override for target.build.metal via Python-side FFI override at
import time. This allows Metal-specific codegen changes without
modifying the TVM submodule.

Add float16x8 type support (mapped to uint4 for packed storage),
with corresponding PrintVecElemLoad/Store using pointer cast and
BroadcastNode using as_type<uint>(half2(...)) packing.

Restore 128-bit vectorize width for Metal (was limited to 64-bit),
enabling half8 global memory loads for ~2% throughput improvement
on large matrices (2048²: 94% → 97% of torch).
@oraluben oraluben force-pushed the metal-gemm branch 3 times, most recently from cc1440b to f06c241 Compare April 8, 2026 03:31
oraluben added 2 commits April 9, 2026 10:19
# Conflicts:
#	pyproject.toml
#	requirements.txt
@oraluben
Copy link
Copy Markdown
Collaborator Author

oraluben commented Apr 9, 2026

This is ready for review @LeiWang1999

why we need to introduce storage scope metal.simdgroup instead of just using local.

I've updated this in the PR description.

@oraluben oraluben changed the title Add Metal T.gemm_v2 using simdgroup_multiply_accumulate [Metal] Add Metal GEMM support with simdgroup intrinsics Apr 9, 2026
…fragment

Remove the Metal-specific alloc_simdgroup() API so users write the same
alloc_fragment() + T.gemm() code for both CUDA and Metal targets.

A new MetalFragmentToSimdgroup pass runs before LayoutInference and
selectively rewrites local.fragment GEMM accumulators to metal.simdgroup
scope, leaving non-GEMM fragment buffers (scalar loops) untouched.
This ensures zero performance regression on M4 — the generated Metal
shader code is identical to the previous alloc_simdgroup path.

Also prepares the dispatch point for future M5 NAX tensor core support,
where the same alloc_fragment code can route to a NAX lowering path
based on architecture detection.
…e for MPP/NAX support

Introduce 4 new TIR builtins (cooperative_tensor_fill/load/store/multiply_accumulate)
that mirror the existing simdgroup builtins but target MetalPerformancePrimitives.

Add metal.cooperative_tensor storage scope with full pipeline support:
- TVM: builtin declarations, registrations, Python wrappers, script parser exports,
  StorageRank enum and scope string parsing
- tilelang: MetalFragmentToCooperativeTensor pass, codegen_metal AllocateNode and
  builtin emit, fill/copy/layout_inference scope handling
- Runtime: Metal language version bumped to 4.0 for MPP header availability

Currently emits simdgroup_* Metal calls as functional baseline; actual MPP matmul2d
codegen will follow in a subsequent commit.
Replace the simdgroup_matrix (8x8) fallback with MetalPerformancePrimitives
matmul2d using 16x16 base fragments and (M=16, N=16, K=32) micro-tiles.

Key changes:
- codegen_metal.cc: emit MPP matmul2d code with manual per-thread coordinate
  calculation (MLX-style BaseNAXFrag layout) instead of get_multidimensional_index
- cooperative_tensor_load/store builtins now accept mma_M/N/K and operand_role
- micro_size_k raised from 16 to 32 (MPP requires at least one dim >= 32)
- A/B local buffers sized for 16x32 tiles (512 elements = 16 per thread)
- Metal address space qualifiers resolved via expression tree traversal

Performance: 6.3 TFLOPS on 4096x4096x4096 fp16 GEMM (M5), up from 3.4 TFLOPS
with simdgroup_matrix. PyTorch MPS reference: 14.2 TFLOPS.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants