[Feature] Introduce T.CUDASourceCodeKernel#1970
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds TileLang support for embedding external CUDA source kernels: new TileLang attributes and API, propagation through lowering and host/device split, validation and early-emission in CUDA codegen, module-level device-symbol uniqueness checks, and tests for inline and file-backed CUDA sources. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant API as CUDASourceCodeKernel
participant PrimFunc as TileLang PrimFunc
participant Splitter as split_host_device
participant RT as rt_mod_cuda
participant Validator as ValidateExternalKernelEntryName
participant Codegen as codegen_cuda.AddFunction
User->>API: call with source_or_path (+ entry_name?)
API->>API: resolve/load source, set attrs (code_block_source, entry_name)
API->>PrimFunc: create prim_func with attrs
PrimFunc->>Splitter: host/device split (capture host ABI)
Splitter->>Splitter: extract/erase code_block attrs, derive device params (or parse kernel signature)
Splitter->>RT: validate device global symbols unique
Splitter->>Codegen: emit device PrimFunc (with code_block_source)
Codegen->>Validator: parse source, verify entry_name matches __global__ kernel names
Validator-->>Codegen: success / error
Codegen->>Codegen: emit external CUDA source directly into artifact
Codegen->>User: return compiled artifact (includes kernel_source)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tilelang/language/kernel.py (1)
354-370: Consider clarifying the fallback behavior.The function returns the input string unchanged (line 370) when it doesn't match file paths or contain CUDA markers. While this flexibility allows users to pass minimal CUDA snippets, it might silently accept invalid inputs. Consider adding a comment explaining this intentional permissiveness.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/kernel.py` around lines 354 - 370, The function _load_cuda_source intentionally falls back to returning the original input when it isn't a file path and doesn't contain CUDA markers; add a concise inline comment above the final "return source" explaining this permissive fallback behaviour (that minimal CUDA snippets or single-line kernels are accepted and therefore the function returns the input unchanged rather than raising), and reference the earlier checks in the comment (os.path.expanduser/isfile, source_markers tuple and the FileNotFoundError branch) so future readers know this is deliberate.tilelang/engine/runner.py (3)
20-23: Consider more specific exception handling for torch import.The broad
Exceptioncatch could mask unexpected issues during torch import. Consider catchingImportErrorspecifically.♻️ Proposed fix
try: import torch -except Exception: # pragma: no cover +except ImportError: # pragma: no cover torch = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/engine/runner.py` around lines 20 - 23, The current blanket except Exception around the torch import in runner.py should be narrowed to avoid hiding unrelated errors; update the try/except that imports torch (the block assigning torch = None) to catch ImportError (or ImportError/ModuleNotFoundError) instead of Exception, preserve the existing behavior of setting torch = None and keep the pragma comment if needed for coverage tools.
170-179: Cleanup exception handling is acceptable but could benefit from logging.The silent exception handling in cleanup is a common pattern to avoid masking original errors. Debug logging could help troubleshoot cleanup issues without disrupting normal flow.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/engine/runner.py` around lines 170 - 179, The _cleanup_temp_sources method currently swallows all exceptions silently; modify the except block to log the failure (including the path and exception) at debug level instead of passing silently. Use the runner's logger (e.g., self.logger.debug) if available, falling back to the module logger (logging.getLogger(__name__)) to record f"failed to remove temp source {path}: {e}" while keeping the behavior of not re-raising the exception; keep references to self._temp_sources, keep_temporary_sources, path.exists, and path.unlink unchanged.
105-111: Consider narrowing exception handling or adding logging.The broad
Exceptioncatch could hide unexpected errors. Since this is looking up TVM global functions, consider catching a more specific exception or adding debug logging.♻️ Proposed fix with logging
+import logging + +_logger = logging.getLogger(__name__) + def _get_global_func(self, names: list[str]) -> Any | None: for name in names: try: return tvm.ffi.get_global_func(name) - except Exception: + except Exception as e: + _logger.debug("Failed to get global func %s: %s", name, e) continue return None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/engine/runner.py` around lines 105 - 111, The loop in _get_global_func currently swallows all exceptions from tvm.ffi.get_global_func; change the broad except to catch the specific TVM error (e.g., tvm.error.TVMError or the concrete exception type raised by tvm.ffi.get_global_func) so missing globals are handled quietly, and add a debug log for unexpected exceptions: in _get_global_func import logging, get a logger for the module, catch tvm.error.TVMError (continue) and add a separate except Exception as e block that logs logger.debug/exception with the function name and error before continuing or rethrowing as appropriate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tilelang/language/kernel.py`:
- Around line 433-436: The docstring in CUDASourceCodeKernel contains a stale
parameter description for "is_cpu" which is not present in
CUDASourceCodeKernel's signature (it was copied from Kernel); remove the "is_cpu
: bool" section and any sentences referencing binding threadIdx/blockIdx from
the CUDASourceCodeKernel docstring so the docs match the actual parameters and
behavior of CUDASourceCodeKernel.
---
Nitpick comments:
In `@tilelang/engine/runner.py`:
- Around line 20-23: The current blanket except Exception around the torch
import in runner.py should be narrowed to avoid hiding unrelated errors; update
the try/except that imports torch (the block assigning torch = None) to catch
ImportError (or ImportError/ModuleNotFoundError) instead of Exception, preserve
the existing behavior of setting torch = None and keep the pragma comment if
needed for coverage tools.
- Around line 170-179: The _cleanup_temp_sources method currently swallows all
exceptions silently; modify the except block to log the failure (including the
path and exception) at debug level instead of passing silently. Use the runner's
logger (e.g., self.logger.debug) if available, falling back to the module logger
(logging.getLogger(__name__)) to record f"failed to remove temp source {path}:
{e}" while keeping the behavior of not re-raising the exception; keep references
to self._temp_sources, keep_temporary_sources, path.exists, and path.unlink
unchanged.
- Around line 105-111: The loop in _get_global_func currently swallows all
exceptions from tvm.ffi.get_global_func; change the broad except to catch the
specific TVM error (e.g., tvm.error.TVMError or the concrete exception type
raised by tvm.ffi.get_global_func) so missing globals are handled quietly, and
add a debug log for unexpected exceptions: in _get_global_func import logging,
get a logger for the module, catch tvm.error.TVMError (continue) and add a
separate except Exception as e block that logs logger.debug/exception with the
function name and error before continuing or rethrowing as appropriate.
In `@tilelang/language/kernel.py`:
- Around line 354-370: The function _load_cuda_source intentionally falls back
to returning the original input when it isn't a file path and doesn't contain
CUDA markers; add a concise inline comment above the final "return source"
explaining this permissive fallback behaviour (that minimal CUDA snippets or
single-line kernels are accepted and therefore the function returns the input
unchanged rather than raising), and reference the earlier checks in the comment
(os.path.expanduser/isfile, source_markers tuple and the FileNotFoundError
branch) so future readers know this is deliberate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 41a8e8bc-61cb-4342-86b8-acbe38f88d93
📒 Files selected for processing (12)
examples/gemm_streamk/example_tilelang_gemm_streamk.pysrc/target/codegen_cuda.ccsrc/transform/common/attr.hsrc/transform/lower_opaque_block.ccsrc/transform/split_host_device.cctesting/python/language/test_tilelang_language_source_kernel.pytilelang/engine/__init__.pytilelang/engine/lower.pytilelang/engine/runner.pytilelang/language/__init__.pytilelang/language/eager/builder.pytilelang/language/kernel.py
👮 Files not reviewed due to content moderation or server errors (8)
- src/transform/common/attr.h
- tilelang/language/eager/builder.py
- tilelang/language/init.py
- tilelang/engine/lower.py
- examples/gemm_streamk/example_tilelang_gemm_streamk.py
- tilelang/engine/init.py
- src/transform/lower_opaque_block.cc
- src/target/codegen_cuda.cc
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_source_kernel.py (1)
22-32: Please add a regression where the enclosingPrimFunchas an extra unused arg.These helpers only cover the happy path where the external CUDA signature exactly matches
A,B,N. Given the new ABI reconstruction insrc/transform/split_host_device.cc, we should also cover a@T.prim_funcwith an unused tensor or scalar so launch-arity/order mismatches are caught in CI.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/language/test_tilelang_language_source_kernel.py` around lines 22 - 32, Update the test helper make_source_kernel to include a regression case where the enclosing `@T.prim_func` (the function main created inside make_source_kernel) declares an additional unused parameter (e.g., an extra tensor or scalar) that is not referenced by the CUDA kernel signature; ensure the test invokes T.CUDASourceCodeKernel with the same existing launch parameters (T.ceildiv(N, 128), threads=128) but the PrimFunc signature now contains the extra unused arg so ABI reconstruction mismatches (launch arity/order) are exercised—modify the test that calls make_source_kernel (or add a new test variant) to pass a PrimFunc with that unused parameter while keeping kernel source and expected behavior otherwise unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/transform/split_host_device.cc`:
- Around line 134-176: CollectSourceKernelSignature currently captures all host
scalars and buffer symbols (host_params_ and host_buffer_map_) regardless of
whether the external CUDA source kernel actually uses them, which causes
mismatched arity/order for emitted __global__ functions; fix by restricting
captured params/buffers to only those referenced by the source kernel: compute a
used-symbol set (e.g., from the source-kernel AST or the kernel's explicit
parameter list) and change the push logic and the host_buffer_map_ loop to skip
any var or buffer whose data/shape/strides/elem_offset vars are not in that
used-symbol set, only add buffers_to_declare for buffers that contribute any
used vars, then call SortDeviceParams(¶ms) and return as before (update
signature of CollectSourceKernelSignature or have it obtain the used-symbol set
from the surrounding context to locate relevant symbols).
In `@tilelang/language/kernel.py`:
- Around line 97-129: Replace the flaky assert and add explicit validation in
_normalize_threads and _normalize_cluster_dims: in _normalize_threads (and when
handling list/tuple branches) check that sequences have 1–3 elements, that each
element is an int (and positive if required), and raise ValueError with a clear
message instead of using assert; similarly, in _normalize_cluster_dims validate
that incoming int/list/tuple values produce 1–3 elements, that each element is
an int, and raise ValueError for lengths >3 or non-int entries; keep the
existing defaulting behavior (e.g., threads=None -> 128 when not is_cpu, and
returning None when cluster_dims == [1,1,1]) and ensure invalid inputs fail
early before reaching _ffi_api.KernelLaunch.
---
Nitpick comments:
In `@testing/python/language/test_tilelang_language_source_kernel.py`:
- Around line 22-32: Update the test helper make_source_kernel to include a
regression case where the enclosing `@T.prim_func` (the function main created
inside make_source_kernel) declares an additional unused parameter (e.g., an
extra tensor or scalar) that is not referenced by the CUDA kernel signature;
ensure the test invokes T.CUDASourceCodeKernel with the same existing launch
parameters (T.ceildiv(N, 128), threads=128) but the PrimFunc signature now
contains the extra unused arg so ABI reconstruction mismatches (launch
arity/order) are exercised—modify the test that calls make_source_kernel (or add
a new test variant) to pass a PrimFunc with that unused parameter while keeping
kernel source and expected behavior otherwise unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e10e4225-8a6a-45d3-aa83-61037fac4105
📒 Files selected for processing (3)
src/transform/split_host_device.cctesting/python/language/test_tilelang_language_source_kernel.pytilelang/language/kernel.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tilelang/language/kernel.py (1)
97-129:⚠️ Potential issue | 🟡 MinorFail fast on invalid launch shapes here instead of relying on
assert.
python -Ostrips Line 112, and empty or >3-D sequences still pass through unchanged. Since bothKernel()andCUDASourceCodeKernel()now share these helpers, bad inputs fail later inside_ffi_api.KernelLaunchinstead of at the Python API boundary.♻️ Minimal fix
def _normalize_threads( threads: int | list[int] | tuple | None, *, is_cpu: bool, ) -> list[int] | None: if not is_cpu and threads is None: threads = 128 # default thread number + + if threads is None: + return None if isinstance(threads, int): return [threads, 1, 1] - if isinstance(threads, list): - return threads + [1] * (3 - len(threads)) - if isinstance(threads, tuple): - return list(threads) + [1] * (3 - len(threads)) + if isinstance(threads, (list, tuple)): + if not 1 <= len(threads) <= 3: + raise ValueError("threads must have 1-3 dimensions") + return list(threads) + [1] * (3 - len(threads)) - assert is_cpu, "threads must be an integer or a list of integers" - return None + raise ValueError("threads must be an integer or a list/tuple of 1-3 integers") def _normalize_cluster_dims( cluster_dims: int | tuple[int, int, int] | list[int] | None, ) -> list[int] | None: if cluster_dims is None: return None if isinstance(cluster_dims, (list, tuple)): + if not 1 <= len(cluster_dims) <= 3: + raise ValueError("cluster_dims must have 1-3 dimensions") cluster_dims = list(cluster_dims) + [1] * (3 - len(cluster_dims))
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)
4107-4142: Refactor to avoid duplicate external kernel parsing across lowering and codegen.
split_host_device.ccalready resolves the external kernel entry name from the CUDA source during lowering and stores it askGlobalSymbol.codegen_cuda.ccthen receives this resolved name but re-parses the raw source with an identical regex to re-validate it. This duplication creates drift risk: both parsers must stay in sync around CUDA qualifier and attribute syntax, which varies across CUDA versions (__launch_bounds__parameter count, etc.). The resolved kernel signature should be passed forward instead of re-parsing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/target/codegen_cuda.cc` around lines 4107 - 4142, The codegen currently re-parses CUDA source in ValidateExternalKernelEntryName using a regex, duplicating logic already performed during lowering in split_host_device.cc which produced kGlobalSymbol; to fix, stop re-parsing the raw source in codegen_cuda.cc and instead use the resolved symbol carried forward (kGlobalSymbol) from the lowering pass: update the T.CUDASourceCodeKernel handling so the resolved kernel name is passed into codegen and have ValidateExternalKernelEntryName (or its caller) accept and use that provided symbol rather than running the regex; remove the regex/kernel_names logic or guard it as a fallback only when kGlobalSymbol is not present to avoid drift.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/target/rt_mod_cuda.cc`:
- Around line 20-40: The current ValidateUniqueDeviceGlobalSymbols rejects any
repeated GetDeviceGlobalSymbol value, which wrongly blocks multiple PrimFunc
launches that reference the same external .cu entry; modify the check to allow
identical source-backed kernels by deduplicating or by comparing definitions:
when encountering an existing global_symbol in symbol_to_gvar, if both PrimFuncs
are source-backed (e.g., produced by T.CUDASourceCodeKernel / have the external
entry name set by split_host_device.cc) then compare their device
code/attributes and treat identical definitions as non-conflicting (do not
ICHECK), otherwise keep the existing ICHECK failure; update
ValidateUniqueDeviceGlobalSymbols (and its use of symbol_to_gvar,
GetDeviceGlobalSymbol, and the PrimFunc comparison logic) to implement this
dedupe-or-compare behavior.
---
Nitpick comments:
In `@src/target/codegen_cuda.cc`:
- Around line 4107-4142: The codegen currently re-parses CUDA source in
ValidateExternalKernelEntryName using a regex, duplicating logic already
performed during lowering in split_host_device.cc which produced kGlobalSymbol;
to fix, stop re-parsing the raw source in codegen_cuda.cc and instead use the
resolved symbol carried forward (kGlobalSymbol) from the lowering pass: update
the T.CUDASourceCodeKernel handling so the resolved kernel name is passed into
codegen and have ValidateExternalKernelEntryName (or its caller) accept and use
that provided symbol rather than running the regex; remove the
regex/kernel_names logic or guard it as a fallback only when kGlobalSymbol is
not present to avoid drift.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fb6ebbb7-7cdb-4948-9d45-29bc09f51f9d
📒 Files selected for processing (10)
src/target/codegen_c_host.ccsrc/target/codegen_cuda.ccsrc/target/rt_mod_cuda.ccsrc/transform/common/attr.hsrc/transform/lower_device_kernel_launch.ccsrc/transform/lower_opaque_block.ccsrc/transform/make_packed_api.ccsrc/transform/split_host_device.cctesting/python/language/test_tilelang_language_source_kernel.pytilelang/language/kernel.py
🚧 Files skipped from review as they are similar to previous changes (4)
- src/transform/lower_opaque_block.cc
- src/transform/common/attr.h
- testing/python/language/test_tilelang_language_source_kernel.py
- src/transform/split_host_device.cc
1ae4952 to
fc77671
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
src/target/rt_mod_cuda.cc (1)
20-39:⚠️ Potential issue | 🟠 MajorAllow repeated launches of the same external CUDA entry.
This rejects any duplicate
global_symbol, including the valid case where two launch stubs intentionally reference the same source-backed__global__function. That makes one external kernel unusable from multipleT.CUDASourceCodeKernelcall sites. Please deduplicate identical source-backed kernels, or only fail when the same symbol resolves to a different definition.Also applies to: 98-99, 135-135
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b2338ec3-29dc-4aa6-ad25-47801a92b34f
📒 Files selected for processing (12)
src/target/codegen_c_host.ccsrc/target/codegen_cuda.ccsrc/target/rt_mod_cuda.ccsrc/transform/common/attr.hsrc/transform/lower_device_kernel_launch.ccsrc/transform/lower_opaque_block.ccsrc/transform/make_packed_api.ccsrc/transform/split_host_device.cctesting/python/language/test_tilelang_language_source_kernel.pytilelang/engine/lower.pytilelang/language/__init__.pytilelang/language/kernel.py
✅ Files skipped from review due to trivial changes (2)
- tilelang/language/init.py
- src/transform/lower_device_kernel_launch.cc
🚧 Files skipped from review as they are similar to previous changes (6)
- src/transform/lower_opaque_block.cc
- tilelang/engine/lower.py
- tilelang/language/kernel.py
- src/transform/split_host_device.cc
- testing/python/language/test_tilelang_language_source_kernel.py
- src/target/codegen_cuda.cc
6aa15de to
465caea
Compare
f9958be to
eba7591
Compare
This PR introduces
T.CUDASourceCodeKernel, which is another approach in TileLang to launch a GPU kernel. Instead of expressing the device-side body through TileLang DSL, users can provide CUDA source directly, either as inline code or via a source file path, and launch it from a TileLang PrimFunc.Some valuable points of this new API:
tl_templates.kernel..export_sourcesand importing the CUDA throughT.CUDASourceCodeKernel.T.import_sourceortilelang_callback_cuda_postproc), this approach is more elegant, more maintainable, and better aligned with TileLang's standard lowering and codegen pipeline.-This API introduces a more explicit host-driven launch abstraction, which can be a precursor to HostDSL (an early building block for making TileLang a host-side control layer over both TileLang-generated kernels and externally authored device kernels.)
Summary by CodeRabbit
New Features
Bug Fixes
Tests