[WIP][Feature]: add NV FP4 GEMM support for SM100 & SM120#1918
[WIP][Feature]: add NV FP4 GEMM support for SM100 & SM120#1918Hale423 wants to merge 13 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds end-to-end FP4 (float4_e2m1fn) support for CUDA SM120: new dtype enums and mappings, TCGEN5/TCGEN05 MMA specializations and ldmatrix helpers, MMA/dispatcher changes, codegen packing updates, and Python/CUDA GEMM examples and tests demonstrating FP4 GEMM and fused-MoE paths. Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (2)
tilelang/intrinsics/mma_macro_generator.py (1)
57-57: Consider using a shorter abbreviation for consistency.Other FP8 variants use abbreviated forms like
"e4m3"and"e5m2", but FP4 uses the full"float4_e2m1fn". While this may be intentional for PTX codegen compatibility, consider whether a shorter form like"e2m1"would be more consistent with the existing abbreviation pattern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/mma_macro_generator.py` at line 57, The mapping key "float4_e2m1fn" should be shortened to "e2m1" for consistency with other FP abbreviations ("e4m3", "e5m2"); update the mapping entry in mma_macro_generator.py and replace all references/usages that lookup or expect "float4_e2m1fn" (search for the symbol float4_e2m1fn) to use "e2m1" instead, ensuring any PTX compatibility layers or codegen consumers that relied on the old name are adjusted to the new key.examples/gemm_fp4/example_gemm_fp4_sm120.py (1)
93-96: Add one numerical check before benchmarking.This example only reports latency/TFLOPS. Please compare the kernel output against a Torch reference once and use
calc_diff()so FP4 lowering regressions are caught instead of just benchmarked.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 93 - 96, Before running the benchmark, run a single numerical correctness check comparing the JIT kernel output to a PyTorch reference and fail or log if differences are large: invoke the kernel (jit_kernel) once with representative inputs used for benchmarking, compute a reference result with Torch (same inputs/precision), call calc_diff(reference, kernel_output) and assert or log the diff if it exceeds tolerance, then proceed to profiler = jit_kernel.get_profiler() / profiler.do_bench(); this ensures FP4 lowering regressions are caught before measuring latency/TFLOPS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/GEMM_NV_FP4_FEATURE_STEPS.md`:
- Line 106: The dtype example string is malformed: change the example
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to the correct pairing such as
`in_dtype=T.float4_e2m1fn, accum_dtype=T.float` (or the intended
in_dtype/accum_dtype pair) so it clearly indicates the input is FP4 and the
accumulator is float; update the example text where the current
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` occurs to use the corrected
`in_dtype=T.float4_e2m1fn, accum_dtype=T.float` form.
- Around line 90-102: Update the declarative language in the
"五、本次已完成的修改(CuTeDSL/TCGEN5 路径)" section so it no longer claims end-to-end FP4
GEMM readiness; replace phrases like "能正确走 `kind::f8f6f4` 的 FP4 路径" and "已落地"
with scoped wording such as "dtype/codegen plumbing landed", "kernel source can
be generated", or "codegen support added" and add a short caveat noting
unresolved CUDA/C++ compile blockers and that numerical validation is pending;
ensure mentions of symbols `DataType::kFloat4_e2m1fn`, `encode_dtype`,
`tcgen05mma_ss/ts/ws_ss`, and files `ptx.h`, `ptx.cc`, `tcgen05mma.h`,
`mma_macro_generator.py` remain factual but change tense to indicate plumbing
implemented rather than fully validated.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 12-24: Add explicit validation at the start of matmul_fp4 to
reject K and block_K values that don't meet SM120 FP4 k32 granularity: check
that K % 32 == 0 and block_K % 32 == 0 (and that block_K <= K), and if not raise
a clear ValueError mentioning the function name and the required 32-granularity
constraint; this ensures callers see a Python error instead of a backend failure
when using matmul_fp4 with unsupported K/block_K.
- Around line 80-91: The example assumes an SM120 FP4 GEMM backend TileLang does
not support; update example_gemm_fp4_sm120.py to avoid falsely reporting success
by gating or removing the compile/print block: replace the unconditional
tilelang.compile(...) / print("Compilation succeeded!") /
jit_kernel.get_kernel_source() sequence with a runtime check for the
backend/feature (e.g., a TileLang API like a capability query or a check of
target="cuda"/SM version or a feature flag tied to the SM120 FP4 GEMM path) and
if the feature is missing, skip the compilation and emit a clear message
explaining the missing "sm120 nvfp4 gemm support" (issue `#1592`) so the example
doesn't run; alternatively remove the runnable compile/print code entirely until
the backend specialization (sm120 nvfp4 gemm) is implemented. Ensure you
reference the existing tilelang.compile call and jit_kernel.get_kernel_source in
your change.
In `@src/tl_templates/cuda/cuda_fp4.h`:
- Line 9: The helper functions that manipulate fp4_e2_t (aliased to
tl::float_e2m1_t) are directly accessing the private encoded bits via .__x which
is incompatible with CUTLASS storage API; update all such accesses in functions
like set_x(), set_y(), and make_fp4_e2_2_t() to use the public storage member or
the raw() accessor (e.g., replace any `.__x` reads/writes with `.storage` or
call the `.raw()` getters/setters) so that encoded bits are accessed via the
supported API.
---
Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 93-96: Before running the benchmark, run a single numerical
correctness check comparing the JIT kernel output to a PyTorch reference and
fail or log if differences are large: invoke the kernel (jit_kernel) once with
representative inputs used for benchmarking, compute a reference result with
Torch (same inputs/precision), call calc_diff(reference, kernel_output) and
assert or log the diff if it exceeds tolerance, then proceed to profiler =
jit_kernel.get_profiler() / profiler.do_bench(); this ensures FP4 lowering
regressions are caught before measuring latency/TFLOPS.
In `@tilelang/intrinsics/mma_macro_generator.py`:
- Line 57: The mapping key "float4_e2m1fn" should be shortened to "e2m1" for
consistency with other FP abbreviations ("e4m3", "e5m2"); update the mapping
entry in mma_macro_generator.py and replace all references/usages that lookup or
expect "float4_e2m1fn" (search for the symbol float4_e2m1fn) to use "e2m1"
instead, ensuring any PTX compatibility layers or codegen consumers that relied
on the old name are adjusted to the new key.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2a82ee87-b628-487c-b8b1-0acfd5470bb5
📒 Files selected for processing (11)
docs/GEMM_NV_FP4_FEATURE_STEPS.mdexamples/gemm_fp4/example_gemm_fp4_sm120.pysrc/op/tcgen5_meta.hsrc/target/ptx.ccsrc/target/ptx.hsrc/tl_templates/cuda/common.hsrc/tl_templates/cuda/cuda_fp4.hsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/instruction/tcgen05mma.htilelang/intrinsics/mma_macro_generator.pytilelang/intrinsics/utils.py
| ## 五、本次已完成的修改(CuTeDSL/TCGEN5 路径) | ||
|
|
||
| 以下修改已落地,使 **CuTeDSL 生成的 SM100 GEMM kernel** 在 A/B 为 `float4_e2m1fn` 时能正确走 `kind::f8f6f4` 的 FP4 路径。当前未改 `gemm_sm100.h` 的 CUTLASS `DispatchInstruction`(该路径由 CUTLASS 模板实例化,与 CuTeDSL 生成 TIR 再 codegen 的路径分离)。 | ||
|
|
||
| | 文件 | 修改内容 | | ||
| |------|----------| | ||
| | **src/target/ptx.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | | ||
| | **src/target/ptx.cc** | `enum_to_str` / `dtype_str` / `num_bits` 增加第 24 项;`DTypeFromString` 增加 `"float4_e2m1fn"`、`".e2m1"` → `kFloat4_e2m1fn`。 | | ||
| | **src/tl_templates/cuda/common.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | | ||
| | **src/op/tcgen5_meta.h** | `GetTCGEN5InstrDesc()` 中 `encode_dtype` 增加 `dtype.is_float4_e2m1fn()` 分支,返回 `2`(FP4 format 编码)。 | | ||
| | **src/tl_templates/cuda/instruction/tcgen05mma.h** | 为 `DataType::kFloat4_e2m1fn` 增加 `tcgen05mma_ss`、`tcgen05mma_ts`、`tcgen05mma_ws_ss` 特化,均转发到 `kFloat8_e4m3` 的 f8f6f4 实现。 | | ||
| | **tilelang/intrinsics/mma_macro_generator.py** | `dtype_abbrv` 增加 `"float4_e2m1fn": "float4_e2m1fn"`,保证 lowering 时 `a_dtype_abbrv` 传入 `ptx_tcgen05_mma_ss("float4_e2m1fn", ...)`。 | | ||
|
|
There was a problem hiding this comment.
Tone down the readiness claim in this section.
This reads as if the FP4 GEMM path is already working end-to-end, but the PR summary still calls out unresolved CUDA/C++ compile blockers. Please scope the wording to “dtype/codegen plumbing landed” or “kernel source can be generated” until build + numerical validation are actually passing.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/GEMM_NV_FP4_FEATURE_STEPS.md` around lines 90 - 102, Update the
declarative language in the "五、本次已完成的修改(CuTeDSL/TCGEN5 路径)" section so it no
longer claims end-to-end FP4 GEMM readiness; replace phrases like "能正确走
`kind::f8f6f4` 的 FP4 路径" and "已落地" with scoped wording such as "dtype/codegen
plumbing landed", "kernel source can be generated", or "codegen support added"
and add a short caveat noting unresolved CUDA/C++ compile blockers and that
numerical validation is pending; ensure mentions of symbols
`DataType::kFloat4_e2m1fn`, `encode_dtype`, `tcgen05mma_ss/ts/ws_ss`, and files
`ptx.h`, `ptx.cc`, `tcgen05mma.h`, `mma_macro_generator.py` remain factual but
change tense to indicate plumbing implemented rather than fully validated.
| **后续建议**(在有 SM100 / nvcc 环境时): | ||
|
|
||
| 1. 本地或 CI 执行完整构建并跑 `examples/gemm_sm100/` 中 BF16/F8 用例,确认无回归。 | ||
| 2. 新增 FP4 GEMM 示例:`in_dtype=accum_dtype=T.float4_e2m1fn, T.float`,`block_K=128`(K%32 已由 meta 保证),与参考实现做数值对比。 |
There was a problem hiding this comment.
Fix the dtype example syntax here.
in_dtype=accum_dtype=T.float4_e2m1fn, T.float is malformed and reads as if the accumulator were FP4. This should describe in_dtype=T.float4_e2m1fn, accum_dtype=T.float (or the intended pair).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/GEMM_NV_FP4_FEATURE_STEPS.md` at line 106, The dtype example string is
malformed: change the example `in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to
the correct pairing such as `in_dtype=T.float4_e2m1fn, accum_dtype=T.float` (or
the intended in_dtype/accum_dtype pair) so it clearly indicates the input is FP4
and the accumulator is float; update the example text where the current
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` occurs to use the corrected
`in_dtype=T.float4_e2m1fn, accum_dtype=T.float` form.
| def matmul_fp4( | ||
| M, | ||
| N, | ||
| K, | ||
| block_M, | ||
| block_N, | ||
| block_K, | ||
| in_dtype, | ||
| out_dtype, | ||
| accum_dtype, | ||
| num_stages=2, | ||
| threads=128, | ||
| ): |
There was a problem hiding this comment.
Validate SM120 FP4 k32 constraints before building the kernel.
matmul_fp4() accepts arbitrary K and block_K, but this path lowers to FP4 MMA with k32 granularity. Reject unsupported values here so callers get a clear Python error instead of a backend failure later.
🔧 Proposed fix
def matmul_fp4(
M,
N,
K,
@@
accum_dtype,
num_stages=2,
threads=128,
):
+ if K % 32 != 0 or block_K % 32 != 0:
+ raise ValueError("SM120 FP4 GEMM requires K and block_K to be multiples of 32")
+
A_shape = (M, K)
B_shape = (N, K)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 12 - 24, Add
explicit validation at the start of matmul_fp4 to reject K and block_K values
that don't meet SM120 FP4 k32 granularity: check that K % 32 == 0 and block_K %
32 == 0 (and that block_K <= K), and if not raise a clear ValueError mentioning
the function name and the required 32-granularity constraint; this ensures
callers see a Python error instead of a backend failure when using matmul_fp4
with unsupported K/block_K.
| jit_kernel = tilelang.compile( | ||
| func, | ||
| out_idx=[2], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| }, | ||
| ) | ||
|
|
||
| print("Compilation succeeded!") | ||
| print(jit_kernel.get_kernel_source()) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect likely lowering/dispatch headers for an SM120 FP4 mma_sync specialization.
fd 'mma\.h$|gemm_mma\.h$' src | while read -r f; do
echo "== $f =="
rg -n -C2 'kFloat4_e2m1fn|MmaDispatcher|mma_sync<' "$f" || true
doneRepository: tile-ai/tilelang
Length of output: 1798
🌐 Web query:
tilelang SM120 FP4 mma_sync MmaDispatcher specialization support
💡 Result:
TileLang’s CUDA SM120 path exists, but FP4 (NVFP4 / MXF4) GEMM via mma.sync + a specialized “MmaDispatcher” is not currently implemented/exposed in TileLang.
- What TileLang does for SM120 today:
gemm_sm120.his essentially a thin wrapper that just includes the generic MMA GEMM implementation (gemm_mma.h)—there’s no SM120-specific FP4 dispatch/specialization visible there. [1] - Project status for SM120 NVFP4: TileLang has an open feature request explicitly asking for “sm120 nvfp4 gemm support,” indicating it’s not supported yet (at least not in mainline TileLang). [2]
- Where SM120 FP4 support does exist: NVIDIA CUTLASS documents SM120 “mixed-precision MMA” support including FP4/FP6/FP8 (“f8f6f4” families) and dispatch policies/valid tile shapes for SM120 GEMMs. [3]
- Practical implication: if you need SM120 FP4 now, the viable route is typically CUTLASS (or TileLang’s newer CUTLASS CuTe DSL backend) rather than TileLang’s internal MMA templates. TileLang added a CuTeDSL backend (compile-to-CUTLASS CuTe DSL) in late 2025. [4]
Answer to your specific terms:
mma_sync: TileLang’s SM120 GEMM template layer doesn’t currently show FP4mma.sync-based specializations/dispatch in the SM120 header path. [1]- “
MmaDispatcherspecialization support”: not present as a TileLang-supported/public mechanism for SM120 FP4 today; the open NVFP4 request suggests it’s pending. [2]
Sources:
[1] TileLang gemm_sm120.h (main branch) https://raw.githubusercontent.com/tile-ai/tilelang/main/src/tl_templates/cuda/gemm_sm120.h
[2] TileLang issue “Request for sm120 nvfp4 gemm support” (#1592) #1592
[3] NVIDIA CUTLASS Blackwell/SM120 functionality docs (FP4/FP6/FP8 MMA families and SM120 GEMM dispatch/tile-shape tables) https://docs.nvidia.com/cutlass/4.3.4/media/docs/cpp/blackwell_functionality.html
[4] TileLang README “Added CuTeDSL backend support” (12/18/2025) https://github.com/tile-ai/tilelang
This example still assumes an end-to-end SM120 FP4 backend that TileLang does not currently support.
tilelang.compile(...) is followed by "Compilation succeeded!", but TileLang's SM120 GEMM path does not include FP4 (mma.sync/MmaDispatcher) specialization. TileLang has an open feature request for "sm120 nvfp4 gemm support" (issue #1592), confirming the missing implementation. Either land that backend specialization in this PR or gate/remove this runnable example for now.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 80 - 91, The
example assumes an SM120 FP4 GEMM backend TileLang does not support; update
example_gemm_fp4_sm120.py to avoid falsely reporting success by gating or
removing the compile/print block: replace the unconditional
tilelang.compile(...) / print("Compilation succeeded!") /
jit_kernel.get_kernel_source() sequence with a runtime check for the
backend/feature (e.g., a TileLang API like a capability query or a check of
target="cuda"/SM version or a feature flag tied to the SM120 FP4 GEMM path) and
if the feature is missing, skip the compilation and emit a clear message
explaining the missing "sm120 nvfp4 gemm support" (issue `#1592`) so the example
doesn't run; alternatively remove the runnable compile/print code entirely until
the backend specialization (sm120 nvfp4 gemm) is implemented. Ensure you
reference the existing tilelang.compile call and jit_kernel.get_kernel_source in
your change.
e040bd3 to
7cf06fa
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (6)
examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md (2)
106-106:⚠️ Potential issue | 🟡 MinorFix the dtype example syntax.
in_dtype=accum_dtype=T.float4_e2m1fn, T.floatis malformed and reads as if the accumulator were FP4. Spell the pair out explicitly, e.g.in_dtype=T.float4_e2m1fn, accum_dtype=T.float.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md` at line 106, The dtype example is malformed: change the combined expression `in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to explicitly spell out both fields so it reads e.g. `in_dtype=T.float4_e2m1fn, accum_dtype=T.float`; update the example text to use the separate symbols in_dtype and accum_dtype with the correct types T.float4_e2m1fn and T.float so it’s unambiguous that the accumulator is T.float.
90-101:⚠️ Potential issue | 🟡 MinorThis section still reads as more validated than the PR state.
The PR summary still calls out unresolved CUDA/C++ compile blockers, so this should be framed as dtype/codegen plumbing or source-generation progress rather than a working FP4 GEMM path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md` around lines 90 - 101, The section currently overstates completion; rephrase GEMM_NV_FP4_FEATURE_STEPS.md so it describes these changes as dtype/codegen plumbing and source-generation progress rather than a fully working FP4 GEMM path: change the heading and summary sentence that starts "五、本次已完成的修改" to indicate these are implemented plumbing/codegen changes (mentioning DataType::kFloat4_e2m1fn, GetTCGEN5InstrDesc encode_dtype branch, tcgen05mma_ss/ts/ws_ss specializations, dtype_abbrv addition) and add a short note that CUDA/C++ compile blockers remain and CUTLASS DispatchInstruction in gemm_sm100.h is unchanged. Ensure the wording signals work-in-progress, not validated runtime behavior.examples/gemm_fp4/example_gemm_fp4_sm120.py (2)
80-103:⚠️ Potential issue | 🟡 MinorTone down the success logging here.
The PR still lists unresolved CUDA/C++ blockers, so
Compilation succeeded!and[OK] ... compiled ...read like end-to-end readiness. Please phrase this as lowering/source-generation success until the C++ path and runtime execution actually pass.🔧 Suggested wording
-print("Compilation succeeded!") +print("Lowering/source generation succeeded.") @@ -print("\n[OK] FP4 GEMM kernel compiled and CUDA source generated successfully.") +print("\n[OK] FP4 GEMM CUDA source generated successfully.") print("[TODO] Runtime execution pending TVM sub-byte pointer arithmetic fix.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 80 - 103, Update the user-facing print messages to tone down end-to-end claims: change the "Compilation succeeded!" and the two "[OK] FP4 GEMM kernel compiled and CUDA source generated successfully." / "[TODO] Runtime execution pending..." lines so they clearly state that only compilation/source-generation succeeded and runtime/C++ execution is still pending; update the prints that reference jit_kernel.get_kernel_source() and the file write to say "CUDA source generated (runtime execution and C++ integration remain pending)" or similar, and remove any definitive "OK" language to avoid implying full end-to-end readiness.
12-24:⚠️ Potential issue | 🟡 MinorReject unsupported
K/block_Kvalues up front.This path lowers to FP4
m16n8k32, so callers need a clear Python error whenKorblock_Kis not 32-aligned, or whenblock_K > K, instead of a backend failure later.🔧 Suggested guard
def matmul_fp4( M, N, K, @@ num_stages=2, threads=128, ): + if K % 32 != 0 or block_K % 32 != 0 or block_K > K: + raise ValueError( + "matmul_fp4 requires K and block_K to be multiples of 32 and block_K <= K for SM120 FP4 MMA" + ) + A_shape = (M, K)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 12 - 24, In matmul_fp4, add upfront validation to raise a clear Python error when K or block_K are unsupported: verify K % 32 == 0, block_K % 32 == 0, and block_K <= K (and raise ValueError with a descriptive message if any check fails) so callers get a meaningful error instead of a backend failure; place these checks at the start of the matmul_fp4 function before any lowering or backend calls.src/tl_templates/cuda/cuda_fp4.h (1)
9-9:⚠️ Potential issue | 🔴 CriticalVerify that
raw()is actually the supported FP4 storage accessor.After
fp4_e2_tbecomestl::float_e2m1_t, these helpers rely onval.raw(). The PR notes still call out this header as a CUTLASS/CuTe storage-API blocker, so please confirm the accessor for the FP4 type in the version you're targeting before landing this alias.In the CUTLASS/CuTe version used by TileLang, what is the supported raw-storage accessor for the FP4 type `cute::float_e2m1_t` / `cutlass::float_e2m1_t`? Does it expose `raw()`, or should callers use a different storage API for packing/unpacking 4-bit values?Also applies to: 30-34, 73-73
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/cuda_fp4.h` at line 9, Confirm the correct low-level storage accessor for the FP4 type after aliasing fp4_e2_t = tl::float_e2m1_t: check the CUTLASS/CuTe version used by the project and verify whether cute::float_e2m1_t / cutlass::float_e2m1_t exposes raw() or a different method (e.g., bits(), as_int(), data(), u(), etc.); then update all code that calls val.raw() (references around the fp4_e2_t alias and the usages noted near the other occurrences) to call the actual supported accessor and adjust any packing/unpacking helpers to use that API, and add a short comment by the fp4_e2_t typedef noting the chosen accessor for future reference.examples/gemm_fp4/gemm_fp4_sm120.cu (1)
49-52:⚠️ Potential issue | 🔴 CriticalThese
mma_synccalls still depend on a missing FP4 dispatcher.
src/tl_templates/cuda/instruction/mma.hstill needs theMmaDispatcherspecialization fortl::mma_sync<tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat32, 16, 8, 32, ...>, so this checked-in SM120 kernel remains a compile blocker.Expect to find a specialization matching the instantiated FP4 signature; if it is absent, this kernel cannot compile.
#!/bin/bash set -euo pipefail fd 'mma\.h$|gemm_mma\.h$' src | while read -r f; do echo "== $f ==" rg -n -C2 'MmaDispatcher|kFloat4_e2m1fn|mma_sync<' "$f" || true doneAlso applies to: 64-65
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/gemm_fp4_sm120.cu` around lines 49 - 52, The kernel instantiates tl::mma_sync with FP4 (tl::DataType::kFloat4_e2m1fn) operand types and dims (16,8,32,false,true) but there is no corresponding MmaDispatcher specialization, which prevents compilation; add a MmaDispatcher specialization in the MmaDispatcher template (in mma.h) that exactly matches the signature used by the kernel (template parameters: tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat32, 16, 8, 32, false, true) and implement the required interface (the same static method/operator used by other dispatchers) to perform the FP4 packing/unpacking and call the correct low-level mma intrinsic or shared implementation; ensure the specialization forwards to any existing FP4 helper routine or implements the correct accumulation into float32 so the tl::mma_sync calls in gemm_fp4_sm120.cu compile.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/target/ptx.cc`:
- Around line 43-51: DTypeFromString now accepts kFloat4_e2m1fn but the PTX
backend lacks full FP4 support; update CheckMMADTypeCompatible, GetFragAttrs,
and the MMA/WGMMA configuration tables so they recognize and handle
kFloat4_e2m1fn end-to-end (including any frag attribute mapping used by
PrintMMAAssembly and PrintWGMMAAssembly) or, as a safer short-term fix, make
these functions explicitly reject kFloat4_e2m1fn (return false or raise the same
parse-time error) so callers like PrintMMAAssembly/PrintWGMMAAssembly never
receive an unsupported dtype; locate the symbols CheckMMADTypeCompatible,
GetFragAttrs, PrintMMAAssembly, PrintWGMMAAssembly and the MMA/WGMMA config
tables in this file and either add FP4 cases mirroring other float types or add
explicit guards that block FP4 from proceeding to assembly generation.
In `@tilelang/intrinsics/mma_macro_generator.py`:
- Around line 118-119: The ladder-transform emitter still computes k_dim as 256
// a_dtype.bits which yields 64 for float4_e2m1fn and breaks
_initialize_mma_prefix; update
TensorCoreIntrinEmitterWithLadderTransform._initialize_k_dim to apply the same
clamp (k_dim = min(256 // a_dtype.bits, 32)) or refactor both emitters to call a
shared helper (e.g., a new compute_k_dim(a_dtype) used by _initialize_k_dim in
both TensorCoreIntrinEmitter and TensorCoreIntrinEmitterWithLadderTransform) so
FP4 k_dim is capped at 32.
---
Duplicate comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 80-103: Update the user-facing print messages to tone down
end-to-end claims: change the "Compilation succeeded!" and the two "[OK] FP4
GEMM kernel compiled and CUDA source generated successfully." / "[TODO] Runtime
execution pending..." lines so they clearly state that only
compilation/source-generation succeeded and runtime/C++ execution is still
pending; update the prints that reference jit_kernel.get_kernel_source() and the
file write to say "CUDA source generated (runtime execution and C++ integration
remain pending)" or similar, and remove any definitive "OK" language to avoid
implying full end-to-end readiness.
- Around line 12-24: In matmul_fp4, add upfront validation to raise a clear
Python error when K or block_K are unsupported: verify K % 32 == 0, block_K % 32
== 0, and block_K <= K (and raise ValueError with a descriptive message if any
check fails) so callers get a meaningful error instead of a backend failure;
place these checks at the start of the matmul_fp4 function before any lowering
or backend calls.
In `@examples/gemm_fp4/gemm_fp4_sm120.cu`:
- Around line 49-52: The kernel instantiates tl::mma_sync with FP4
(tl::DataType::kFloat4_e2m1fn) operand types and dims (16,8,32,false,true) but
there is no corresponding MmaDispatcher specialization, which prevents
compilation; add a MmaDispatcher specialization in the MmaDispatcher template
(in mma.h) that exactly matches the signature used by the kernel (template
parameters: tl::DataType::kFloat4_e2m1fn, tl::DataType::kFloat4_e2m1fn,
tl::DataType::kFloat32, 16, 8, 32, false, true) and implement the required
interface (the same static method/operator used by other dispatchers) to perform
the FP4 packing/unpacking and call the correct low-level mma intrinsic or shared
implementation; ensure the specialization forwards to any existing FP4 helper
routine or implements the correct accumulation into float32 so the tl::mma_sync
calls in gemm_fp4_sm120.cu compile.
In `@examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md`:
- Line 106: The dtype example is malformed: change the combined expression
`in_dtype=accum_dtype=T.float4_e2m1fn, T.float` to explicitly spell out both
fields so it reads e.g. `in_dtype=T.float4_e2m1fn, accum_dtype=T.float`; update
the example text to use the separate symbols in_dtype and accum_dtype with the
correct types T.float4_e2m1fn and T.float so it’s unambiguous that the
accumulator is T.float.
- Around line 90-101: The section currently overstates completion; rephrase
GEMM_NV_FP4_FEATURE_STEPS.md so it describes these changes as dtype/codegen
plumbing and source-generation progress rather than a fully working FP4 GEMM
path: change the heading and summary sentence that starts "五、本次已完成的修改" to
indicate these are implemented plumbing/codegen changes (mentioning
DataType::kFloat4_e2m1fn, GetTCGEN5InstrDesc encode_dtype branch,
tcgen05mma_ss/ts/ws_ss specializations, dtype_abbrv addition) and add a short
note that CUDA/C++ compile blockers remain and CUTLASS DispatchInstruction in
gemm_sm100.h is unchanged. Ensure the wording signals work-in-progress, not
validated runtime behavior.
In `@src/tl_templates/cuda/cuda_fp4.h`:
- Line 9: Confirm the correct low-level storage accessor for the FP4 type after
aliasing fp4_e2_t = tl::float_e2m1_t: check the CUTLASS/CuTe version used by the
project and verify whether cute::float_e2m1_t / cutlass::float_e2m1_t exposes
raw() or a different method (e.g., bits(), as_int(), data(), u(), etc.); then
update all code that calls val.raw() (references around the fp4_e2_t alias and
the usages noted near the other occurrences) to call the actual supported
accessor and adjust any packing/unpacking helpers to use that API, and add a
short comment by the fp4_e2_t typedef noting the chosen accessor for future
reference.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4e8ca97b-bcfd-44fd-9624-46066d0f3e9c
📒 Files selected for processing (15)
docs/GEMM_NV_FP4_FEATURE_STEPS.mdexamples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.mdexamples/gemm_fp4/example_gemm_fp4_sm120.pyexamples/gemm_fp4/gemm_fp4_sm120.cusrc/op/tcgen5_meta.hsrc/target/codegen_cuda.ccsrc/target/ptx.ccsrc/target/ptx.hsrc/tl_templates/cuda/common.hsrc/tl_templates/cuda/cuda_fp4.hsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/instruction/mma.hsrc/tl_templates/cuda/instruction/tcgen05mma.htilelang/intrinsics/mma_macro_generator.pytilelang/intrinsics/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/tl_templates/cuda/gemm_mma.h
- src/target/ptx.h
| "kBit16", "kBit32", "kBit64", "kFloat4_e2m1fn"}; | ||
|
|
||
| static const char *dtype_str[] = { | ||
| ".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32", | ||
| ".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32", | ||
| ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; | ||
| ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64", ".e2m1"}; | ||
| static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, | ||
| 64, 64, 8, 8, 16, 16, 32, 32, | ||
| 32, 64, 1, 8, 16, 32, 64}; | ||
| 32, 64, 1, 8, 16, 32, 64, 4}; |
There was a problem hiding this comment.
Finish the PTX-side FP4 wiring before accepting the new dtype string.
DTypeFromString() now parses FP4, but this file still has no kFloat4_e2m1fn handling in CheckMMADTypeCompatible(), GetFragAttrs(), or the valid MMA/WGMMA config tables. That means any FP4 caller that reaches PrintMMAAssembly() / PrintWGMMAAssembly() will parse successfully and then fatal later in the same file. Either complete those cases end-to-end or keep the parser rejecting FP4 on the PTX-assembly path for now.
Also applies to: 83-84
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/ptx.cc` around lines 43 - 51, DTypeFromString now accepts
kFloat4_e2m1fn but the PTX backend lacks full FP4 support; update
CheckMMADTypeCompatible, GetFragAttrs, and the MMA/WGMMA configuration tables so
they recognize and handle kFloat4_e2m1fn end-to-end (including any frag
attribute mapping used by PrintMMAAssembly and PrintWGMMAAssembly) or, as a
safer short-term fix, make these functions explicitly reject kFloat4_e2m1fn
(return false or raise the same parse-time error) so callers like
PrintMMAAssembly/PrintWGMMAAssembly never receive an unsupported dtype; locate
the symbols CheckMMADTypeCompatible, GetFragAttrs, PrintMMAAssembly,
PrintWGMMAAssembly and the MMA/WGMMA config tables in this file and either add
FP4 cases mirroring other float types or add explicit guards that block FP4 from
proceeding to assembly generation.
| # MMA k_dim caps at 32 (m16n8k32 is the widest K for FP8/FP4) | ||
| self.k_dim = min(256 // a_dtype.bits, 32) |
There was a problem hiding this comment.
Apply the FP4 k32 clamp to the ladder-transform emitter too.
TensorCoreIntrinEmitterWithLadderTransform._initialize_k_dim() still computes 256 // bits, so float4_e2m1fn becomes 64 there and later trips _initialize_mma_prefix() with Unsupported k_dim. Please share this clamp or route both classes through one helper.
🔧 Suggested fix
def _initialize_k_dim(self, a_dtype=T.float16):
- self.k_dim = 256 // DataType(a_dtype).bits
+ self.k_dim = min(256 // DataType(a_dtype).bits, 32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/intrinsics/mma_macro_generator.py` around lines 118 - 119, The
ladder-transform emitter still computes k_dim as 256 // a_dtype.bits which
yields 64 for float4_e2m1fn and breaks _initialize_mma_prefix; update
TensorCoreIntrinEmitterWithLadderTransform._initialize_k_dim to apply the same
clamp (k_dim = min(256 // a_dtype.bits, 32)) or refactor both emitters to call a
shared helper (e.g., a new compute_k_dim(a_dtype) used by _initialize_k_dim in
both TensorCoreIntrinEmitter and TensorCoreIntrinEmitterWithLadderTransform) so
FP4 k_dim is capped at 32.
f037550 to
f068f2f
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
examples/gemm_fp4/example_gemm_fp4_sm120.py (1)
118-119: Consider documenting the tolerance threshold rationale.A
max_abs_diff < 1.0threshold is quite loose. While FP4's coarse quantization may justify this, a brief comment explaining the expected precision bounds would help future maintainers understand whether this is intentional or a placeholder.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 118 - 119, The pass condition uses a very loose tolerance (max_abs_diff < 1.0) which needs an explanatory comment: update the block that checks the variable max_diff and prints "[PASS] numerical verification (max_abs_diff < 1.0)" to include a short comment describing why 1.0 is acceptable for FP4 (e.g., expected coarse quantization error for FP4 on sm120, any empirical basis or citation, and that this is intentional rather than a placeholder), or replace the literal with a named constant (e.g., FP4_TOLERANCE) and comment that constant explaining the expected precision bounds and how it was chosen.src/tl_templates/cuda/instruction/mma.h (1)
168-182: Narrow FP4 preprocessing guard to avoid future mixed-type mispacking.At Lines 168-178, both operands are shifted when either side is FP4. That’s fine today (only FP4/FP4 is dispatched), but it can silently break if mixed FP4 dispatchers are introduced later. Prefer guarding this path with
AType == kFloat4_e2m1fn && BType == kFloat4_e2m1fn(or shifting each operand conditionally).Suggested minimal guard tightening
- if constexpr (AType == DataType::kFloat4_e2m1fn || - BType == DataType::kFloat4_e2m1fn) { + if constexpr (AType == DataType::kFloat4_e2m1fn && + BType == DataType::kFloat4_e2m1fn) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/instruction/mma.h` around lines 168 - 182, The current branch shifts both operands when either AType or BType equals DataType::kFloat4_e2m1fn, which can mispack mixed-type calls; change the guard so it only takes this path when both types are FP4 (i.e., AType == DataType::kFloat4_e2m1fn && BType == DataType::kFloat4_e2m1fn) or alternatively apply shifts per-operand (shift a[] only when AType is FP4 and shift b[] only when BType is FP4) before calling Dispatcher::exec; update the if constexpr condition and the uses of as/bs (or keep original a/b when not shifted) to ensure correct typing and avoid silent mispacking.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 62-64: The make_fp4_tensor function silently truncates when K is
odd because it uses K // 2; add a validation at the start of make_fp4_tensor to
check that K % 2 == 0 and raise a clear ValueError (or AssertionError) if not,
including K in the message, so callers know K must be even before creating the
packed (M, K//2) int8 tensor.
- Line 121: The print statement uses an unnecessary f-string prefix; remove the
leading "f" from the call so the literal is a normal string (replace
print(f"[WARN] large diff — may indicate layout or data flow issue") with
print("[WARN] large diff — may indicate layout or data flow issue") to eliminate
the unused f-string).
---
Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 118-119: The pass condition uses a very loose tolerance
(max_abs_diff < 1.0) which needs an explanatory comment: update the block that
checks the variable max_diff and prints "[PASS] numerical verification
(max_abs_diff < 1.0)" to include a short comment describing why 1.0 is
acceptable for FP4 (e.g., expected coarse quantization error for FP4 on sm120,
any empirical basis or citation, and that this is intentional rather than a
placeholder), or replace the literal with a named constant (e.g., FP4_TOLERANCE)
and comment that constant explaining the expected precision bounds and how it
was chosen.
In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 168-182: The current branch shifts both operands when either AType
or BType equals DataType::kFloat4_e2m1fn, which can mispack mixed-type calls;
change the guard so it only takes this path when both types are FP4 (i.e., AType
== DataType::kFloat4_e2m1fn && BType == DataType::kFloat4_e2m1fn) or
alternatively apply shifts per-operand (shift a[] only when AType is FP4 and
shift b[] only when BType is FP4) before calling Dispatcher::exec; update the if
constexpr condition and the uses of as/bs (or keep original a/b when not
shifted) to ensure correct typing and avoid silent mispacking.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: db4fd421-ef6b-49ea-b04c-562028950504
📒 Files selected for processing (5)
examples/gemm_fp4/example_gemm_fp4_sm120.pysrc/target/codegen_cuda.ccsrc/tl_templates/cuda/instruction/mma.htilelang/intrinsics/mma_layout.pytilelang/intrinsics/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tilelang/intrinsics/utils.py
- src/target/codegen_cuda.cc
| def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor: | ||
| """Create random packed FP4 tensor as (M, K//2) int8.""" | ||
| return torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device).to(torch.int8) |
There was a problem hiding this comment.
Add validation that K is even.
K // 2 silently truncates if K is odd, which would create a tensor with fewer elements than expected.
🛡️ Proposed fix
def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor:
"""Create random packed FP4 tensor as (M, K//2) int8."""
+ if K % 2 != 0:
+ raise ValueError("K must be even for packed FP4 representation")
return torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device).to(torch.int8)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 62 - 64, The
make_fp4_tensor function silently truncates when K is odd because it uses K //
2; add a validation at the start of make_fp4_tensor to check that K % 2 == 0 and
raise a clear ValueError (or AssertionError) if not, including K in the message,
so callers know K must be even before creating the packed (M, K//2) int8 tensor.
| if max_diff < 1.0: | ||
| print("[PASS] numerical verification (max_abs_diff < 1.0)") | ||
| else: | ||
| print(f"[WARN] large diff — may indicate layout or data flow issue") |
There was a problem hiding this comment.
Remove extraneous f-string prefix.
The string has no placeholders, so the f prefix is unnecessary.
🔧 Proposed fix
- print(f"[WARN] large diff — may indicate layout or data flow issue")
+ print("[WARN] large diff — may indicate layout or data flow issue")🧰 Tools
🪛 Ruff (0.15.5)
[error] 121-121: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` at line 121, The print statement
uses an unnecessary f-string prefix; remove the leading "f" from the call so the
literal is a normal string (replace print(f"[WARN] large diff — may indicate
layout or data flow issue") with print("[WARN] large diff — may indicate layout
or data flow issue") to eliminate the unused f-string).
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tilelang/intrinsics/utils.py (1)
54-64: Update the rest of the FP4 helper contract in this module.
get_ldmatrix_offset()now accepts FP4, but the publicdtypeannotations here still advertise onlyfloat16/int8, andget_mma_micro_size()below still leaves FP4 on the defaultk=16path. Keeping those helpers in sync avoids future FP4 callers selecting the wrong micro-k.♻️ Suggested follow-up
def get_ldmatrix_offset( matrix: Literal["A", "B"], row_idx, col_idx, stride, - dtype: Literal["float16", "int8"] = "float16", + dtype: Literal["float16", "int8", "float4_e2m1fn"] = "float16", transposed: bool = False, ): @@ -def get_mma_micro_size(dtype: Literal["float16", "int8"]): +def get_mma_micro_size( + dtype: Literal["float16", "int8", "float8_e4m3", "float8_e5m2", "float4_e2m1fn"] +): @@ - if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: + if dtype in {"float8_e4m3", "float8_e5m2", "float4_e2m1fn", "int8"}: micro_size_k = 32🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/utils.py` around lines 54 - 64, Update this module's public dtype annotations and helper logic to include FP4 alongside float16/int8: add "fp4" (or the canonical FP4 dtype name used elsewhere) to the type hints/docs for functions like get_ldmatrix_offset and any public signatures that currently list only float16/int8, and modify get_mma_micro_size to handle the FP4 case explicitly instead of falling through to the default k=16 path (choose the correct micro-k for FP4 consistent with hardware/other code paths). Also scan for other helper functions in this file that claim only float16/int8 and update their contracts/branches so FP4 uses the correct transform and micro-k behavior (reference get_ldmatrix_offset and get_mma_micro_size to locate where to change).src/tl_templates/cuda/instruction/mma.h (1)
168-179: Tighten the FP4 pre-shift to the supported FP4/FP4 case.This branch fires when either operand is FP4, but this file only registers an FP4/FP4 dispatcher. If a mixed FP4 kernel is added later, the non-FP4 operand will be shifted too.
♻️ Suggested change
- if constexpr (AType == DataType::kFloat4_e2m1fn || - BType == DataType::kFloat4_e2m1fn) { + if constexpr (AType == DataType::kFloat4_e2m1fn && + BType == DataType::kFloat4_e2m1fn) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/instruction/mma.h` around lines 168 - 179, The branch currently triggers if either AType or BType is FP4 (DataType::kFloat4_e2m1fn) causing both operands to be pre-shifted even though only FP4/FP4 dispatchers are registered; change the condition to require both operands be FP4 (use AType == DataType::kFloat4_e2m1fn && BType == DataType::kFloat4_e2m1fn) and only perform the left-shift on AReg as[] and BReg bs[] inside that tightened branch (references: AType, BType, DataType::kFloat4_e2m1fn, Dispatcher::ARegType, Dispatcher::BRegType, Dispatcher::exec, detail::MmaImplTraits::kARegs/kBRegs).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 67-92: Add an early hardware guard before calling matmul_fp4 /
tilelang.compile: check that CUDA is available (e.g., torch.cuda.is_available())
and that the GPU's compute capability is SM>=12.0 by querying
torch.cuda.get_device_properties(0).major/minor (or equivalent), and if the
check fails raise/print a clear error and exit (e.g., "requires CUDA SM >=
12.0"). Place this check right before the current calls to matmul_fp4(...) and
tilelang.compile(...) so the example fails fast on non-CUDA or pre-SM120
hardware.
In `@src/target/codegen_cuda.cc`:
- Around line 3355-3357: GetBufferRef is still halving scalar FP4 indices
(dividing by 2) even though the buffer is now declared as unpacked
fp4_e2_t[constant_size] (see where PrintStorageScope and PrintType emit the
declaration); update GetBufferRef to detect unpacked/raw-byte FP4 buffers and
stop dividing indices for fp4 scalar accesses so each scalar maps to its own
fp4_e2_t element. Specifically, in GetBufferRef, identify buffers whose dtype is
fp4 and whose storage/declared element type is fp4_e2_t (or where PrintType
emitted fp4_e2_t[]) and bypass the index /= 2 path, ensuring the same fix is
applied to the similar handling referenced around the other occurrence (the
block at the 3383-3387 equivalent).
In `@tilelang/intrinsics/mma_layout.py`:
- Around line 42-54: The FP4 layout functions
ldmatrix_32x16_to_shared_16x32_fp4_layout_a and
ldmatrix_32x16_to_shared_16x32_fp4_layout_b only emit columns 0..15 (col =
local_id) and thus never produce columns 16..31; update both functions to
compute a half-row offset (e.g., half = (thread_id // 16) % 2 or derived from
the lane group bit) and set col = local_id + half * 16 so each lane group covers
the other half-row; if FP4 truly changes the per-thread fragment shape, also
rename the helpers (or update their contract) to reflect the 16x32 output shape.
---
Nitpick comments:
In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 168-179: The branch currently triggers if either AType or BType is
FP4 (DataType::kFloat4_e2m1fn) causing both operands to be pre-shifted even
though only FP4/FP4 dispatchers are registered; change the condition to require
both operands be FP4 (use AType == DataType::kFloat4_e2m1fn && BType ==
DataType::kFloat4_e2m1fn) and only perform the left-shift on AReg as[] and BReg
bs[] inside that tightened branch (references: AType, BType,
DataType::kFloat4_e2m1fn, Dispatcher::ARegType, Dispatcher::BRegType,
Dispatcher::exec, detail::MmaImplTraits::kARegs/kBRegs).
In `@tilelang/intrinsics/utils.py`:
- Around line 54-64: Update this module's public dtype annotations and helper
logic to include FP4 alongside float16/int8: add "fp4" (or the canonical FP4
dtype name used elsewhere) to the type hints/docs for functions like
get_ldmatrix_offset and any public signatures that currently list only
float16/int8, and modify get_mma_micro_size to handle the FP4 case explicitly
instead of falling through to the default k=16 path (choose the correct micro-k
for FP4 consistent with hardware/other code paths). Also scan for other helper
functions in this file that claim only float16/int8 and update their
contracts/branches so FP4 uses the correct transform and micro-k behavior
(reference get_ldmatrix_offset and get_mma_micro_size to locate where to
change).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4243d9a7-46b5-4708-ba56-4d7bcdf2cc73
📒 Files selected for processing (5)
examples/gemm_fp4/example_gemm_fp4_sm120.pysrc/target/codegen_cuda.ccsrc/tl_templates/cuda/instruction/mma.htilelang/intrinsics/mma_layout.pytilelang/intrinsics/utils.py
| M, N, K = 256, 256, 256 | ||
| block_M, block_N, block_K = 128, 128, 128 | ||
| in_dtype = T.float4_e2m1fn | ||
| out_dtype = T.float32 | ||
| accum_dtype = T.float32 | ||
|
|
||
| print(f"Running FP4 GEMM: M={M}, N={N}, K={K}") | ||
| print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") | ||
|
|
||
| func = matmul_fp4( | ||
| M, N, K, block_M, block_N, block_K, | ||
| in_dtype, out_dtype, accum_dtype, | ||
| num_stages=2, threads=128, | ||
| ) | ||
|
|
||
| jit_kernel = tilelang.compile( | ||
| func, | ||
| out_idx=[2], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| }, | ||
| ) | ||
|
|
||
| print("Compilation succeeded!") |
There was a problem hiding this comment.
Fail fast on unsupported hardware.
This example immediately compiles and allocates CUDA tensors. On non-CUDA or pre-SM120 machines the first failure will come from deep inside compile/runtime instead of a clear example-specific error.
🛡️ Suggested guard
+if not torch.cuda.is_available():
+ raise RuntimeError("example_gemm_fp4_sm120.py requires CUDA")
+
+major, minor = torch.cuda.get_device_capability()
+if (major, minor) < (12, 0):
+ raise RuntimeError("example_gemm_fp4_sm120.py requires an SM120 GPU")
+
M, N, K = 256, 256, 256
block_M, block_N, block_K = 128, 128, 128
in_dtype = T.float4_e2m1fn📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| M, N, K = 256, 256, 256 | |
| block_M, block_N, block_K = 128, 128, 128 | |
| in_dtype = T.float4_e2m1fn | |
| out_dtype = T.float32 | |
| accum_dtype = T.float32 | |
| print(f"Running FP4 GEMM: M={M}, N={N}, K={K}") | |
| print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") | |
| func = matmul_fp4( | |
| M, N, K, block_M, block_N, block_K, | |
| in_dtype, out_dtype, accum_dtype, | |
| num_stages=2, threads=128, | |
| ) | |
| jit_kernel = tilelang.compile( | |
| func, | |
| out_idx=[2], | |
| target="cuda", | |
| pass_configs={ | |
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | |
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | |
| }, | |
| ) | |
| print("Compilation succeeded!") | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("example_gemm_fp4_sm120.py requires CUDA") | |
| major, minor = torch.cuda.get_device_capability() | |
| if (major, minor) < (12, 0): | |
| raise RuntimeError("example_gemm_fp4_sm120.py requires an SM120 GPU") | |
| M, N, K = 256, 256, 256 | |
| block_M, block_N, block_K = 128, 128, 128 | |
| in_dtype = T.float4_e2m1fn | |
| out_dtype = T.float32 | |
| accum_dtype = T.float32 | |
| print(f"Running FP4 GEMM: M={M}, N={N}, K={K}") | |
| print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") | |
| func = matmul_fp4( | |
| M, N, K, block_M, block_N, block_K, | |
| in_dtype, out_dtype, accum_dtype, | |
| num_stages=2, threads=128, | |
| ) | |
| jit_kernel = tilelang.compile( | |
| func, | |
| out_idx=[2], | |
| target="cuda", | |
| pass_configs={ | |
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | |
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | |
| }, | |
| ) | |
| print("Compilation succeeded!") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 67 - 92, Add an
early hardware guard before calling matmul_fp4 / tilelang.compile: check that
CUDA is available (e.g., torch.cuda.is_available()) and that the GPU's compute
capability is SM>=12.0 by querying
torch.cuda.get_device_properties(0).major/minor (or equivalent), and if the
check fails raise/print a clear error and exit (e.g., "requires CUDA SM >=
12.0"). Place this check right before the current calls to matmul_fp4(...) and
tilelang.compile(...) so the example fails fast on non-CUDA or pre-SM120
hardware.
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (6)
examples/gemm_fp4/example_gemm_a8w4_sm120.py (4)
118-126: Consider adding warmup iterations before timing.The benchmark starts timing immediately without warmup. The first few kernel invocations may include JIT/caching overhead, skewing the average latency upward.
Suggested fix
# --- Benchmark --- +# Warmup +for _ in range(10): + jit_kernel(a_fp8, b_uint8) torch.cuda.synchronize() start = time.perf_counter()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` around lines 118 - 126, The benchmark measures jit_kernel but lacks warmup runs, so initial JIT/caching skews timings; add a small warmup loop (e.g., 10–20 iterations) that calls jit_kernel(a_fp8, b_uint8) with torch.cuda.synchronize() before starting the timed section, then proceed with the existing timing loop and final synchronize/latency calculation to ensure measured runs reflect steady-state performance.
66-66: Remove extraneousfprefix from string without placeholders.This f-string contains no placeholder expressions.
Suggested fix
-print(f" A: float8_e4m3fn, B: FP4 (unpacked uint8)") +print(" A: float8_e4m3fn, B: FP4 (unpacked uint8)")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` at line 66, The print statement uses an unnecessary f-string prefix; update the print call (the line containing print(f" A: float8_e4m3fn, B: FP4 (unpacked uint8)")) to use a regular string (remove the leading f) so it reads print(" A: float8_e4m3fn, B: FP4 (unpacked uint8)") to avoid misleading f-string usage.
11-11: Unused importos.The
osmodule is imported but never used in this file.Suggested fix
-import os🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` at line 11, Remove the unused top-level import "import os" from the module (the lone import statement) in examples/gemm_fp4/example_gemm_a8w4_sm120.py so there’s no unused dependency; simply delete the line "import os".
116-116: Remove extraneousfprefix from string without placeholders.This f-string contains no placeholder expressions.
Suggested fix
- print(f"[WARN] large diff -- may indicate layout or data flow issue") + print("[WARN] large diff -- may indicate layout or data flow issue")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py` at line 116, The print statement using an f-string contains no placeholders; remove the unnecessary f prefix in the print call (the print(...) expression in examples/gemm_fp4/example_gemm_a8w4_sm120.py that prints "[WARN] large diff -- may indicate layout or data flow issue") so it becomes a plain string literal.tilelang/intrinsics/mma_layout.py (1)
42-54: FP4 helpers duplicate existing INT8 layout logic; consider delegating to avoid drift.Line 45-54 mirrors
ldmatrix_32x16_to_shared_16x32_layout_a/b. Reusing the existing helpers will keep future layout fixes in one place.♻️ Proposed simplification
def ldmatrix_32x16_to_shared_16x32_fp4_layout_a(thread_id, local_id): """FP4 with unpacked shared memory (1 byte/element) uses the same layout as INT8 — shared memory rows are 32 bytes for K=32.""" - row = thread_id % 16 - col = local_id + (thread_id // 16) * 16 - return row, col + return ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id) @@ def ldmatrix_32x16_to_shared_16x32_fp4_layout_b(thread_id, local_id): """FP4 with unpacked shared memory — same as INT8.""" - row = (thread_id // 16) * 8 + (thread_id % 8) - col = local_id + 16 * ((thread_id % 16) // 8) - return row, col + return ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/mma_layout.py` around lines 42 - 54, The two FP4 helpers ldmatrix_32x16_to_shared_16x32_fp4_layout_a and ldmatrix_32x16_to_shared_16x32_fp4_layout_b duplicate INT8 logic; replace their bodies to delegate to the canonical INT8 helpers ldmatrix_32x16_to_shared_16x32_layout_a and ldmatrix_32x16_to_shared_16x32_layout_b (i.e., return the result of calling the corresponding INT8 function with thread_id and local_id) and update their docstrings to state they delegate to the INT8 implementations to avoid future drift.src/tl_templates/cuda/instruction/mma.h (1)
176-192: SM100 and SM120 handle FP4 preprocessing differently; verify this aligns with instruction semantics.The SM100 path (
tcgen05mma.h:145–270) delegates FP4 directly to the FP8 instruction handler without preprocessing, while this SM120 code applies a 2-bit left shift before themma.synccall. This divergence may be correct if the two instruction variants have different semantic requirements, but should be documented with an explanatory comment clarifying why the shift is necessary for SM120 but not SM100.The compile-time ternary expressions can be made more idiomatic using
if constexpr:♻️ Idiomatic refactor suggestion
`#pragma` unroll for (int i = 0; i < nA; ++i) - as[i] = (AType == DataType::kFloat4_e2m1fn) ? (a[i] << 2) : a[i]; + if constexpr (AType == DataType::kFloat4_e2m1fn) + as[i] = a[i] << 2; + else + as[i] = a[i]; `#pragma` unroll for (int i = 0; i < nB; ++i) - bs[i] = (BType == DataType::kFloat4_e2m1fn) ? (b[i] << 2) : b[i]; + if constexpr (BType == DataType::kFloat4_e2m1fn) + bs[i] = b[i] << 2; + else + bs[i] = b[i];🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/instruction/mma.h` around lines 176 - 192, The SM120 FP4 preprocessing applies a 2-bit left shift before calling Dispatcher::exec while SM100 routes FP4 to the FP8 handler without shifting; add a concise comment above the SM120 branch in mma.h explaining the semantic reason for the shift (i.e., why SM120 requires the <<2 adjustment but SM100 does not) and ensure the behavior is intentional, and at the same time make the compile-time selection idiomatic by replacing the ternary expressions in the loops that set as[i] and bs[i] (which check AType/BType against DataType::kFloat4_e2m1fn) with if constexpr branches so the choice is resolved at compile time and the intent is clearer while retaining the existing Dispatcher::exec(c, as, bs, c) and Dispatcher::exec(c, a, b, c) calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py`:
- Line 44: Rename the parameter named "input" to "input_tensor" to avoid
shadowing Python's built-in; update the function signature where input:
T.Tensor((num_tokens, d_hidden), "float8_e4m3fn") appears and rename every use
of that symbol inside the same function (and any callers) to the new name
(input_tensor) so references like indexing, shape checks, or passes use the
updated identifier consistently.
- Around line 34-38: The stage moe_shared_expert_a8w4 declares its output width
as d_hidden but the output indexing and writes (iteration variable by over
expert tiles and writes at by * block_expert) are keyed by d_expert, which will
mis-index unless d_hidden == d_expert; either enforce d_hidden == d_expert at
the start of moe_shared_expert_a8w4 (raise/validate) or change the stage’s
output shape to use d_expert instead of d_hidden and update any output
shape/width declarations and dependent indexing (references:
moe_shared_expert_a8w4, iteration variable by, block_expert, d_hidden, d_expert)
so writes at by * block_expert map correctly and avoid out-of-bounds access.
In `@src/tl_templates/cuda/ldsm.h`:
- Around line 54-85: Wrap the SM120-only inline PTX functions
ptx_ldmatrix_b4x16_x1, ptx_ldmatrix_b4x16_x2, and ptx_ldmatrix_b4x16_x4 with an
architecture guard so they are only compiled for SM120+; specifically surround
the entire definitions with `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >=
1200) ... `#endif` (matching the pattern used in reduce.h) to prevent compilation
on older SM targets.
In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 95-97: The current remap in in_dtype_b (checking self.B.dtype ==
"uint8" and self.C.dtype in ("float32","float16")) incorrectly forces FP4
lowering via "float4_e2m1fn" for any uint8 storage; change the condition to
require an explicit FP4 capability flag in addition to the dtype checks (e.g.,
require self.target.supports_fp4() or an explicit boolean like
self.use_fp4/self.has_fp4) before returning "float4_e2m1fn", update the remap
logic in the method that reads self.B.dtype/self.C.dtype accordingly, and leave
other uint8 paths unchanged so non-FP4 kernels are not remapped.
---
Nitpick comments:
In `@examples/gemm_fp4/example_gemm_a8w4_sm120.py`:
- Around line 118-126: The benchmark measures jit_kernel but lacks warmup runs,
so initial JIT/caching skews timings; add a small warmup loop (e.g., 10–20
iterations) that calls jit_kernel(a_fp8, b_uint8) with torch.cuda.synchronize()
before starting the timed section, then proceed with the existing timing loop
and final synchronize/latency calculation to ensure measured runs reflect
steady-state performance.
- Line 66: The print statement uses an unnecessary f-string prefix; update the
print call (the line containing print(f" A: float8_e4m3fn, B: FP4 (unpacked
uint8)")) to use a regular string (remove the leading f) so it reads print(" A:
float8_e4m3fn, B: FP4 (unpacked uint8)") to avoid misleading f-string usage.
- Line 11: Remove the unused top-level import "import os" from the module (the
lone import statement) in examples/gemm_fp4/example_gemm_a8w4_sm120.py so
there’s no unused dependency; simply delete the line "import os".
- Line 116: The print statement using an f-string contains no placeholders;
remove the unnecessary f prefix in the print call (the print(...) expression in
examples/gemm_fp4/example_gemm_a8w4_sm120.py that prints "[WARN] large diff --
may indicate layout or data flow issue") so it becomes a plain string literal.
In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 176-192: The SM120 FP4 preprocessing applies a 2-bit left shift
before calling Dispatcher::exec while SM100 routes FP4 to the FP8 handler
without shifting; add a concise comment above the SM120 branch in mma.h
explaining the semantic reason for the shift (i.e., why SM120 requires the <<2
adjustment but SM100 does not) and ensure the behavior is intentional, and at
the same time make the compile-time selection idiomatic by replacing the ternary
expressions in the loops that set as[i] and bs[i] (which check AType/BType
against DataType::kFloat4_e2m1fn) with if constexpr branches so the choice is
resolved at compile time and the intent is clearer while retaining the existing
Dispatcher::exec(c, as, bs, c) and Dispatcher::exec(c, a, b, c) calls.
In `@tilelang/intrinsics/mma_layout.py`:
- Around line 42-54: The two FP4 helpers
ldmatrix_32x16_to_shared_16x32_fp4_layout_a and
ldmatrix_32x16_to_shared_16x32_fp4_layout_b duplicate INT8 logic; replace their
bodies to delegate to the canonical INT8 helpers
ldmatrix_32x16_to_shared_16x32_layout_a and
ldmatrix_32x16_to_shared_16x32_layout_b (i.e., return the result of calling the
corresponding INT8 function with thread_id and local_id) and update their
docstrings to state they delegate to the INT8 implementations to avoid future
drift.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 295023f2-777b-4946-96af-18553152cc84
📒 Files selected for processing (10)
examples/gemm_fp4/example_fusedmoe_a8w4_sm120.pyexamples/gemm_fp4/example_gemm_a8w4_sm120.pyexamples/gemm_fp4/example_gemm_fp4_sm120.pysrc/target/codegen_cuda.ccsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/instruction/mma.hsrc/tl_templates/cuda/ldsm.htilelang/intrinsics/mma_layout.pytilelang/tileop/gemm/gemm_base.pytilelang/tileop/gemm/gemm_mma.py
| def moe_shared_expert_a8w4( | ||
| num_tokens, d_hidden, d_expert, | ||
| block_token=128, block_hidden=128, block_expert=128, | ||
| threads=128, num_stages=1, | ||
| ): |
There was a problem hiding this comment.
Output indexing is keyed by d_expert, but output width is declared as d_hidden.
Line 53 iterates by over expert tiles, while Line 48 declares output width as d_hidden and Line 80 writes at by * block_expert. This is only valid when d_hidden == d_expert; otherwise writes are incorrect/out-of-bounds. Enforce the equality or change this stage’s output shape to d_expert.
🛡️ Minimal guard to prevent invalid configs
def moe_shared_expert_a8w4(
num_tokens, d_hidden, d_expert,
block_token=128, block_hidden=128, block_expert=128,
threads=128, num_stages=1,
):
+ if d_hidden != d_expert:
+ raise ValueError(
+ "moe_shared_expert_a8w4 currently emits gate/up stage only; require d_hidden == d_expert"
+ )
"""Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM."""Also applies to: 48-55, 80-80
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` around lines 34 - 38, The
stage moe_shared_expert_a8w4 declares its output width as d_hidden but the
output indexing and writes (iteration variable by over expert tiles and writes
at by * block_expert) are keyed by d_expert, which will mis-index unless
d_hidden == d_expert; either enforce d_hidden == d_expert at the start of
moe_shared_expert_a8w4 (raise/validate) or change the stage’s output shape to
use d_expert instead of d_hidden and update any output shape/width declarations
and dependent indexing (references: moe_shared_expert_a8w4, iteration variable
by, block_expert, d_hidden, d_expert) so writes at by * block_expert map
correctly and avoid out-of-bounds access.
|
|
||
| @T.prim_func | ||
| def main( | ||
| input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), |
There was a problem hiding this comment.
Avoid shadowing Python’s built-in input.
Line 44 uses input as a parameter name; rename to input_tensor (or similar) to satisfy lint and reduce ambiguity.
🧰 Tools
🪛 Ruff (0.15.5)
[error] 44-44: Function argument input is shadowing a Python builtin
(A002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` at line 44, Rename the
parameter named "input" to "input_tensor" to avoid shadowing Python's built-in;
update the function signature where input: T.Tensor((num_tokens, d_hidden),
"float8_e4m3fn") appears and rename every use of that symbol inside the same
function (and any callers) to the new name (input_tensor) so references like
indexing, shape checks, or passes use the updated identifier consistently.
| // ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+). | ||
| // Reads packed 4-bit data from shared memory and unpacks each 4-bit value | ||
| // into the low 4 bits of an 8-bit container (upper 4 bits zeroed). | ||
| TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, | ||
| void *const local_ptr) { | ||
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | ||
| asm volatile( | ||
| "ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" | ||
| : "=r"(value[0]) | ||
| : "r"(smem_int_ptr)); | ||
| } | ||
|
|
||
| TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, | ||
| void *const local_ptr) { | ||
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | ||
| asm volatile( | ||
| "ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" | ||
| : "=r"(value[0]), "=r"(value[1]) | ||
| : "r"(smem_int_ptr)); | ||
| } | ||
|
|
||
| TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr, | ||
| void *const local_ptr) { | ||
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | ||
| asm volatile( | ||
| "ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" | ||
| : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) | ||
| : "r"(smem_int_ptr)); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="$(fd 'ldsm\.h$' src | head -n1)"
echo "Inspecting: $FILE"
rg -n -C3 'ptx_ldmatrix_b4x16|__CUDA_ARCH__|__CUDA_ARCH_LIST__' "$FILE"
echo
echo "Call sites:"
rg -n -C2 'ptx_ldmatrix_b4x16_(x1|x2|x4)' src tilelangRepository: tile-ai/tilelang
Length of output: 2350
🏁 Script executed:
cat -n src/tl_templates/cuda/ldsm.h | head -100Repository: tile-ai/tilelang
Length of output: 4914
🏁 Script executed:
# Check if TL_DEVICE macro or other guards exist
rg -n 'TL_DEVICE|#if.*CUDA|#ifdef.*CUDA' src/tl_templates/cuda/ldsm.h | head -20Repository: tile-ai/tilelang
Length of output: 1047
🏁 Script executed:
# Check the entire file for any architectural gating patterns
wc -l src/tl_templates/cuda/ldsm.hRepository: tile-ai/tilelang
Length of output: 92
🏁 Script executed:
# Check for call sites of these new functions
rg -r 'ptx_ldmatrix_b4x16_(x1|x2|x4)' --type cpp --type h --type cuRepository: tile-ai/tilelang
Length of output: 456
🏁 Script executed:
# Check TL_DEVICE macro definition
rg -n 'define TL_DEVICE' src/Repository: tile-ai/tilelang
Length of output: 464
🏁 Script executed:
# Check common.h to understand the macro
head -50 src/tl_templates/cuda/common.hRepository: tile-ai/tilelang
Length of output: 1622
🏁 Script executed:
rg 'ptx_ldmatrix_b4x16' src/ tilelang/ 2>/dev/null | grep -v 'void ptx_ldmatrix_b4x16'Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for any architecture guards in the build system or other headers
rg '__CUDA_ARCH__|SM[0-9]+|cudaArchitecture' src/tl_templates/cuda/ | head -20Repository: tile-ai/tilelang
Length of output: 2285
🏁 Script executed:
# Check if there are any compilation flags or CMake guards for SM120
fd -e cmake -e CMakeLists | head -5 | xargs cat 2>/dev/null | grep -i 'sm\|arch\|compute' | head -10Repository: tile-ai/tilelang
Length of output: 42
🌐 Web query:
CUDA b4x16_p64 ldmatrix SM120 architecture requirement
💡 Result:
For the ldmatrix variant that uses 4-bit packed source format (i.e., .src_fmt = .b4x16_p64, producing .dst_fmt = .b8x16 for shapes like .m8n16 / .m16n16):
- It is not enough to just be on “SM120” in general; this form requires the architecture-specific Blackwell target
sm_120a(likewisesm_100a/sm_101afor other Blackwell chips). The PTX ISA states that the.m16n16/.m8n16shapes (which are the ones that use.b4x16_p64/.b6x16_p32) are supported onsm_120a. [1] - If you are using family targets (the
...fsuffix), PTX ISA notes these shapes are supported onsm_120f(or higher in the same family) starting PTX ISA 8.8. [1] - Background/versioning:
sm_120andsm_120atargets were added in PTX ISA 8.7. [2]
So, the practical requirement for ldmatrix ... .b8x16 .b4x16_p64 on “SM120” hardware is: compile for sm_120a (or sm_120f with a new-enough PTX toolchain), not plain sm_120. [1]
Sources: PTX ISA ldmatrix target notes and format description [1]; PTX ISA release history noting addition of sm_120/sm_120a [2].
Guard SM120-specific ldmatrix ... b4x16_p64 PTX behind an SM120+ compile check.
Lines 57–85 emit inline PTX that is only available on sm_120a and later. Without an architecture gate, CUDA compilation fails on lower SM targets (e.g., sm_80, sm_90) even if these helpers are never called. Gate these three functions with #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) to match the pattern used elsewhere in the codebase (see reduce.h).
🔧 Suggested guard shape
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200)
TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr,
void *const local_ptr) {
...
}
...
TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr,
void *const local_ptr) {
...
}
+#endif📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+). | |
| // Reads packed 4-bit data from shared memory and unpacks each 4-bit value | |
| // into the low 4 bits of an 8-bit container (upper 4 bits zeroed). | |
| TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, | |
| void *const local_ptr) { | |
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | |
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | |
| asm volatile( | |
| "ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" | |
| : "=r"(value[0]) | |
| : "r"(smem_int_ptr)); | |
| } | |
| TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, | |
| void *const local_ptr) { | |
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | |
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | |
| asm volatile( | |
| "ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" | |
| : "=r"(value[0]), "=r"(value[1]) | |
| : "r"(smem_int_ptr)); | |
| } | |
| TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr, | |
| void *const local_ptr) { | |
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | |
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | |
| asm volatile( | |
| "ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" | |
| : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) | |
| : "r"(smem_int_ptr)); | |
| } | |
| // ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+). | |
| // Reads packed 4-bit data from shared memory and unpacks each 4-bit value | |
| // into the low 4 bits of an 8-bit container (upper 4 bits zeroed). | |
| `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) | |
| TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, | |
| void *const local_ptr) { | |
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | |
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | |
| asm volatile( | |
| "ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" | |
| : "=r"(value[0]) | |
| : "r"(smem_int_ptr)); | |
| } | |
| TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, | |
| void *const local_ptr) { | |
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | |
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | |
| asm volatile( | |
| "ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" | |
| : "=r"(value[0]), "=r"(value[1]) | |
| : "r"(smem_int_ptr)); | |
| } | |
| TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr, | |
| void *const local_ptr) { | |
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | |
| int32_t *value = reinterpret_cast<int32_t *>(local_ptr); | |
| asm volatile( | |
| "ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" | |
| : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) | |
| : "r"(smem_int_ptr)); | |
| } | |
| `#endif` |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/tl_templates/cuda/ldsm.h` around lines 54 - 85, Wrap the SM120-only
inline PTX functions ptx_ldmatrix_b4x16_x1, ptx_ldmatrix_b4x16_x2, and
ptx_ldmatrix_b4x16_x4 with an architecture guard so they are only compiled for
SM120+; specifically surround the entire definitions with `#if`
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) ... `#endif` (matching the
pattern used in reduce.h) to prevent compilation on older SM targets.
| dtype = self.B.dtype | ||
| if dtype == "uint8" and self.C.dtype in ("float32", "float16"): | ||
| return "float4_e2m1fn" |
There was a problem hiding this comment.
in_dtype_b remap is too broad and can misroute non-FP4 kernels.
Line 96-97 maps any uint8 B with float accumulation to float4_e2m1fn. That can also match non-FP4 uint8 storage paths and incorrectly force FP4 MMA lowering. Please gate this remap behind an explicit FP4 signal/capability instead of the current heuristic.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/gemm_base.py` around lines 95 - 97, The current remap in
in_dtype_b (checking self.B.dtype == "uint8" and self.C.dtype in
("float32","float16")) incorrectly forces FP4 lowering via "float4_e2m1fn" for
any uint8 storage; change the condition to require an explicit FP4 capability
flag in addition to the dtype checks (e.g., require self.target.supports_fp4()
or an explicit boolean like self.use_fp4/self.has_fp4) before returning
"float4_e2m1fn", update the remap logic in the method that reads
self.B.dtype/self.C.dtype accordingly, and leave other uint8 paths unchanged so
non-FP4 kernels are not remapped.
Add the plumbing required to route float4_e2m1fn through the TCGEN5 MMA code-generation path so that FP4 GEMM kernels can be emitted on SM100. Changes: - ptx.h / ptx.cc: add kFloat4_e2m1fn enum, string tables, DTypeFromString - common.h: add kFloat4_e2m1fn to device-side DataType enum - tcgen5_meta.h: add FP4 branch in encode_dtype (format code 2) - tcgen05mma.h: add kFloat4_e2m1fn specializations for SS/TS/WS_SS (delegates to the existing f8f6f4 PTX kind) - mma_macro_generator.py: add dtype_abbrv mapping for float4_e2m1fn - docs/GEMM_NV_FP4_FEATURE_STEPS.md: design doc and progress tracker Addresses tile-ai#1592 Made-with: Cursor
… raw bytes, get_ldmatrix_offset add fp4 special layout, add SM120_FP4_FP4_F32_TN MmaDispatcher specialization + fp4 << 2 bit-shift.
…escriptor, cp.async copy, and packed SMEM addressing
Switch FP4 TMA data type from CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B to ALIGN16B so TMA hardware unpacks FP4 to 1 byte/element in SMEM, matching the SMEM pointer arithmetic (sizeof(fp4_e2_t)==1). Remove all sub-byte special-case branches in copy.cc since dtype.bytes()==1 now correctly represents the SMEM element size. Fix tcgen05_macro_generator to use elem_bytes instead of elem_bits//8 for byte offset and descriptor calculations in both SS and TS variants. Enable TMA path and disable warp-specialization for T.Pipelined FP4 GEMM example. Made-with: Cursor
0fa709c to
41ade2d
Compare
…gemm-nv-fp4 # Conflicts: # tilelang/intrinsics/tcgen05_macro_generator.py
|
may i take this pr? |
|
Thanks for your contribution! but closed as we should use |
|
Hi @LeiWang1999 |
GEMM NV FP4 Feature – Design & Progress
Addresses #1592
Summary
Add
float4_e2m1fn(NV FP4) GEMM support on both SM100 (TCGEN05 path) and SM120 (fragment MMA path).SM100 (B100/B200) uses
tcgen05.mmawith TMEM; SM120 (RTX 5080/5090) usesmma.sync.aligned.kind::f8f6f4.m16n8k32with register fragments. The CUTE library already provides MMA atoms for both architectures.Changes in this PR
WIP: SM120 fragment-MMA FP4 support
src/tl_templates/cuda/common.htl::float_e2m1_t(inheritscute::float_e2m1_t) +to_cute_typespecializationsrc/tl_templates/cuda/cuda_fp4.hfp4_e2_tfrom custom struct tousing fp4_e2_t = tl::float_e2m1_t(aligns with CUTE MMA atoms)src/tl_templates/cuda/gemm_mma.hTL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN)in SM120 sectiontilelang/intrinsics/mma_macro_generator.pyk_dimat 32 for sub-byte types (FP4 MMA is m16n8k32, same as FP8)tilelang/intrinsics/utils.pyget_ldmatrix_offset(reuses 8-bit layout)examples/gemm_fp4/example_gemm_fp4_sm120.pyT.alloc_fragment(no TMEM)Python-side pipeline (LayoutInference + LowerTileOp) passes successfully. CUDA kernel source is generated. Three C++ compilation issues remain:
Known remaining issues (CUDA compile stage)
cuda_fp4.h:__xmember access —fp4_e2_2_thelper functions referenceval.__x, buttl::float_e2m1_t(viacutlass::float_e2m1_t) uses a different storage accessor. Need to update pack/unpack helpers to use CUTLASS storage API.Codegen variable naming for sub-byte types — TVM codegen renames sub-byte fragment buffers with a
_packedsuffix (e.g.A_local_packed), but MMA lowering emits references to the original name (A_local). This is a TVM codegen interaction issue for 4-bit types.mma.h: missingMmaDispatcherspecialization —tl::mma_sync<kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, 16, 8, 32>has no matching dispatch ininstruction/mma.h. Need to add the FP4 PTX inline assembly specialization.Architecture: SM100 vs SM120
tcgen05.mma(async)mma.sync.aligned(sync)T.alloc_tmem)T.alloc_fragment)tcgen05mma.hgemm_mma.hSM100_MMA_*SM120_16x8x32_TNHow to test
Future work
mxf4nvf4.block_scale) support as a follow-upSummary by CodeRabbit
New Features
Examples
Documentation