Skip to content

[WIP][Feature]: add NV FP4 GEMM support for SM100 & SM120#1918

Closed
Hale423 wants to merge 13 commits intotile-ai:mainfrom
Hale423:feat/gemm-nv-fp4
Closed

[WIP][Feature]: add NV FP4 GEMM support for SM100 & SM120#1918
Hale423 wants to merge 13 commits intotile-ai:mainfrom
Hale423:feat/gemm-nv-fp4

Conversation

@Hale423
Copy link
Copy Markdown
Contributor

@Hale423 Hale423 commented Mar 9, 2026

GEMM NV FP4 Feature – Design & Progress

Addresses #1592

Summary

Add float4_e2m1fn (NV FP4) GEMM support on both SM100 (TCGEN05 path) and SM120 (fragment MMA path).

SM100 (B100/B200) uses tcgen05.mma with TMEM; SM120 (RTX 5080/5090) uses mma.sync.aligned.kind::f8f6f4.m16n8k32 with register fragments. The CUTE library already provides MMA atoms for both architectures.

Changes in this PR

WIP: SM120 fragment-MMA FP4 support

File Change Status
src/tl_templates/cuda/common.h Add tl::float_e2m1_t (inherits cute::float_e2m1_t) + to_cute_type specialization Done
src/tl_templates/cuda/cuda_fp4.h Change fp4_e2_t from custom struct to using fp4_e2_t = tl::float_e2m1_t (aligns with CUTE MMA atoms) Done
src/tl_templates/cuda/gemm_mma.h Add TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN) in SM120 section Done
tilelang/intrinsics/mma_macro_generator.py Cap k_dim at 32 for sub-byte types (FP4 MMA is m16n8k32, same as FP8) Done
tilelang/intrinsics/utils.py Add 4-bit case to get_ldmatrix_offset (reuses 8-bit layout) Done
examples/gemm_fp4/example_gemm_fp4_sm120.py FP4 GEMM example using T.alloc_fragment (no TMEM) Done

Python-side pipeline (LayoutInference + LowerTileOp) passes successfully. CUDA kernel source is generated. Three C++ compilation issues remain:

Known remaining issues (CUDA compile stage)

  1. cuda_fp4.h: __x member accessfp4_e2_2_t helper functions reference val.__x, but tl::float_e2m1_t (via cutlass::float_e2m1_t) uses a different storage accessor. Need to update pack/unpack helpers to use CUTLASS storage API.

  2. Codegen variable naming for sub-byte types — TVM codegen renames sub-byte fragment buffers with a _packed suffix (e.g. A_local_packed), but MMA lowering emits references to the original name (A_local). This is a TVM codegen interaction issue for 4-bit types.

  3. mma.h: missing MmaDispatcher specializationtl::mma_sync<kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, 16, 8, 32> has no matching dispatch in instruction/mma.h. Need to add the FP4 PTX inline assembly specialization.

Architecture: SM100 vs SM120

Feature SM100 (B100/B200) SM120 (RTX 5080/5090)
MMA instruction tcgen05.mma (async) mma.sync.aligned (sync)
Accumulator TMEM (T.alloc_tmem) fragment (T.alloc_fragment)
Code path tcgen05mma.h gemm_mma.h
CUTE atom SM100_MMA_* SM120_16x8x32_TN
FP4 MMA shape varies m16n8k32

How to test

# SM120 (RTX 5080/5090)
python examples/gemm_fp4/example_gemm_fp4_sm120.py

# SM100 (B100/B200) — requires SM100 hardware
python examples/gemm_sm100/gemm_tcgen5mma.py  # with in_dtype=T.float4_e2m1fn

Future work

  • Add support on ldmatrix b4x16_p64
    • ldmatrix load packed FP4 data,but SM120's MMA F8F6F4 expect FP4 data been unpacked (4-bit occupies 8-bit-container), need codegen ptx of ldmatrix b4x16_p64.
  • Add block-scaled FP4 (mxf4nvf4.block_scale) support as a follow-up
  • Numerical verification against reference implementation
  • CI integration

Summary by CodeRabbit

  • New Features

    • End-to-end FP4 (float4_e2m1fn) support: dtype recognition, codegen, GPU instruction dispatch, and SM120 kernels with FP4 inputs and FP32 accumulation; mixed FP4/FP8 GEMM paths enabled.
  • Examples

    • New FP4 examples: high-performance SM120 GEMM, A8W4 GEMM, and fused MoE demo with unpack/convert utilities, validation checks, and basic benchmarks.
  • Documentation

    • Added FP4 feature guide and usage notes.

@Hale423 Hale423 changed the title WIP] feat: add NV FP4 GEMM support for SM100 & SM120 [WIP] feat: add NV FP4 GEMM support for SM100 & SM120 Mar 9, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 9, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds end-to-end FP4 (float4_e2m1fn) support for CUDA SM120: new dtype enums and mappings, TCGEN5/TCGEN05 MMA specializations and ldmatrix helpers, MMA/dispatcher changes, codegen packing updates, and Python/CUDA GEMM examples and tests demonstrating FP4 GEMM and fused-MoE paths.

Changes

Cohort / File(s) Summary
Core dtype enums & mappings
src/target/ptx.h, src/target/ptx.cc, src/tl_templates/cuda/common.h
Adds kFloat4_e2m1fn to PTX/TL enums, updates string↔enum, bit-width tables, and DTypeFromString aliases.
FP4 numeric type and CUDA helpers
src/tl_templates/cuda/common.h, src/tl_templates/cuda/cuda_fp4.h, src/tl_templates/cuda/ldsm.h
Introduces tl::float_e2m1_t mapping, refactors fp4_e2_t to alias the new type, updates packing/access via raw(), and adds ptx_ldmatrix .b4x16 helpers for 4-bit loads.
MMA dispatch & instruction specializations
src/tl_templates/cuda/instruction/mma.h, src/tl_templates/cuda/instruction/tcgen05mma.h, src/tl_templates/cuda/gemm_mma.h
Registers SM120 FP4 MMA dispatchers (FP4/FP4, mixed FP4/FP8) and adds tcgen05mma specializations for float4_e2m1fn delegating to FP8 emission path.
TileLang intrinsics & layouts
tilelang/intrinsics/mma_macro_generator.py, tilelang/intrinsics/utils.py, tilelang/intrinsics/mma_layout.py
Adds dtype abbreviation for float4_e2m1fn, constrains MMA K-dim for sub-byte types, implements FP4-specific ldmatrix offset selection and two 32x16→16x32 FP4 layouts.
TileLang GEMM plumbing
tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_mma.py
Removes A==B dtype assertion, adds in_dtype_b to expose B-side dtype (maps uint8→float4_e2m1fn for non-tensor A), and uses in_dtype_b in emitter/lowering so B fragments can be FP4.
CUDA codegen packing changes
src/target/codegen_cuda.cc
Restricts FP4 packing to packed scopes (is_packed_scope) and removes the local-scope packed allocation path to avoid packed locals on the stack.
Examples & docs
examples/gemm_fp4/example_gemm_fp4_sm120.py, examples/gemm_fp4/example_gemm_a8w4_sm120.py, examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py, examples/gemm_fp4/gemm_fp4_sm120.cu, examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md
Adds Python examples (FP4 GEMM, A8W4 mixed GEMM, fused MoE with FP4 weights), FP4 LUT helpers and unpacking, test/benchmark harnesses, and a generated CUDA kernel gemm_fp4_sm120.cu implementing an SM120 FP4 GEMM.
PTX/TGEN5 descriptor mapping
src/op/tcgen5_meta.h, src/tl_templates/cuda/instruction/tcgen05mma.h
Encodes FP4 dtype mapping into TCGEN5 instr descriptor logic and adds FP4 specializations that forward to FP8 emission, ensuring instruction encoding recognizes FP4.
Misc. MMA layout/ldmatrix indexing
tilelang/intrinsics/mma_layout.py, tilelang/intrinsics/utils.py
Adds FP4-specific layout/offset functions used by ldmatrix selection for A/B tiled loads.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant TL as TileLang (frontend)
participant Emitter as Intrinsics Emitter
participant Codegen as CUDA Codegen
participant PTX as PTX/TCGEN5 layer
participant GPU as CUDA Kernel / SM120
Note over TL,Emitter: FP4 dtype flows into lowering
TL->>Emitter: request MMA emission (float4_e2m1fn)
Emitter->>Codegen: emit ldmatrix/layout + dispatcher selection
Codegen->>PTX: encode dtype, select tcgen05mma path
PTX->>GPU: generate CUDA kernel (gemm_fp4_sm120.cu)
GPU-->>TL: runtime execution & benchmark results (via example harness)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LJC00118
  • lucifer1004
  • SiriusNEO

Poem

🐰 In four small bits I hop and play,

FP4 tiles stitch night to day.
Dispatch aligns, the kernels sing,
SM120 hums — a tiny spring.
Puff of code, a carrot ring.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.96% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[WIP][Feature]: add NV FP4 GEMM support for SM100 & SM120' clearly describes the main feature being added—FP4 GEMM support for specific GPU architectures—and accurately reflects the comprehensive changes across the codebase.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Hale423 Hale423 changed the title [WIP] feat: add NV FP4 GEMM support for SM100 & SM120 [WIP][Feature]: add NV FP4 GEMM support for SM100 & SM120 Mar 9, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
tilelang/intrinsics/mma_macro_generator.py (1)

57-57: Consider using a shorter abbreviation for consistency.

Other FP8 variants use abbreviated forms like "e4m3" and "e5m2", but FP4 uses the full "float4_e2m1fn". While this may be intentional for PTX codegen compatibility, consider whether a shorter form like "e2m1" would be more consistent with the existing abbreviation pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mma_macro_generator.py` at line 57, The mapping key
"float4_e2m1fn" should be shortened to "e2m1" for consistency with other FP
abbreviations ("e4m3", "e5m2"); update the mapping entry in
mma_macro_generator.py and replace all references/usages that lookup or expect
"float4_e2m1fn" (search for the symbol float4_e2m1fn) to use "e2m1" instead,
ensuring any PTX compatibility layers or codegen consumers that relied on the
old name are adjusted to the new key.
examples/gemm_fp4/example_gemm_fp4_sm120.py (1)

93-96: Add one numerical check before benchmarking.

This example only reports latency/TFLOPS. Please compare the kernel output against a Torch reference once and use calc_diff() so FP4 lowering regressions are caught instead of just benchmarked.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 93 - 96, Before
running the benchmark, run a single numerical correctness check comparing the
JIT kernel output to a PyTorch reference and fail or log if differences are
large: invoke the kernel (jit_kernel) once with representative inputs used for
benchmarking, compute a reference result with Torch (same inputs/precision),
call calc_diff(reference, kernel_output) and assert or log the diff if it
exceeds tolerance, then proceed to profiler = jit_kernel.get_profiler() /
profiler.do_bench(); this ensures FP4 lowering regressions are caught before
measuring latency/TFLOPS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/GEMM_NV_FP4_FEATURE_STEPS.md`:
- Line 106: The dtype example string is malformed: change the example
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to the correct pairing such as
`in_dtype=T.float4_e2m1fn, accum_dtype=T.float` (or the intended
in_dtype/accum_dtype pair) so it clearly indicates the input is FP4 and the
accumulator is float; update the example text where the current
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` occurs to use the corrected
`in_dtype=T.float4_e2m1fn, accum_dtype=T.float` form.
- Around line 90-102: Update the declarative language in the
"五、本次已完成的修改(CuTeDSL/TCGEN5 路径)" section so it no longer claims end-to-end FP4
GEMM readiness; replace phrases like "能正确走 `kind::f8f6f4` 的 FP4 路径" and "已落地"
with scoped wording such as "dtype/codegen plumbing landed", "kernel source can
be generated", or "codegen support added" and add a short caveat noting
unresolved CUDA/C++ compile blockers and that numerical validation is pending;
ensure mentions of symbols `DataType::kFloat4_e2m1fn`, `encode_dtype`,
`tcgen05mma_ss/ts/ws_ss`, and files `ptx.h`, `ptx.cc`, `tcgen05mma.h`,
`mma_macro_generator.py` remain factual but change tense to indicate plumbing
implemented rather than fully validated.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 12-24: Add explicit validation at the start of matmul_fp4 to
reject K and block_K values that don't meet SM120 FP4 k32 granularity: check
that K % 32 == 0 and block_K % 32 == 0 (and that block_K <= K), and if not raise
a clear ValueError mentioning the function name and the required 32-granularity
constraint; this ensures callers see a Python error instead of a backend failure
when using matmul_fp4 with unsupported K/block_K.
- Around line 80-91: The example assumes an SM120 FP4 GEMM backend TileLang does
not support; update example_gemm_fp4_sm120.py to avoid falsely reporting success
by gating or removing the compile/print block: replace the unconditional
tilelang.compile(...) / print("Compilation succeeded!") /
jit_kernel.get_kernel_source() sequence with a runtime check for the
backend/feature (e.g., a TileLang API like a capability query or a check of
target="cuda"/SM version or a feature flag tied to the SM120 FP4 GEMM path) and
if the feature is missing, skip the compilation and emit a clear message
explaining the missing "sm120 nvfp4 gemm support" (issue `#1592`) so the example
doesn't run; alternatively remove the runnable compile/print code entirely until
the backend specialization (sm120 nvfp4 gemm) is implemented. Ensure you
reference the existing tilelang.compile call and jit_kernel.get_kernel_source in
your change.

In `@src/tl_templates/cuda/cuda_fp4.h`:
- Line 9: The helper functions that manipulate fp4_e2_t (aliased to
tl::float_e2m1_t) are directly accessing the private encoded bits via .__x which
is incompatible with CUTLASS storage API; update all such accesses in functions
like set_x(), set_y(), and make_fp4_e2_2_t() to use the public storage member or
the raw() accessor (e.g., replace any `.__x` reads/writes with `.storage` or
call the `.raw()` getters/setters) so that encoded bits are accessed via the
supported API.

---

Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 93-96: Before running the benchmark, run a single numerical
correctness check comparing the JIT kernel output to a PyTorch reference and
fail or log if differences are large: invoke the kernel (jit_kernel) once with
representative inputs used for benchmarking, compute a reference result with
Torch (same inputs/precision), call calc_diff(reference, kernel_output) and
assert or log the diff if it exceeds tolerance, then proceed to profiler =
jit_kernel.get_profiler() / profiler.do_bench(); this ensures FP4 lowering
regressions are caught before measuring latency/TFLOPS.

In `@tilelang/intrinsics/mma_macro_generator.py`:
- Line 57: The mapping key "float4_e2m1fn" should be shortened to "e2m1" for
consistency with other FP abbreviations ("e4m3", "e5m2"); update the mapping
entry in mma_macro_generator.py and replace all references/usages that lookup or
expect "float4_e2m1fn" (search for the symbol float4_e2m1fn) to use "e2m1"
instead, ensuring any PTX compatibility layers or codegen consumers that relied
on the old name are adjusted to the new key.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2a82ee87-b628-487c-b8b1-0acfd5470bb5

📥 Commits

Reviewing files that changed from the base of the PR and between 0765d1d and e040bd3.

📒 Files selected for processing (11)
  • docs/GEMM_NV_FP4_FEATURE_STEPS.md
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/op/tcgen5_meta.h
  • src/target/ptx.cc
  • src/target/ptx.h
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/instruction/tcgen05mma.h
  • tilelang/intrinsics/mma_macro_generator.py
  • tilelang/intrinsics/utils.py

Comment thread docs/GEMM_NV_FP4_FEATURE_STEPS.md Outdated
Comment on lines +90 to +102
## 五、本次已完成的修改(CuTeDSL/TCGEN5 路径)

以下修改已落地,使 **CuTeDSL 生成的 SM100 GEMM kernel** 在 A/B 为 `float4_e2m1fn` 时能正确走 `kind::f8f6f4` 的 FP4 路径。当前未改 `gemm_sm100.h` 的 CUTLASS `DispatchInstruction`(该路径由 CUTLASS 模板实例化,与 CuTeDSL 生成 TIR 再 codegen 的路径分离)。

| 文件 | 修改内容 |
|------|----------|
| **src/target/ptx.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 |
| **src/target/ptx.cc** | `enum_to_str` / `dtype_str` / `num_bits` 增加第 24 项;`DTypeFromString` 增加 `"float4_e2m1fn"`、`".e2m1"` → `kFloat4_e2m1fn`。 |
| **src/tl_templates/cuda/common.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 |
| **src/op/tcgen5_meta.h** | `GetTCGEN5InstrDesc()` 中 `encode_dtype` 增加 `dtype.is_float4_e2m1fn()` 分支,返回 `2`(FP4 format 编码)。 |
| **src/tl_templates/cuda/instruction/tcgen05mma.h** | 为 `DataType::kFloat4_e2m1fn` 增加 `tcgen05mma_ss`、`tcgen05mma_ts`、`tcgen05mma_ws_ss` 特化,均转发到 `kFloat8_e4m3` 的 f8f6f4 实现。 |
| **tilelang/intrinsics/mma_macro_generator.py** | `dtype_abbrv` 增加 `"float4_e2m1fn": "float4_e2m1fn"`,保证 lowering 时 `a_dtype_abbrv` 传入 `ptx_tcgen05_mma_ss("float4_e2m1fn", ...)`。 |

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Tone down the readiness claim in this section.

This reads as if the FP4 GEMM path is already working end-to-end, but the PR summary still calls out unresolved CUDA/C++ compile blockers. Please scope the wording to “dtype/codegen plumbing landed” or “kernel source can be generated” until build + numerical validation are actually passing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/GEMM_NV_FP4_FEATURE_STEPS.md` around lines 90 - 102, Update the
declarative language in the "五、本次已完成的修改(CuTeDSL/TCGEN5 路径)" section so it no
longer claims end-to-end FP4 GEMM readiness; replace phrases like "能正确走
`kind::f8f6f4` 的 FP4 路径" and "已落地" with scoped wording such as "dtype/codegen
plumbing landed", "kernel source can be generated", or "codegen support added"
and add a short caveat noting unresolved CUDA/C++ compile blockers and that
numerical validation is pending; ensure mentions of symbols
`DataType::kFloat4_e2m1fn`, `encode_dtype`, `tcgen05mma_ss/ts/ws_ss`, and files
`ptx.h`, `ptx.cc`, `tcgen05mma.h`, `mma_macro_generator.py` remain factual but
change tense to indicate plumbing implemented rather than fully validated.

Comment thread docs/GEMM_NV_FP4_FEATURE_STEPS.md Outdated
**后续建议**(在有 SM100 / nvcc 环境时):

1. 本地或 CI 执行完整构建并跑 `examples/gemm_sm100/` 中 BF16/F8 用例,确认无回归。
2. 新增 FP4 GEMM 示例:`in_dtype=accum_dtype=T.float4_e2m1fn, T.float`,`block_K=128`(K%32 已由 meta 保证),与参考实现做数值对比。
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the dtype example syntax here.

in_dtype=accum_dtype=T.float4_e2m1fn, T.float is malformed and reads as if the accumulator were FP4. This should describe in_dtype=T.float4_e2m1fn, accum_dtype=T.float (or the intended pair).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/GEMM_NV_FP4_FEATURE_STEPS.md` at line 106, The dtype example string is
malformed: change the example `in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to
the correct pairing such as `in_dtype=T.float4_e2m1fn, accum_dtype=T.float` (or
the intended in_dtype/accum_dtype pair) so it clearly indicates the input is FP4
and the accumulator is float; update the example text where the current
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` occurs to use the corrected
`in_dtype=T.float4_e2m1fn, accum_dtype=T.float` form.

Comment on lines +12 to +24
def matmul_fp4(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages=2,
threads=128,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate SM120 FP4 k32 constraints before building the kernel.

matmul_fp4() accepts arbitrary K and block_K, but this path lowers to FP4 MMA with k32 granularity. Reject unsupported values here so callers get a clear Python error instead of a backend failure later.

🔧 Proposed fix
 def matmul_fp4(
     M,
     N,
     K,
@@
     accum_dtype,
     num_stages=2,
     threads=128,
 ):
+    if K % 32 != 0 or block_K % 32 != 0:
+        raise ValueError("SM120 FP4 GEMM requires K and block_K to be multiples of 32")
+
     A_shape = (M, K)
     B_shape = (N, K)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 12 - 24, Add
explicit validation at the start of matmul_fp4 to reject K and block_K values
that don't meet SM120 FP4 k32 granularity: check that K % 32 == 0 and block_K %
32 == 0 (and that block_K <= K), and if not raise a clear ValueError mentioning
the function name and the required 32-granularity constraint; this ensures
callers see a Python error instead of a backend failure when using matmul_fp4
with unsupported K/block_K.

Comment on lines +80 to +91
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)

print("Compilation succeeded!")
print(jit_kernel.get_kernel_source())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect likely lowering/dispatch headers for an SM120 FP4 mma_sync specialization.
fd 'mma\.h$|gemm_mma\.h$' src | while read -r f; do
  echo "== $f =="
  rg -n -C2 'kFloat4_e2m1fn|MmaDispatcher|mma_sync<' "$f" || true
done

Repository: tile-ai/tilelang

Length of output: 1798


🌐 Web query:

tilelang SM120 FP4 mma_sync MmaDispatcher specialization support

💡 Result:

TileLang’s CUDA SM120 path exists, but FP4 (NVFP4 / MXF4) GEMM via mma.sync + a specialized “MmaDispatcher” is not currently implemented/exposed in TileLang.

  • What TileLang does for SM120 today: gemm_sm120.h is essentially a thin wrapper that just includes the generic MMA GEMM implementation (gemm_mma.h)—there’s no SM120-specific FP4 dispatch/specialization visible there. [1]
  • Project status for SM120 NVFP4: TileLang has an open feature request explicitly asking for “sm120 nvfp4 gemm support,” indicating it’s not supported yet (at least not in mainline TileLang). [2]
  • Where SM120 FP4 support does exist: NVIDIA CUTLASS documents SM120 “mixed-precision MMA” support including FP4/FP6/FP8 (“f8f6f4” families) and dispatch policies/valid tile shapes for SM120 GEMMs. [3]
  • Practical implication: if you need SM120 FP4 now, the viable route is typically CUTLASS (or TileLang’s newer CUTLASS CuTe DSL backend) rather than TileLang’s internal MMA templates. TileLang added a CuTeDSL backend (compile-to-CUTLASS CuTe DSL) in late 2025. [4]

Answer to your specific terms:

  • mma_sync: TileLang’s SM120 GEMM template layer doesn’t currently show FP4 mma.sync-based specializations/dispatch in the SM120 header path. [1]
  • MmaDispatcher specialization support”: not present as a TileLang-supported/public mechanism for SM120 FP4 today; the open NVFP4 request suggests it’s pending. [2]

Sources:
[1] TileLang gemm_sm120.h (main branch) https://raw.githubusercontent.com/tile-ai/tilelang/main/src/tl_templates/cuda/gemm_sm120.h
[2] TileLang issue “Request for sm120 nvfp4 gemm support” (#1592) #1592
[3] NVIDIA CUTLASS Blackwell/SM120 functionality docs (FP4/FP6/FP8 MMA families and SM120 GEMM dispatch/tile-shape tables) https://docs.nvidia.com/cutlass/4.3.4/media/docs/cpp/blackwell_functionality.html
[4] TileLang README “Added CuTeDSL backend support” (12/18/2025) https://github.com/tile-ai/tilelang


This example still assumes an end-to-end SM120 FP4 backend that TileLang does not currently support.

tilelang.compile(...) is followed by "Compilation succeeded!", but TileLang's SM120 GEMM path does not include FP4 (mma.sync/MmaDispatcher) specialization. TileLang has an open feature request for "sm120 nvfp4 gemm support" (issue #1592), confirming the missing implementation. Either land that backend specialization in this PR or gate/remove this runnable example for now.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 80 - 91, The
example assumes an SM120 FP4 GEMM backend TileLang does not support; update
example_gemm_fp4_sm120.py to avoid falsely reporting success by gating or
removing the compile/print block: replace the unconditional
tilelang.compile(...) / print("Compilation succeeded!") /
jit_kernel.get_kernel_source() sequence with a runtime check for the
backend/feature (e.g., a TileLang API like a capability query or a check of
target="cuda"/SM version or a feature flag tied to the SM120 FP4 GEMM path) and
if the feature is missing, skip the compilation and emit a clear message
explaining the missing "sm120 nvfp4 gemm support" (issue `#1592`) so the example
doesn't run; alternatively remove the runnable compile/print code entirely until
the backend specialization (sm120 nvfp4 gemm) is implemented. Ensure you
reference the existing tilelang.compile call and jit_kernel.get_kernel_source in
your change.

Comment thread src/tl_templates/cuda/cuda_fp4.h
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (6)
examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md (2)

106-106: ⚠️ Potential issue | 🟡 Minor

Fix the dtype example syntax.

in_dtype=accum_dtype=T.float4_e2m1fn, T.float is malformed and reads as if the accumulator were FP4. Spell the pair out explicitly, e.g. in_dtype=T.float4_e2m1fn, accum_dtype=T.float.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md` at line 106, The dtype
example is malformed: change the combined expression
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to explicitly spell out both
fields so it reads e.g. `in_dtype=T.float4_e2m1fn, accum_dtype=T.float`; update
the example text to use the separate symbols in_dtype and accum_dtype with the
correct types T.float4_e2m1fn and T.float so it’s unambiguous that the
accumulator is T.float.

90-101: ⚠️ Potential issue | 🟡 Minor

This section still reads as more validated than the PR state.

The PR summary still calls out unresolved CUDA/C++ compile blockers, so this should be framed as dtype/codegen plumbing or source-generation progress rather than a working FP4 GEMM path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md` around lines 90 - 101, The
section currently overstates completion; rephrase GEMM_NV_FP4_FEATURE_STEPS.md
so it describes these changes as dtype/codegen plumbing and source-generation
progress rather than a fully working FP4 GEMM path: change the heading and
summary sentence that starts "五、本次已完成的修改" to indicate these are implemented
plumbing/codegen changes (mentioning DataType::kFloat4_e2m1fn,
GetTCGEN5InstrDesc encode_dtype branch, tcgen05mma_ss/ts/ws_ss specializations,
dtype_abbrv addition) and add a short note that CUDA/C++ compile blockers remain
and CUTLASS DispatchInstruction in gemm_sm100.h is unchanged. Ensure the wording
signals work-in-progress, not validated runtime behavior.
examples/gemm_fp4/example_gemm_fp4_sm120.py (2)

80-103: ⚠️ Potential issue | 🟡 Minor

Tone down the success logging here.

The PR still lists unresolved CUDA/C++ blockers, so Compilation succeeded! and [OK] ... compiled ... read like end-to-end readiness. Please phrase this as lowering/source-generation success until the C++ path and runtime execution actually pass.

🔧 Suggested wording
-print("Compilation succeeded!")
+print("Lowering/source generation succeeded.")
@@
-print("\n[OK] FP4 GEMM kernel compiled and CUDA source generated successfully.")
+print("\n[OK] FP4 GEMM CUDA source generated successfully.")
 print("[TODO] Runtime execution pending TVM sub-byte pointer arithmetic fix.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 80 - 103, Update
the user-facing print messages to tone down end-to-end claims: change the
"Compilation succeeded!" and the two "[OK] FP4 GEMM kernel compiled and CUDA
source generated successfully." / "[TODO] Runtime execution pending..." lines so
they clearly state that only compilation/source-generation succeeded and
runtime/C++ execution is still pending; update the prints that reference
jit_kernel.get_kernel_source() and the file write to say "CUDA source generated
(runtime execution and C++ integration remain pending)" or similar, and remove
any definitive "OK" language to avoid implying full end-to-end readiness.

12-24: ⚠️ Potential issue | 🟡 Minor

Reject unsupported K/block_K values up front.

This path lowers to FP4 m16n8k32, so callers need a clear Python error when K or block_K is not 32-aligned, or when block_K > K, instead of a backend failure later.

🔧 Suggested guard
 def matmul_fp4(
     M,
     N,
     K,
@@
     num_stages=2,
     threads=128,
 ):
+    if K % 32 != 0 or block_K % 32 != 0 or block_K > K:
+        raise ValueError(
+            "matmul_fp4 requires K and block_K to be multiples of 32 and block_K <= K for SM120 FP4 MMA"
+        )
+
     A_shape = (M, K)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 12 - 24, In
matmul_fp4, add upfront validation to raise a clear Python error when K or
block_K are unsupported: verify K % 32 == 0, block_K % 32 == 0, and block_K <= K
(and raise ValueError with a descriptive message if any check fails) so callers
get a meaningful error instead of a backend failure; place these checks at the
start of the matmul_fp4 function before any lowering or backend calls.
src/tl_templates/cuda/cuda_fp4.h (1)

9-9: ⚠️ Potential issue | 🔴 Critical

Verify that raw() is actually the supported FP4 storage accessor.

After fp4_e2_t becomes tl::float_e2m1_t, these helpers rely on val.raw(). The PR notes still call out this header as a CUTLASS/CuTe storage-API blocker, so please confirm the accessor for the FP4 type in the version you're targeting before landing this alias.

In the CUTLASS/CuTe version used by TileLang, what is the supported raw-storage accessor for the FP4 type `cute::float_e2m1_t` / `cutlass::float_e2m1_t`? Does it expose `raw()`, or should callers use a different storage API for packing/unpacking 4-bit values?

Also applies to: 30-34, 73-73

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/cuda_fp4.h` at line 9, Confirm the correct low-level
storage accessor for the FP4 type after aliasing fp4_e2_t = tl::float_e2m1_t:
check the CUTLASS/CuTe version used by the project and verify whether
cute::float_e2m1_t / cutlass::float_e2m1_t exposes raw() or a different method
(e.g., bits(), as_int(), data(), u(), etc.); then update all code that calls
val.raw() (references around the fp4_e2_t alias and the usages noted near the
other occurrences) to call the actual supported accessor and adjust any
packing/unpacking helpers to use that API, and add a short comment by the
fp4_e2_t typedef noting the chosen accessor for future reference.
examples/gemm_fp4/gemm_fp4_sm120.cu (1)

49-52: ⚠️ Potential issue | 🔴 Critical

These mma_sync calls still depend on a missing FP4 dispatcher.

src/tl_templates/cuda/instruction/mma.h still needs the MmaDispatcher specialization for tl::mma_sync<tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat32, 16, 8, 32, ...>, so this checked-in SM120 kernel remains a compile blocker.

Expect to find a specialization matching the instantiated FP4 signature; if it is absent, this kernel cannot compile.

#!/bin/bash
set -euo pipefail

fd 'mma\.h$|gemm_mma\.h$' src | while read -r f; do
  echo "== $f =="
  rg -n -C2 'MmaDispatcher|kFloat4_e2m1fn|mma_sync<' "$f" || true
done

Also applies to: 64-65

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/gemm_fp4_sm120.cu` around lines 49 - 52, The kernel
instantiates tl::mma_sync with FP4 (tl::DataType::kFloat4_e2m1fn) operand types
and dims (16,8,32,false,true) but there is no corresponding MmaDispatcher
specialization, which prevents compilation; add a MmaDispatcher specialization
in the MmaDispatcher template (in mma.h) that exactly matches the signature used
by the kernel (template parameters: tl::DataType::kFloat4_e2m1fn,
tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat32, 16, 8, 32, false, true)
and implement the required interface (the same static method/operator used by
other dispatchers) to perform the FP4 packing/unpacking and call the correct
low-level mma intrinsic or shared implementation; ensure the specialization
forwards to any existing FP4 helper routine or implements the correct
accumulation into float32 so the tl::mma_sync calls in gemm_fp4_sm120.cu
compile.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/target/ptx.cc`:
- Around line 43-51: DTypeFromString now accepts kFloat4_e2m1fn but the PTX
backend lacks full FP4 support; update CheckMMADTypeCompatible, GetFragAttrs,
and the MMA/WGMMA configuration tables so they recognize and handle
kFloat4_e2m1fn end-to-end (including any frag attribute mapping used by
PrintMMAAssembly and PrintWGMMAAssembly) or, as a safer short-term fix, make
these functions explicitly reject kFloat4_e2m1fn (return false or raise the same
parse-time error) so callers like PrintMMAAssembly/PrintWGMMAAssembly never
receive an unsupported dtype; locate the symbols CheckMMADTypeCompatible,
GetFragAttrs, PrintMMAAssembly, PrintWGMMAAssembly and the MMA/WGMMA config
tables in this file and either add FP4 cases mirroring other float types or add
explicit guards that block FP4 from proceeding to assembly generation.

In `@tilelang/intrinsics/mma_macro_generator.py`:
- Around line 118-119: The ladder-transform emitter still computes k_dim as 256
// a_dtype.bits which yields 64 for float4_e2m1fn and breaks
_initialize_mma_prefix; update
TensorCoreIntrinEmitterWithLadderTransform._initialize_k_dim to apply the same
clamp (k_dim = min(256 // a_dtype.bits, 32)) or refactor both emitters to call a
shared helper (e.g., a new compute_k_dim(a_dtype) used by _initialize_k_dim in
both TensorCoreIntrinEmitter and TensorCoreIntrinEmitterWithLadderTransform) so
FP4 k_dim is capped at 32.

---

Duplicate comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 80-103: Update the user-facing print messages to tone down
end-to-end claims: change the "Compilation succeeded!" and the two "[OK] FP4
GEMM kernel compiled and CUDA source generated successfully." / "[TODO] Runtime
execution pending..." lines so they clearly state that only
compilation/source-generation succeeded and runtime/C++ execution is still
pending; update the prints that reference jit_kernel.get_kernel_source() and the
file write to say "CUDA source generated (runtime execution and C++ integration
remain pending)" or similar, and remove any definitive "OK" language to avoid
implying full end-to-end readiness.
- Around line 12-24: In matmul_fp4, add upfront validation to raise a clear
Python error when K or block_K are unsupported: verify K % 32 == 0, block_K % 32
== 0, and block_K <= K (and raise ValueError with a descriptive message if any
check fails) so callers get a meaningful error instead of a backend failure;
place these checks at the start of the matmul_fp4 function before any lowering
or backend calls.

In `@examples/gemm_fp4/gemm_fp4_sm120.cu`:
- Around line 49-52: The kernel instantiates tl::mma_sync with FP4
(tl::DataType::kFloat4_e2m1fn) operand types and dims (16,8,32,false,true) but
there is no corresponding MmaDispatcher specialization, which prevents
compilation; add a MmaDispatcher specialization in the MmaDispatcher template
(in mma.h) that exactly matches the signature used by the kernel (template
parameters: tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat4_e2m1fn,
tl::DataType::kFloat32, 16, 8, 32, false, true) and implement the required
interface (the same static method/operator used by other dispatchers) to perform
the FP4 packing/unpacking and call the correct low-level mma intrinsic or shared
implementation; ensure the specialization forwards to any existing FP4 helper
routine or implements the correct accumulation into float32 so the tl::mma_sync
calls in gemm_fp4_sm120.cu compile.

In `@examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md`:
- Line 106: The dtype example is malformed: change the combined expression
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to explicitly spell out both
fields so it reads e.g. `in_dtype=T.float4_e2m1fn, accum_dtype=T.float`; update
the example text to use the separate symbols in_dtype and accum_dtype with the
correct types T.float4_e2m1fn and T.float so it’s unambiguous that the
accumulator is T.float.
- Around line 90-101: The section currently overstates completion; rephrase
GEMM_NV_FP4_FEATURE_STEPS.md so it describes these changes as dtype/codegen
plumbing and source-generation progress rather than a fully working FP4 GEMM
path: change the heading and summary sentence that starts "五、本次已完成的修改" to
indicate these are implemented plumbing/codegen changes (mentioning
DataType::kFloat4_e2m1fn, GetTCGEN5InstrDesc encode_dtype branch,
tcgen05mma_ss/ts/ws_ss specializations, dtype_abbrv addition) and add a short
note that CUDA/C++ compile blockers remain and CUTLASS DispatchInstruction in
gemm_sm100.h is unchanged. Ensure the wording signals work-in-progress, not
validated runtime behavior.

In `@src/tl_templates/cuda/cuda_fp4.h`:
- Line 9: Confirm the correct low-level storage accessor for the FP4 type after
aliasing fp4_e2_t = tl::float_e2m1_t: check the CUTLASS/CuTe version used by the
project and verify whether cute::float_e2m1_t / cutlass::float_e2m1_t exposes
raw() or a different method (e.g., bits(), as_int(), data(), u(), etc.); then
update all code that calls val.raw() (references around the fp4_e2_t alias and
the usages noted near the other occurrences) to call the actual supported
accessor and adjust any packing/unpacking helpers to use that API, and add a
short comment by the fp4_e2_t typedef noting the chosen accessor for future
reference.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4e8ca97b-bcfd-44fd-9624-46066d0f3e9c

📥 Commits

Reviewing files that changed from the base of the PR and between e040bd3 and 7cf06fa.

📒 Files selected for processing (15)
  • docs/GEMM_NV_FP4_FEATURE_STEPS.md
  • examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • examples/gemm_fp4/gemm_fp4_sm120.cu
  • src/op/tcgen5_meta.h
  • src/target/codegen_cuda.cc
  • src/target/ptx.cc
  • src/target/ptx.h
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/instruction/tcgen05mma.h
  • tilelang/intrinsics/mma_macro_generator.py
  • tilelang/intrinsics/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/tl_templates/cuda/gemm_mma.h
  • src/target/ptx.h

Comment thread src/target/ptx.cc
Comment on lines +43 to +51
"kBit16", "kBit32", "kBit64", "kFloat4_e2m1fn"};

static const char *dtype_str[] = {
".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32",
".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32",
".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64", ".e2m1"};
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32,
64, 64, 8, 8, 16, 16, 32, 32,
32, 64, 1, 8, 16, 32, 64};
32, 64, 1, 8, 16, 32, 64, 4};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Finish the PTX-side FP4 wiring before accepting the new dtype string.

DTypeFromString() now parses FP4, but this file still has no kFloat4_e2m1fn handling in CheckMMADTypeCompatible(), GetFragAttrs(), or the valid MMA/WGMMA config tables. That means any FP4 caller that reaches PrintMMAAssembly() / PrintWGMMAAssembly() will parse successfully and then fatal later in the same file. Either complete those cases end-to-end or keep the parser rejecting FP4 on the PTX-assembly path for now.

Also applies to: 83-84

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/ptx.cc` around lines 43 - 51, DTypeFromString now accepts
kFloat4_e2m1fn but the PTX backend lacks full FP4 support; update
CheckMMADTypeCompatible, GetFragAttrs, and the MMA/WGMMA configuration tables so
they recognize and handle kFloat4_e2m1fn end-to-end (including any frag
attribute mapping used by PrintMMAAssembly and PrintWGMMAAssembly) or, as a
safer short-term fix, make these functions explicitly reject kFloat4_e2m1fn
(return false or raise the same parse-time error) so callers like
PrintMMAAssembly/PrintWGMMAAssembly never receive an unsupported dtype; locate
the symbols CheckMMADTypeCompatible, GetFragAttrs, PrintMMAAssembly,
PrintWGMMAAssembly and the MMA/WGMMA config tables in this file and either add
FP4 cases mirroring other float types or add explicit guards that block FP4 from
proceeding to assembly generation.

Comment on lines +118 to +119
# MMA k_dim caps at 32 (m16n8k32 is the widest K for FP8/FP4)
self.k_dim = min(256 // a_dtype.bits, 32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Apply the FP4 k32 clamp to the ladder-transform emitter too.

TensorCoreIntrinEmitterWithLadderTransform._initialize_k_dim() still computes 256 // bits, so float4_e2m1fn becomes 64 there and later trips _initialize_mma_prefix() with Unsupported k_dim. Please share this clamp or route both classes through one helper.

🔧 Suggested fix
     def _initialize_k_dim(self, a_dtype=T.float16):
-        self.k_dim = 256 // DataType(a_dtype).bits
+        self.k_dim = min(256 // DataType(a_dtype).bits, 32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mma_macro_generator.py` around lines 118 - 119, The
ladder-transform emitter still computes k_dim as 256 // a_dtype.bits which
yields 64 for float4_e2m1fn and breaks _initialize_mma_prefix; update
TensorCoreIntrinEmitterWithLadderTransform._initialize_k_dim to apply the same
clamp (k_dim = min(256 // a_dtype.bits, 32)) or refactor both emitters to call a
shared helper (e.g., a new compute_k_dim(a_dtype) used by _initialize_k_dim in
both TensorCoreIntrinEmitter and TensorCoreIntrinEmitterWithLadderTransform) so
FP4 k_dim is capped at 32.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
examples/gemm_fp4/example_gemm_fp4_sm120.py (1)

118-119: Consider documenting the tolerance threshold rationale.

A max_abs_diff < 1.0 threshold is quite loose. While FP4's coarse quantization may justify this, a brief comment explaining the expected precision bounds would help future maintainers understand whether this is intentional or a placeholder.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 118 - 119, The pass
condition uses a very loose tolerance (max_abs_diff < 1.0) which needs an
explanatory comment: update the block that checks the variable max_diff and
prints "[PASS] numerical verification (max_abs_diff < 1.0)" to include a short
comment describing why 1.0 is acceptable for FP4 (e.g., expected coarse
quantization error for FP4 on sm120, any empirical basis or citation, and that
this is intentional rather than a placeholder), or replace the literal with a
named constant (e.g., FP4_TOLERANCE) and comment that constant explaining the
expected precision bounds and how it was chosen.
src/tl_templates/cuda/instruction/mma.h (1)

168-182: Narrow FP4 preprocessing guard to avoid future mixed-type mispacking.

At Lines 168-178, both operands are shifted when either side is FP4. That’s fine today (only FP4/FP4 is dispatched), but it can silently break if mixed FP4 dispatchers are introduced later. Prefer guarding this path with AType == kFloat4_e2m1fn && BType == kFloat4_e2m1fn (or shifting each operand conditionally).

Suggested minimal guard tightening
-  if constexpr (AType == DataType::kFloat4_e2m1fn ||
-                BType == DataType::kFloat4_e2m1fn) {
+  if constexpr (AType == DataType::kFloat4_e2m1fn &&
+                BType == DataType::kFloat4_e2m1fn) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/instruction/mma.h` around lines 168 - 182, The current
branch shifts both operands when either AType or BType equals
DataType::kFloat4_e2m1fn, which can mispack mixed-type calls; change the guard
so it only takes this path when both types are FP4 (i.e., AType ==
DataType::kFloat4_e2m1fn && BType == DataType::kFloat4_e2m1fn) or alternatively
apply shifts per-operand (shift a[] only when AType is FP4 and shift b[] only
when BType is FP4) before calling Dispatcher::exec; update the if constexpr
condition and the uses of as/bs (or keep original a/b when not shifted) to
ensure correct typing and avoid silent mispacking.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 62-64: The make_fp4_tensor function silently truncates when K is
odd because it uses K // 2; add a validation at the start of make_fp4_tensor to
check that K % 2 == 0 and raise a clear ValueError (or AssertionError) if not,
including K in the message, so callers know K must be even before creating the
packed (M, K//2) int8 tensor.
- Line 121: The print statement uses an unnecessary f-string prefix; remove the
leading "f" from the call so the literal is a normal string (replace
print(f"[WARN] large diff — may indicate layout or data flow issue") with
print("[WARN] large diff — may indicate layout or data flow issue") to eliminate
the unused f-string).

---

Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 118-119: The pass condition uses a very loose tolerance
(max_abs_diff < 1.0) which needs an explanatory comment: update the block that
checks the variable max_diff and prints "[PASS] numerical verification
(max_abs_diff < 1.0)" to include a short comment describing why 1.0 is
acceptable for FP4 (e.g., expected coarse quantization error for FP4 on sm120,
any empirical basis or citation, and that this is intentional rather than a
placeholder), or replace the literal with a named constant (e.g., FP4_TOLERANCE)
and comment that constant explaining the expected precision bounds and how it
was chosen.

In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 168-182: The current branch shifts both operands when either AType
or BType equals DataType::kFloat4_e2m1fn, which can mispack mixed-type calls;
change the guard so it only takes this path when both types are FP4 (i.e., AType
== DataType::kFloat4_e2m1fn && BType == DataType::kFloat4_e2m1fn) or
alternatively apply shifts per-operand (shift a[] only when AType is FP4 and
shift b[] only when BType is FP4) before calling Dispatcher::exec; update the if
constexpr condition and the uses of as/bs (or keep original a/b when not
shifted) to ensure correct typing and avoid silent mispacking.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: db4fd421-ef6b-49ea-b04c-562028950504

📥 Commits

Reviewing files that changed from the base of the PR and between 7cf06fa and f037550.

📒 Files selected for processing (5)
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/instruction/mma.h
  • tilelang/intrinsics/mma_layout.py
  • tilelang/intrinsics/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tilelang/intrinsics/utils.py
  • src/target/codegen_cuda.cc

Comment on lines +62 to +64
def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor:
"""Create random packed FP4 tensor as (M, K//2) int8."""
return torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device).to(torch.int8)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add validation that K is even.

K // 2 silently truncates if K is odd, which would create a tensor with fewer elements than expected.

🛡️ Proposed fix
 def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor:
     """Create random packed FP4 tensor as (M, K//2) int8."""
+    if K % 2 != 0:
+        raise ValueError("K must be even for packed FP4 representation")
     return torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device).to(torch.int8)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 62 - 64, The
make_fp4_tensor function silently truncates when K is odd because it uses K //
2; add a validation at the start of make_fp4_tensor to check that K % 2 == 0 and
raise a clear ValueError (or AssertionError) if not, including K in the message,
so callers know K must be even before creating the packed (M, K//2) int8 tensor.

if max_diff < 1.0:
print("[PASS] numerical verification (max_abs_diff < 1.0)")
else:
print(f"[WARN] large diff — may indicate layout or data flow issue")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove extraneous f-string prefix.

The string has no placeholders, so the f prefix is unnecessary.

🔧 Proposed fix
-    print(f"[WARN] large diff — may indicate layout or data flow issue")
+    print("[WARN] large diff — may indicate layout or data flow issue")
🧰 Tools
🪛 Ruff (0.15.5)

[error] 121-121: f-string without any placeholders

Remove extraneous f prefix

(F541)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` at line 121, The print statement
uses an unnecessary f-string prefix; remove the leading "f" from the call so the
literal is a normal string (replace print(f"[WARN] large diff — may indicate
layout or data flow issue") with print("[WARN] large diff — may indicate layout
or data flow issue") to eliminate the unused f-string).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
tilelang/intrinsics/utils.py (1)

54-64: Update the rest of the FP4 helper contract in this module.

get_ldmatrix_offset() now accepts FP4, but the public dtype annotations here still advertise only float16/int8, and get_mma_micro_size() below still leaves FP4 on the default k=16 path. Keeping those helpers in sync avoids future FP4 callers selecting the wrong micro-k.

♻️ Suggested follow-up
 def get_ldmatrix_offset(
     matrix: Literal["A", "B"],
     row_idx,
     col_idx,
     stride,
-    dtype: Literal["float16", "int8"] = "float16",
+    dtype: Literal["float16", "int8", "float4_e2m1fn"] = "float16",
     transposed: bool = False,
 ):
@@
-def get_mma_micro_size(dtype: Literal["float16", "int8"]):
+def get_mma_micro_size(
+    dtype: Literal["float16", "int8", "float8_e4m3", "float8_e5m2", "float4_e2m1fn"]
+):
@@
-    if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:
+    if dtype in {"float8_e4m3", "float8_e5m2", "float4_e2m1fn", "int8"}:
         micro_size_k = 32
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/utils.py` around lines 54 - 64, Update this module's
public dtype annotations and helper logic to include FP4 alongside float16/int8:
add "fp4" (or the canonical FP4 dtype name used elsewhere) to the type
hints/docs for functions like get_ldmatrix_offset and any public signatures that
currently list only float16/int8, and modify get_mma_micro_size to handle the
FP4 case explicitly instead of falling through to the default k=16 path (choose
the correct micro-k for FP4 consistent with hardware/other code paths). Also
scan for other helper functions in this file that claim only float16/int8 and
update their contracts/branches so FP4 uses the correct transform and micro-k
behavior (reference get_ldmatrix_offset and get_mma_micro_size to locate where
to change).
src/tl_templates/cuda/instruction/mma.h (1)

168-179: Tighten the FP4 pre-shift to the supported FP4/FP4 case.

This branch fires when either operand is FP4, but this file only registers an FP4/FP4 dispatcher. If a mixed FP4 kernel is added later, the non-FP4 operand will be shifted too.

♻️ Suggested change
-  if constexpr (AType == DataType::kFloat4_e2m1fn ||
-                BType == DataType::kFloat4_e2m1fn) {
+  if constexpr (AType == DataType::kFloat4_e2m1fn &&
+                BType == DataType::kFloat4_e2m1fn) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/instruction/mma.h` around lines 168 - 179, The branch
currently triggers if either AType or BType is FP4 (DataType::kFloat4_e2m1fn)
causing both operands to be pre-shifted even though only FP4/FP4 dispatchers are
registered; change the condition to require both operands be FP4 (use AType ==
DataType::kFloat4_e2m1fn && BType == DataType::kFloat4_e2m1fn) and only perform
the left-shift on AReg as[] and BReg bs[] inside that tightened branch
(references: AType, BType, DataType::kFloat4_e2m1fn, Dispatcher::ARegType,
Dispatcher::BRegType, Dispatcher::exec, detail::MmaImplTraits::kARegs/kBRegs).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 67-92: Add an early hardware guard before calling matmul_fp4 /
tilelang.compile: check that CUDA is available (e.g., torch.cuda.is_available())
and that the GPU's compute capability is SM>=12.0 by querying
torch.cuda.get_device_properties(0).major/minor (or equivalent), and if the
check fails raise/print a clear error and exit (e.g., "requires CUDA SM >=
12.0"). Place this check right before the current calls to matmul_fp4(...) and
tilelang.compile(...) so the example fails fast on non-CUDA or pre-SM120
hardware.

In `@src/target/codegen_cuda.cc`:
- Around line 3355-3357: GetBufferRef is still halving scalar FP4 indices
(dividing by 2) even though the buffer is now declared as unpacked
fp4_e2_t[constant_size] (see where PrintStorageScope and PrintType emit the
declaration); update GetBufferRef to detect unpacked/raw-byte FP4 buffers and
stop dividing indices for fp4 scalar accesses so each scalar maps to its own
fp4_e2_t element. Specifically, in GetBufferRef, identify buffers whose dtype is
fp4 and whose storage/declared element type is fp4_e2_t (or where PrintType
emitted fp4_e2_t[]) and bypass the index /= 2 path, ensuring the same fix is
applied to the similar handling referenced around the other occurrence (the
block at the 3383-3387 equivalent).

In `@tilelang/intrinsics/mma_layout.py`:
- Around line 42-54: The FP4 layout functions
ldmatrix_32x16_to_shared_16x32_fp4_layout_a and
ldmatrix_32x16_to_shared_16x32_fp4_layout_b only emit columns 0..15 (col =
local_id) and thus never produce columns 16..31; update both functions to
compute a half-row offset (e.g., half = (thread_id // 16) % 2 or derived from
the lane group bit) and set col = local_id + half * 16 so each lane group covers
the other half-row; if FP4 truly changes the per-thread fragment shape, also
rename the helpers (or update their contract) to reflect the 16x32 output shape.

---

Nitpick comments:
In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 168-179: The branch currently triggers if either AType or BType is
FP4 (DataType::kFloat4_e2m1fn) causing both operands to be pre-shifted even
though only FP4/FP4 dispatchers are registered; change the condition to require
both operands be FP4 (use AType == DataType::kFloat4_e2m1fn && BType ==
DataType::kFloat4_e2m1fn) and only perform the left-shift on AReg as[] and BReg
bs[] inside that tightened branch (references: AType, BType,
DataType::kFloat4_e2m1fn, Dispatcher::ARegType, Dispatcher::BRegType,
Dispatcher::exec, detail::MmaImplTraits::kARegs/kBRegs).

In `@tilelang/intrinsics/utils.py`:
- Around line 54-64: Update this module's public dtype annotations and helper
logic to include FP4 alongside float16/int8: add "fp4" (or the canonical FP4
dtype name used elsewhere) to the type hints/docs for functions like
get_ldmatrix_offset and any public signatures that currently list only
float16/int8, and modify get_mma_micro_size to handle the FP4 case explicitly
instead of falling through to the default k=16 path (choose the correct micro-k
for FP4 consistent with hardware/other code paths). Also scan for other helper
functions in this file that claim only float16/int8 and update their
contracts/branches so FP4 uses the correct transform and micro-k behavior
(reference get_ldmatrix_offset and get_mma_micro_size to locate where to
change).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4243d9a7-46b5-4708-ba56-4d7bcdf2cc73

📥 Commits

Reviewing files that changed from the base of the PR and between f037550 and f068f2f.

📒 Files selected for processing (5)
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/instruction/mma.h
  • tilelang/intrinsics/mma_layout.py
  • tilelang/intrinsics/utils.py

Comment on lines +67 to +92
M, N, K = 256, 256, 256
block_M, block_N, block_K = 128, 128, 128
in_dtype = T.float4_e2m1fn
out_dtype = T.float32
accum_dtype = T.float32

print(f"Running FP4 GEMM: M={M}, N={N}, K={K}")
print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}")

func = matmul_fp4(
M, N, K, block_M, block_N, block_K,
in_dtype, out_dtype, accum_dtype,
num_stages=2, threads=128,
)

jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)

print("Compilation succeeded!")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fail fast on unsupported hardware.

This example immediately compiles and allocates CUDA tensors. On non-CUDA or pre-SM120 machines the first failure will come from deep inside compile/runtime instead of a clear example-specific error.

🛡️ Suggested guard
+if not torch.cuda.is_available():
+    raise RuntimeError("example_gemm_fp4_sm120.py requires CUDA")
+
+major, minor = torch.cuda.get_device_capability()
+if (major, minor) < (12, 0):
+    raise RuntimeError("example_gemm_fp4_sm120.py requires an SM120 GPU")
+
 M, N, K = 256, 256, 256
 block_M, block_N, block_K = 128, 128, 128
 in_dtype = T.float4_e2m1fn
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
M, N, K = 256, 256, 256
block_M, block_N, block_K = 128, 128, 128
in_dtype = T.float4_e2m1fn
out_dtype = T.float32
accum_dtype = T.float32
print(f"Running FP4 GEMM: M={M}, N={N}, K={K}")
print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}")
func = matmul_fp4(
M, N, K, block_M, block_N, block_K,
in_dtype, out_dtype, accum_dtype,
num_stages=2, threads=128,
)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
print("Compilation succeeded!")
if not torch.cuda.is_available():
raise RuntimeError("example_gemm_fp4_sm120.py requires CUDA")
major, minor = torch.cuda.get_device_capability()
if (major, minor) < (12, 0):
raise RuntimeError("example_gemm_fp4_sm120.py requires an SM120 GPU")
M, N, K = 256, 256, 256
block_M, block_N, block_K = 128, 128, 128
in_dtype = T.float4_e2m1fn
out_dtype = T.float32
accum_dtype = T.float32
print(f"Running FP4 GEMM: M={M}, N={N}, K={K}")
print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}")
func = matmul_fp4(
M, N, K, block_M, block_N, block_K,
in_dtype, out_dtype, accum_dtype,
num_stages=2, threads=128,
)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
print("Compilation succeeded!")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 67 - 92, Add an
early hardware guard before calling matmul_fp4 / tilelang.compile: check that
CUDA is available (e.g., torch.cuda.is_available()) and that the GPU's compute
capability is SM>=12.0 by querying
torch.cuda.get_device_properties(0).major/minor (or equivalent), and if the
check fails raise/print a clear error and exit (e.g., "requires CUDA SM >=
12.0"). Place this check right before the current calls to matmul_fp4(...) and
tilelang.compile(...) so the example fails fast on non-CUDA or pre-SM120
hardware.

Comment thread src/target/codegen_cuda.cc
Comment thread tilelang/intrinsics/mma_layout.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (6)
examples/gemm_fp4/example_gemm_a8w4_sm120.py (4)

118-126: Consider adding warmup iterations before timing.

The benchmark starts timing immediately without warmup. The first few kernel invocations may include JIT/caching overhead, skewing the average latency upward.

Suggested fix
 # --- Benchmark ---
+# Warmup
+for _ in range(10):
+    jit_kernel(a_fp8, b_uint8)
 torch.cuda.synchronize()
 start = time.perf_counter()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` around lines 118 - 126, The
benchmark measures jit_kernel but lacks warmup runs, so initial JIT/caching
skews timings; add a small warmup loop (e.g., 10–20 iterations) that calls
jit_kernel(a_fp8, b_uint8) with torch.cuda.synchronize() before starting the
timed section, then proceed with the existing timing loop and final
synchronize/latency calculation to ensure measured runs reflect steady-state
performance.

66-66: Remove extraneous f prefix from string without placeholders.

This f-string contains no placeholder expressions.

Suggested fix
-print(f"  A: float8_e4m3fn, B: FP4 (unpacked uint8)")
+print("  A: float8_e4m3fn, B: FP4 (unpacked uint8)")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` at line 66, The print statement
uses an unnecessary f-string prefix; update the print call (the line containing
print(f"  A: float8_e4m3fn, B: FP4 (unpacked uint8)")) to use a regular string
(remove the leading f) so it reads print("  A: float8_e4m3fn, B: FP4 (unpacked
uint8)") to avoid misleading f-string usage.

11-11: Unused import os.

The os module is imported but never used in this file.

Suggested fix
-import os
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` at line 11, Remove the unused
top-level import "import os" from the module (the lone import statement) in
examples/gemm_fp4/example_gemm_a8w4_sm120.py so there’s no unused dependency;
simply delete the line "import os".

116-116: Remove extraneous f prefix from string without placeholders.

This f-string contains no placeholder expressions.

Suggested fix
-    print(f"[WARN] large diff -- may indicate layout or data flow issue")
+    print("[WARN] large diff -- may indicate layout or data flow issue")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` at line 116, The print
statement using an f-string contains no placeholders; remove the unnecessary f
prefix in the print call (the print(...) expression in
examples/gemm_fp4/example_gemm_a8w4_sm120.py that prints "[WARN] large diff --
may indicate layout or data flow issue") so it becomes a plain string literal.
tilelang/intrinsics/mma_layout.py (1)

42-54: FP4 helpers duplicate existing INT8 layout logic; consider delegating to avoid drift.

Line 45-54 mirrors ldmatrix_32x16_to_shared_16x32_layout_a/b. Reusing the existing helpers will keep future layout fixes in one place.

♻️ Proposed simplification
 def ldmatrix_32x16_to_shared_16x32_fp4_layout_a(thread_id, local_id):
     """FP4 with unpacked shared memory (1 byte/element) uses the same
     layout as INT8 — shared memory rows are 32 bytes for K=32."""
-    row = thread_id % 16
-    col = local_id + (thread_id // 16) * 16
-    return row, col
+    return ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id)
@@
 def ldmatrix_32x16_to_shared_16x32_fp4_layout_b(thread_id, local_id):
     """FP4 with unpacked shared memory — same as INT8."""
-    row = (thread_id // 16) * 8 + (thread_id % 8)
-    col = local_id + 16 * ((thread_id % 16) // 8)
-    return row, col
+    return ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mma_layout.py` around lines 42 - 54, The two FP4 helpers
ldmatrix_32x16_to_shared_16x32_fp4_layout_a and
ldmatrix_32x16_to_shared_16x32_fp4_layout_b duplicate INT8 logic; replace their
bodies to delegate to the canonical INT8 helpers
ldmatrix_32x16_to_shared_16x32_layout_a and
ldmatrix_32x16_to_shared_16x32_layout_b (i.e., return the result of calling the
corresponding INT8 function with thread_id and local_id) and update their
docstrings to state they delegate to the INT8 implementations to avoid future
drift.
src/tl_templates/cuda/instruction/mma.h (1)

176-192: SM100 and SM120 handle FP4 preprocessing differently; verify this aligns with instruction semantics.

The SM100 path (tcgen05mma.h:145–270) delegates FP4 directly to the FP8 instruction handler without preprocessing, while this SM120 code applies a 2-bit left shift before the mma.sync call. This divergence may be correct if the two instruction variants have different semantic requirements, but should be documented with an explanatory comment clarifying why the shift is necessary for SM120 but not SM100.

The compile-time ternary expressions can be made more idiomatic using if constexpr:

♻️ Idiomatic refactor suggestion
     `#pragma` unroll
     for (int i = 0; i < nA; ++i)
-      as[i] = (AType == DataType::kFloat4_e2m1fn) ? (a[i] << 2) : a[i];
+      if constexpr (AType == DataType::kFloat4_e2m1fn)
+        as[i] = a[i] << 2;
+      else
+        as[i] = a[i];
     `#pragma` unroll
     for (int i = 0; i < nB; ++i)
-      bs[i] = (BType == DataType::kFloat4_e2m1fn) ? (b[i] << 2) : b[i];
+      if constexpr (BType == DataType::kFloat4_e2m1fn)
+        bs[i] = b[i] << 2;
+      else
+        bs[i] = b[i];
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/instruction/mma.h` around lines 176 - 192, The SM120
FP4 preprocessing applies a 2-bit left shift before calling Dispatcher::exec
while SM100 routes FP4 to the FP8 handler without shifting; add a concise
comment above the SM120 branch in mma.h explaining the semantic reason for the
shift (i.e., why SM120 requires the <<2 adjustment but SM100 does not) and
ensure the behavior is intentional, and at the same time make the compile-time
selection idiomatic by replacing the ternary expressions in the loops that set
as[i] and bs[i] (which check AType/BType against DataType::kFloat4_e2m1fn) with
if constexpr branches so the choice is resolved at compile time and the intent
is clearer while retaining the existing Dispatcher::exec(c, as, bs, c) and
Dispatcher::exec(c, a, b, c) calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py`:
- Line 44: Rename the parameter named "input" to "input_tensor" to avoid
shadowing Python's built-in; update the function signature where input:
T.Tensor((num_tokens, d_hidden), "float8_e4m3fn") appears and rename every use
of that symbol inside the same function (and any callers) to the new name
(input_tensor) so references like indexing, shape checks, or passes use the
updated identifier consistently.
- Around line 34-38: The stage moe_shared_expert_a8w4 declares its output width
as d_hidden but the output indexing and writes (iteration variable by over
expert tiles and writes at by * block_expert) are keyed by d_expert, which will
mis-index unless d_hidden == d_expert; either enforce d_hidden == d_expert at
the start of moe_shared_expert_a8w4 (raise/validate) or change the stage’s
output shape to use d_expert instead of d_hidden and update any output
shape/width declarations and dependent indexing (references:
moe_shared_expert_a8w4, iteration variable by, block_expert, d_hidden, d_expert)
so writes at by * block_expert map correctly and avoid out-of-bounds access.

In `@src/tl_templates/cuda/ldsm.h`:
- Around line 54-85: Wrap the SM120-only inline PTX functions
ptx_ldmatrix_b4x16_x1, ptx_ldmatrix_b4x16_x2, and ptx_ldmatrix_b4x16_x4 with an
architecture guard so they are only compiled for SM120+; specifically surround
the entire definitions with `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >=
1200) ... `#endif` (matching the pattern used in reduce.h) to prevent compilation
on older SM targets.

In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 95-97: The current remap in in_dtype_b (checking self.B.dtype ==
"uint8" and self.C.dtype in ("float32","float16")) incorrectly forces FP4
lowering via "float4_e2m1fn" for any uint8 storage; change the condition to
require an explicit FP4 capability flag in addition to the dtype checks (e.g.,
require self.target.supports_fp4() or an explicit boolean like
self.use_fp4/self.has_fp4) before returning "float4_e2m1fn", update the remap
logic in the method that reads self.B.dtype/self.C.dtype accordingly, and leave
other uint8 paths unchanged so non-FP4 kernels are not remapped.

---

Nitpick comments:
In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py`:
- Around line 118-126: The benchmark measures jit_kernel but lacks warmup runs,
so initial JIT/caching skews timings; add a small warmup loop (e.g., 10–20
iterations) that calls jit_kernel(a_fp8, b_uint8) with torch.cuda.synchronize()
before starting the timed section, then proceed with the existing timing loop
and final synchronize/latency calculation to ensure measured runs reflect
steady-state performance.
- Line 66: The print statement uses an unnecessary f-string prefix; update the
print call (the line containing print(f"  A: float8_e4m3fn, B: FP4 (unpacked
uint8)")) to use a regular string (remove the leading f) so it reads print("  A:
float8_e4m3fn, B: FP4 (unpacked uint8)") to avoid misleading f-string usage.
- Line 11: Remove the unused top-level import "import os" from the module (the
lone import statement) in examples/gemm_fp4/example_gemm_a8w4_sm120.py so
there’s no unused dependency; simply delete the line "import os".
- Line 116: The print statement using an f-string contains no placeholders;
remove the unnecessary f prefix in the print call (the print(...) expression in
examples/gemm_fp4/example_gemm_a8w4_sm120.py that prints "[WARN] large diff --
may indicate layout or data flow issue") so it becomes a plain string literal.

In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 176-192: The SM120 FP4 preprocessing applies a 2-bit left shift
before calling Dispatcher::exec while SM100 routes FP4 to the FP8 handler
without shifting; add a concise comment above the SM120 branch in mma.h
explaining the semantic reason for the shift (i.e., why SM120 requires the <<2
adjustment but SM100 does not) and ensure the behavior is intentional, and at
the same time make the compile-time selection idiomatic by replacing the ternary
expressions in the loops that set as[i] and bs[i] (which check AType/BType
against DataType::kFloat4_e2m1fn) with if constexpr branches so the choice is
resolved at compile time and the intent is clearer while retaining the existing
Dispatcher::exec(c, as, bs, c) and Dispatcher::exec(c, a, b, c) calls.

In `@tilelang/intrinsics/mma_layout.py`:
- Around line 42-54: The two FP4 helpers
ldmatrix_32x16_to_shared_16x32_fp4_layout_a and
ldmatrix_32x16_to_shared_16x32_fp4_layout_b duplicate INT8 logic; replace their
bodies to delegate to the canonical INT8 helpers
ldmatrix_32x16_to_shared_16x32_layout_a and
ldmatrix_32x16_to_shared_16x32_layout_b (i.e., return the result of calling the
corresponding INT8 function with thread_id and local_id) and update their
docstrings to state they delegate to the INT8 implementations to avoid future
drift.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 295023f2-777b-4946-96af-18553152cc84

📥 Commits

Reviewing files that changed from the base of the PR and between f068f2f and f13a6b7.

📒 Files selected for processing (10)
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/ldsm.h
  • tilelang/intrinsics/mma_layout.py
  • tilelang/tileop/gemm/gemm_base.py
  • tilelang/tileop/gemm/gemm_mma.py

Comment on lines +34 to +38
def moe_shared_expert_a8w4(
num_tokens, d_hidden, d_expert,
block_token=128, block_hidden=128, block_expert=128,
threads=128, num_stages=1,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Output indexing is keyed by d_expert, but output width is declared as d_hidden.

Line 53 iterates by over expert tiles, while Line 48 declares output width as d_hidden and Line 80 writes at by * block_expert. This is only valid when d_hidden == d_expert; otherwise writes are incorrect/out-of-bounds. Enforce the equality or change this stage’s output shape to d_expert.

🛡️ Minimal guard to prevent invalid configs
 def moe_shared_expert_a8w4(
     num_tokens, d_hidden, d_expert,
     block_token=128, block_hidden=128, block_expert=128,
     threads=128, num_stages=1,
 ):
+    if d_hidden != d_expert:
+        raise ValueError(
+            "moe_shared_expert_a8w4 currently emits gate/up stage only; require d_hidden == d_expert"
+        )
     """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM."""

Also applies to: 48-55, 80-80

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` around lines 34 - 38, The
stage moe_shared_expert_a8w4 declares its output width as d_hidden but the
output indexing and writes (iteration variable by over expert tiles and writes
at by * block_expert) are keyed by d_expert, which will mis-index unless
d_hidden == d_expert; either enforce d_hidden == d_expert at the start of
moe_shared_expert_a8w4 (raise/validate) or change the stage’s output shape to
use d_expert instead of d_hidden and update any output shape/width declarations
and dependent indexing (references: moe_shared_expert_a8w4, iteration variable
by, block_expert, d_hidden, d_expert) so writes at by * block_expert map
correctly and avoid out-of-bounds access.


@T.prim_func
def main(
input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid shadowing Python’s built-in input.

Line 44 uses input as a parameter name; rename to input_tensor (or similar) to satisfy lint and reduce ambiguity.

🧰 Tools
🪛 Ruff (0.15.5)

[error] 44-44: Function argument input is shadowing a Python builtin

(A002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` at line 44, Rename the
parameter named "input" to "input_tensor" to avoid shadowing Python's built-in;
update the function signature where input: T.Tensor((num_tokens, d_hidden),
"float8_e4m3fn") appears and rename every use of that symbol inside the same
function (and any callers) to the new name (input_tensor) so references like
indexing, shape checks, or passes use the updated identifier consistently.

Comment on lines +54 to +85
// ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+).
// Reads packed 4-bit data from shared memory and unpacks each 4-bit value
// into the low 4 bits of an 8-bit container (upper 4 bits zeroed).
TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}

TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}

TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="$(fd 'ldsm\.h$' src | head -n1)"
echo "Inspecting: $FILE"
rg -n -C3 'ptx_ldmatrix_b4x16|__CUDA_ARCH__|__CUDA_ARCH_LIST__' "$FILE"

echo
echo "Call sites:"
rg -n -C2 'ptx_ldmatrix_b4x16_(x1|x2|x4)' src tilelang

Repository: tile-ai/tilelang

Length of output: 2350


🏁 Script executed:

cat -n src/tl_templates/cuda/ldsm.h | head -100

Repository: tile-ai/tilelang

Length of output: 4914


🏁 Script executed:

# Check if TL_DEVICE macro or other guards exist
rg -n 'TL_DEVICE|#if.*CUDA|#ifdef.*CUDA' src/tl_templates/cuda/ldsm.h | head -20

Repository: tile-ai/tilelang

Length of output: 1047


🏁 Script executed:

# Check the entire file for any architectural gating patterns
wc -l src/tl_templates/cuda/ldsm.h

Repository: tile-ai/tilelang

Length of output: 92


🏁 Script executed:

# Check for call sites of these new functions
rg -r 'ptx_ldmatrix_b4x16_(x1|x2|x4)' --type cpp --type h --type cu

Repository: tile-ai/tilelang

Length of output: 456


🏁 Script executed:

# Check TL_DEVICE macro definition
rg -n 'define TL_DEVICE' src/

Repository: tile-ai/tilelang

Length of output: 464


🏁 Script executed:

# Check common.h to understand the macro
head -50 src/tl_templates/cuda/common.h

Repository: tile-ai/tilelang

Length of output: 1622


🏁 Script executed:

rg 'ptx_ldmatrix_b4x16' src/ tilelang/ 2>/dev/null | grep -v 'void ptx_ldmatrix_b4x16'

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for any architecture guards in the build system or other headers
rg '__CUDA_ARCH__|SM[0-9]+|cudaArchitecture' src/tl_templates/cuda/ | head -20

Repository: tile-ai/tilelang

Length of output: 2285


🏁 Script executed:

# Check if there are any compilation flags or CMake guards for SM120
fd -e cmake -e CMakeLists | head -5 | xargs cat 2>/dev/null | grep -i 'sm\|arch\|compute' | head -10

Repository: tile-ai/tilelang

Length of output: 42


🌐 Web query:

CUDA b4x16_p64 ldmatrix SM120 architecture requirement

💡 Result:

For the ldmatrix variant that uses 4-bit packed source format (i.e., .src_fmt = .b4x16_p64, producing .dst_fmt = .b8x16 for shapes like .m8n16 / .m16n16):

  • It is not enough to just be on “SM120” in general; this form requires the architecture-specific Blackwell target sm_120a (likewise sm_100a / sm_101a for other Blackwell chips). The PTX ISA states that the .m16n16 / .m8n16 shapes (which are the ones that use .b4x16_p64 / .b6x16_p32) are supported on sm_120a. [1]
  • If you are using family targets (the ...f suffix), PTX ISA notes these shapes are supported on sm_120f (or higher in the same family) starting PTX ISA 8.8. [1]
  • Background/versioning: sm_120 and sm_120a targets were added in PTX ISA 8.7. [2]

So, the practical requirement for ldmatrix ... .b8x16 .b4x16_p64 on “SM120” hardware is: compile for sm_120a (or sm_120f with a new-enough PTX toolchain), not plain sm_120. [1]

Sources: PTX ISA ldmatrix target notes and format description [1]; PTX ISA release history noting addition of sm_120/sm_120a [2].


Guard SM120-specific ldmatrix ... b4x16_p64 PTX behind an SM120+ compile check.

Lines 57–85 emit inline PTX that is only available on sm_120a and later. Without an architecture gate, CUDA compilation fails on lower SM targets (e.g., sm_80, sm_90) even if these helpers are never called. Gate these three functions with #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) to match the pattern used elsewhere in the codebase (see reduce.h).

🔧 Suggested guard shape
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200)
 TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr,
                                      void *const local_ptr) {
   ...
 }
 ...
 TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr,
                                      void *const local_ptr) {
   ...
 }
+#endif
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+).
// Reads packed 4-bit data from shared memory and unpacks each 4-bit value
// into the low 4 bits of an 8-bit container (upper 4 bits zeroed).
TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
// ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+).
// Reads packed 4-bit data from shared memory and unpacks each 4-bit value
// into the low 4 bits of an 8-bit container (upper 4 bits zeroed).
`#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200)
TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
`#endif`
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/ldsm.h` around lines 54 - 85, Wrap the SM120-only
inline PTX functions ptx_ldmatrix_b4x16_x1, ptx_ldmatrix_b4x16_x2, and
ptx_ldmatrix_b4x16_x4 with an architecture guard so they are only compiled for
SM120+; specifically surround the entire definitions with `#if`
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) ... `#endif` (matching the
pattern used in reduce.h) to prevent compilation on older SM targets.

Comment on lines +95 to +97
dtype = self.B.dtype
if dtype == "uint8" and self.C.dtype in ("float32", "float16"):
return "float4_e2m1fn"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

in_dtype_b remap is too broad and can misroute non-FP4 kernels.

Line 96-97 maps any uint8 B with float accumulation to float4_e2m1fn. That can also match non-FP4 uint8 storage paths and incorrectly force FP4 MMA lowering. Please gate this remap behind an explicit FP4 signal/capability instead of the current heuristic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_base.py` around lines 95 - 97, The current remap in
in_dtype_b (checking self.B.dtype == "uint8" and self.C.dtype in
("float32","float16")) incorrectly forces FP4 lowering via "float4_e2m1fn" for
any uint8 storage; change the condition to require an explicit FP4 capability
flag in addition to the dtype checks (e.g., require self.target.supports_fp4()
or an explicit boolean like self.use_fp4/self.has_fp4) before returning
"float4_e2m1fn", update the remap logic in the method that reads
self.B.dtype/self.C.dtype accordingly, and leave other uint8 paths unchanged so
non-FP4 kernels are not remapped.

Hale423 added 9 commits March 19, 2026 23:11
Add the plumbing required to route float4_e2m1fn through the TCGEN5 MMA
code-generation path so that FP4 GEMM kernels can be emitted on SM100.

Changes:
- ptx.h / ptx.cc: add kFloat4_e2m1fn enum, string tables, DTypeFromString
- common.h: add kFloat4_e2m1fn to device-side DataType enum
- tcgen5_meta.h: add FP4 branch in encode_dtype (format code 2)
- tcgen05mma.h: add kFloat4_e2m1fn specializations for SS/TS/WS_SS
  (delegates to the existing f8f6f4 PTX kind)
- mma_macro_generator.py: add dtype_abbrv mapping for float4_e2m1fn
- docs/GEMM_NV_FP4_FEATURE_STEPS.md: design doc and progress tracker

Addresses tile-ai#1592

Made-with: Cursor
… raw bytes, get_ldmatrix_offset add fp4 special layout, add SM120_FP4_FP4_F32_TN MmaDispatcher specialization + fp4 << 2 bit-shift.
…escriptor, cp.async copy, and packed SMEM addressing
Switch FP4 TMA data type from CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B to
ALIGN16B so TMA hardware unpacks FP4 to 1 byte/element in SMEM, matching
the SMEM pointer arithmetic (sizeof(fp4_e2_t)==1). Remove all sub-byte
special-case branches in copy.cc since dtype.bytes()==1 now correctly
represents the SMEM element size. Fix tcgen05_macro_generator to use
elem_bytes instead of elem_bits//8 for byte offset and descriptor
calculations in both SS and TS variants. Enable TMA path and disable
warp-specialization for T.Pipelined FP4 GEMM example.

Made-with: Cursor
@LeiWang1999 LeiWang1999 self-requested a review April 19, 2026 02:08
@LeiWang1999
Copy link
Copy Markdown
Member

may i take this pr?

@LeiWang1999
Copy link
Copy Markdown
Member

Thanks for your contribution! but closed as we should use T.float4_e2m1fn instead of packed dtype, we will create another pr for fp4 support, thanks!

@Hale423
Copy link
Copy Markdown
Contributor Author

Hale423 commented Apr 20, 2026

Hi @LeiWang1999
Sorry for the late reply due to some internal works.
Was intended to achieve nvfp4 finally (with block scale), where the T.float4_e2m1fn has been supported logically (though front-end use int8 packed).
Want to learn from you plan on this, or can I separate a new PR for float4_e2m1fn fully supporting on SM100 & SM120 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants