Add regression test for 1D TMA load compilation and execution#1989
Conversation
Signed-off-by: Nguyen Huy Hoang <181364121+huyhoang171106@users.noreply.github.com>
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughA CUDA-gated regression test was added to validate 1D TMA load codegen and runtime. It includes helpers to detect device capability, extract generated kernel source, build a 1D global→shared→global TMA copy kernel, compile for sm_90a, assert TMA patterns in the source, and run the kernel on float16 tensors. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
testing/python/tilelang/test_tma_load.py (1)
18-40: Use the officialget_kernel_source()API instead of probing internal attributes.The
JITKernelobject returned bytl.compile()exposesget_kernel_source()as the documented method for retrieving generated source code. The current implementation unnecessarily probes internal attributes (get_source(),module.imported_modules,rt_mod.imported_modules) which are fragile and not part of the public API.♻️ Proposed simplification
def _extract_source(kernel) -> str: - if hasattr(kernel, "get_source"): - source = kernel.get_source() - if isinstance(source, str) and source: - return source - - module = getattr(kernel, "module", None) - if module is not None and hasattr(module, "imported_modules"): - imported = getattr(module, "imported_modules", []) - if imported: - source = imported[0].get_source() - if isinstance(source, str) and source: - return source - - runtime_mod = getattr(kernel, "rt_mod", None) - if runtime_mod is not None and hasattr(runtime_mod, "imported_modules"): - imported = getattr(runtime_mod, "imported_modules", []) - if imported: - source = imported[0].get_source() - if isinstance(source, str) and source: - return source - - raise RuntimeError("Unable to extract generated source from compiled kernel") + return kernel.get_kernel_source()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/tilelang/test_tma_load.py` around lines 18 - 40, _extract_source currently probes internal attributes (get_source, module.imported_modules, rt_mod.imported_modules) to find generated source; replace that with the public API by calling the JITKernel's documented get_kernel_source() (e.g., use kernel.get_kernel_source() or getattr(kernel, "get_kernel_source", None) and validate it returns a non-empty str) and raise the same RuntimeError if it's missing or empty; update references inside _extract_source to remove the internal-attribute branches and rely solely on get_kernel_source() for extraction.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@testing/python/tilelang/test_tma_load.py`:
- Around line 18-40: _extract_source currently probes internal attributes
(get_source, module.imported_modules, rt_mod.imported_modules) to find generated
source; replace that with the public API by calling the JITKernel's documented
get_kernel_source() (e.g., use kernel.get_kernel_source() or getattr(kernel,
"get_kernel_source", None) and validate it returns a non-empty str) and raise
the same RuntimeError if it's missing or empty; update references inside
_extract_source to remove the internal-attribute branches and rely solely on
get_kernel_source() for extraction.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ce4e9697-6861-49db-b787-4c01fb8b934c
📒 Files selected for processing (1)
testing/python/tilelang/test_tma_load.py
There was a problem hiding this comment.
Pull request overview
Adds a focused Python regression test intended to prevent reintroducing the 1D TMA load compile-time signature mismatch reported in #1842 by validating both generated source patterns and runtime correctness on Hopper-class GPUs.
Changes:
- Add a new CUDA-only regression test that compiles a 1D
T.copyglobal→shared→global kernel. - Assert generated source contains TMA-related instruction markers.
- Execute the compiled kernel and verify output equals input.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _build_1d_tma_copy(): | ||
| @T.prim_func | ||
| def main(A: T.Buffer((N,), "float16"), B: T.Buffer((N,), "float16")): | ||
| with T.Kernel(T.ceildiv(N, BLOCK), threads=128) as bx: |
There was a problem hiding this comment.
threads=128 while copying BLOCK=256 elements is an unusual configuration for a simple 1D bulk copy (the original repro uses 256 threads). If the intent is to force the 1D TMA lowering path, aligning threads with the copy granularity (or matching the issue’s 256-thread setup) will make the test less brittle and more representative of the failing case.
| with T.Kernel(T.ceildiv(N, BLOCK), threads=128) as bx: | |
| with T.Kernel(T.ceildiv(N, BLOCK), threads=BLOCK) as bx: |
| def _extract_source(kernel) -> str: | ||
| if hasattr(kernel, "get_source"): | ||
| source = kernel.get_source() | ||
| if isinstance(source, str) and source: | ||
| return source | ||
|
|
||
| module = getattr(kernel, "module", None) | ||
| if module is not None and hasattr(module, "imported_modules"): | ||
| imported = getattr(module, "imported_modules", []) | ||
| if imported: | ||
| source = imported[0].get_source() | ||
| if isinstance(source, str) and source: | ||
| return source | ||
|
|
||
| runtime_mod = getattr(kernel, "rt_mod", None) | ||
| if runtime_mod is not None and hasattr(runtime_mod, "imported_modules"): | ||
| imported = getattr(runtime_mod, "imported_modules", []) | ||
| if imported: | ||
| source = imported[0].get_source() | ||
| if isinstance(source, str) and source: | ||
| return source | ||
|
|
||
| raise RuntimeError("Unable to extract generated source from compiled kernel") |
There was a problem hiding this comment.
_extract_source() will always fail for tl.compile(...) results: JITKernel exposes get_kernel_source(), but does not have get_source, module, or rt_mod attributes. This will raise RuntimeError and make the test fail even when compilation succeeds. Prefer calling kernel.get_kernel_source() directly (optionally with a small fallback to kernel.adapter.get_kernel_source() if needed).
| b = torch.empty_like(a) | ||
|
|
||
| kernel(a, b) |
There was a problem hiding this comment.
With out_idx=[1], the default tvm_ffi adapter wraps the kernel to accept only non-output tensors (here: only A) and allocate B internally. Calling kernel(a, b) will raise a ValueError due to the input count mismatch. Either call b = kernel(a) (and compare b to a), or compile with out_idx=None if you want to pass B explicitly.
| b = torch.empty_like(a) | |
| kernel(a, b) | |
| b = kernel(a) |
| def _get_device_capability() -> tuple[int, int]: | ||
| if not torch.cuda.is_available(): | ||
| return (0, 0) | ||
| return torch.cuda.get_device_capability() | ||
|
|
||
|
|
||
| def _extract_source(kernel) -> str: | ||
| if hasattr(kernel, "get_source"): | ||
| source = kernel.get_source() | ||
| if isinstance(source, str) and source: | ||
| return source | ||
|
|
||
| module = getattr(kernel, "module", None) | ||
| if module is not None and hasattr(module, "imported_modules"): | ||
| imported = getattr(module, "imported_modules", []) | ||
| if imported: | ||
| source = imported[0].get_source() | ||
| if isinstance(source, str) and source: | ||
| return source | ||
|
|
||
| runtime_mod = getattr(kernel, "rt_mod", None) | ||
| if runtime_mod is not None and hasattr(runtime_mod, "imported_modules"): | ||
| imported = getattr(runtime_mod, "imported_modules", []) | ||
| if imported: | ||
| source = imported[0].get_source() | ||
| if isinstance(source, str) and source: | ||
| return source | ||
|
|
||
| raise RuntimeError("Unable to extract generated source from compiled kernel") | ||
|
|
||
|
|
||
| def _build_1d_tma_copy(): | ||
| @T.prim_func | ||
| def main(A: T.Buffer((N,), "float16"), B: T.Buffer((N,), "float16")): | ||
| with T.Kernel(T.ceildiv(N, BLOCK), threads=128) as bx: | ||
| A_shared = T.alloc_shared((BLOCK,), "float16") | ||
| T.copy(A[bx * BLOCK:(bx + 1) * BLOCK], A_shared) | ||
| T.copy(A_shared, B[bx * BLOCK:(bx + 1) * BLOCK]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") | ||
| @pytest.mark.skipif(_get_device_capability()[0] < 9, reason="Hopper (sm90+) is required for TMA") | ||
| def test_tma_load_1d_compile_and_run_regression(): |
There was a problem hiding this comment.
The test uses ad-hoc pytest.mark.skipif checks and a custom _get_device_capability() helper. The repo already provides tilelang.testing.requires_cuda and tilelang.testing.requires_cuda_compute_version_ge(...) for consistent skipping and clearer failure reasons across environments. Consider switching to those decorators and dropping _get_device_capability().
| @pytest.mark.skipif(_get_device_capability()[0] < 9, reason="Hopper (sm90+) is required for TMA") | ||
| def test_tma_load_1d_compile_and_run_regression(): | ||
| program = _build_1d_tma_copy() | ||
| kernel = tl.compile(program, out_idx=[1], target="cuda -arch=sm_90a") |
There was a problem hiding this comment.
Hard-coding target="cuda -arch=sm_90a" can make the test fail on Hopper devices that are sm_90 but not sm_90a (your skip condition only checks major capability >= 9). Prefer target="cuda" (let TileLang pick the device arch) or target="cuda -arch=sm_90" to match the sm90 requirement enforced by the skip decorator.
| kernel = tl.compile(program, out_idx=[1], target="cuda -arch=sm_90a") | |
| kernel = tl.compile(program, out_idx=[1], target="cuda -arch=sm_90") |
| def _build_1d_tma_copy(): | ||
| @T.prim_func | ||
| def main(A: T.Buffer((N,), "float16"), B: T.Buffer((N,), "float16")): | ||
| with T.Kernel(T.ceildiv(N, BLOCK), threads=128) as bx: | ||
| A_shared = T.alloc_shared((BLOCK,), "float16") | ||
| T.copy(A[bx * BLOCK:(bx + 1) * BLOCK], A_shared) | ||
| T.copy(A_shared, B[bx * BLOCK:(bx + 1) * BLOCK]) | ||
|
|
There was a problem hiding this comment.
This kernel shape doesn’t mirror the reported repro in #1842 (single-CTA copy of a 1D tensor with length 7168 and float32). Using a tiled BLOCK=256 slice copy and float16 may not exercise the exact lowering path that previously produced the bad tl::tma_load call signature. Consider adjusting the test to match the issue more closely (dtype, length, single T.copy(A, A_shared) / T.copy(A_shared, B) pattern) so it reliably catches the regression.
…dd-regression-test-for-1d-tma-load-comp
Summary
This issue is a compile-time mismatch; a focused regression test must validate both source emission and runtime correctness for the reported pattern. Without this, similar signature regressions can reappear silently.
Files changed
testing/python/tilelang/test_tma_load.py(new)Testing
Closes #1842
Summary by CodeRabbit