[Reorgnize] Reorganize intra-node examples add add tests#52
[Reorgnize] Reorganize intra-node examples add add tests#52Rachmanino wants to merge 8 commits intomainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRemoved multiple NVSHMEM/Triton-backed distributed examples/utilities, added several intranode IPC examples and tests, updated README and minor intranode tweaks. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/distributed/intranode/test_intranode.py`:
- Around line 13-14: The test passes args=(2, None) which sends None into the
main functions and causes AttributeError when they access attributes; update the
test (test_example_allgather_gemm_overlapped) to pass a minimal valid args
object (e.g., an argparse.Namespace or simple object) containing the attributes
expected by example_allgather_gemm_overlapped.main,
example_gemm_rs_overlapped.main and example_sp_ag_attention_intra_node.main (at
least persistent for the first two, and batch_size, q_head, etc. for the
attention example), or alternately add defensive checks inside those main
functions to handle args is None before accessing attributes — pick one approach
and implement it consistently for the three mains named above.
🧹 Nitpick comments (1)
examples/distributed/intranode/test_intranode.py (1)
10-28: Considerpytest.mark.parametrizeto reduce boilerplate.All three tests share the same decorator stack and spawn logic, differing only in the imported module. A parametrized test would reduce duplication:
♻️ Optional refactor
+@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +@pytest.mark.parametrize("module", [ + example_allgather_gemm_overlapped, + example_gemm_rs_overlapped, + example_sp_ag_attention_intra_node, +]) +def test_intranode_example(module): + torch.multiprocessing.spawn(module.main, args=(2, None), nprocs=2)
| def test_example_allgather_gemm_overlapped(): | ||
| torch.multiprocessing.spawn(example_allgather_gemm_overlapped.main, args=(2, None), nprocs=2) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, find the location of the example modules
find . -name "example_allgather_gemm_overlapped.py" -o -name "example_gemm_rs_overlapped.py" -o -name "example_sp_ag_attention_intra_node.py" | head -20Repository: tile-ai/tilescale
Length of output: 264
🏁 Script executed:
#!/bin/bash
# Extract main function definitions with context
for f in ./examples/distributed/intranode/example_allgather_gemm_overlapped.py ./examples/distributed/intranode/example_gemm_rs_overlapped.py ./examples/distributed/intranode/example_sp_ag_attention_intra_node.py; do
echo "=== $(basename $f) ==="
sed -n '/^def main/,/^def \|^class \|^$/p' "$f" | head -40
doneRepository: tile-ai/tilescale
Length of output: 931
🏁 Script executed:
#!/bin/bash
# Get the full main function for each module (up to 60 lines)
for f in ./examples/distributed/intranode/example_allgather_gemm_overlapped.py ./examples/distributed/intranode/example_gemm_rs_overlapped.py ./examples/distributed/intranode/example_sp_ag_attention_intra_node.py; do
echo "=== $(basename $f) ==="
sed -n '/^def main/,/^def \|^class \|^$/p' "$f" | head -60
echo ""
doneRepository: tile-ai/tilescale
Length of output: 934
🏁 Script executed:
#!/bin/bash
# Check the test file to confirm args pattern
echo "=== test_intranode.py ==="
cat -n examples/distributed/intranode/test_intranode.py
echo ""
echo "=== Checking for any args usage in example_sp_ag_attention_intra_node.py ==="
grep -n "args\." examples/distributed/intranode/example_sp_ag_attention_intra_node.py || echo "No args. accesses found"Repository: tile-ai/tilescale
Length of output: 1792
All three main functions will crash with AttributeError when args=None is passed.
The tests call torch.multiprocessing.spawn(..., args=(2, None), nprocs=2), which passes None as the third argument to each main function. However:
example_allgather_gemm_overlapped.mainandexample_gemm_rs_overlapped.mainunconditionally accessargs.persistenton line 5.example_sp_ag_attention_intra_node.mainunconditionally accessesargs.batch_size,args.q_head, and multiple other attributes (lines 293-304, 409).
All three will fail with AttributeError: 'NoneType' object has no attribute .... Either pass valid args or add None checks to all attribute accesses in these functions.
🤖 Prompt for AI Agents
In `@examples/distributed/intranode/test_intranode.py` around lines 13 - 14, The
test passes args=(2, None) which sends None into the main functions and causes
AttributeError when they access attributes; update the test
(test_example_allgather_gemm_overlapped) to pass a minimal valid args object
(e.g., an argparse.Namespace or simple object) containing the attributes
expected by example_allgather_gemm_overlapped.main,
example_gemm_rs_overlapped.main and example_sp_ag_attention_intra_node.main (at
least persistent for the first two, and batch_size, q_head, etc. for the
attention example), or alternately add defensive checks inside those main
functions to handle args is None before accessing attributes — pick one approach
and implement it consistently for the three mains named above.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
examples/distributed/intranode/example_reduce_scatter.py (1)
10-18: Avoid shadowing the Python builtininput.The parameter name
inputshadows the Python builtin. Consider renaming toinput_tensororinpfor clarity and to avoid potential confusion.♻️ Suggested rename
def torch_reduce_scatter( pg: torch.distributed.ProcessGroup, - input: torch.Tensor, + input_tensor: torch.Tensor, num_local_ranks: int, ) -> torch.Tensor: - M, N = input.shape - output = torch.empty((M // num_local_ranks, N), dtype=input.dtype, device=input.device) - torch.distributed.reduce_scatter_tensor(output, input, group=pg) + M, N = input_tensor.shape + output = torch.empty((M // num_local_ranks, N), dtype=input_tensor.dtype, device=input_tensor.device) + torch.distributed.reduce_scatter_tensor(output, input_tensor, group=pg) return output🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_reduce_scatter.py` around lines 10 - 18, The function torch_reduce_scatter currently uses a parameter named input which shadows the Python builtin; rename that parameter (and all internal references) to input_tensor (or inp) in the function signature and body of torch_reduce_scatter so the code uses the new name when computing M, N, creating output, and calling torch.distributed.reduce_scatter_tensor; ensure any callers or references to torch_reduce_scatter are updated accordingly to match the new parameter name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 21-25: The function main currently assumes args is not None and
directly reads args.M and args.N; add a None guard like other examples so tests
that call main(args=None) won't crash: inside main (function main) check if args
is None (or use a fallback pattern) and set M and N to default values when args
is None before computing M_per_rank; ensure you reference and preserve dtype,
local_rank, num_local_ranks and only replace the unconditional accesses to
args.M / args.N with guarded assignments.
In `@examples/distributed/intranode/test_intranode.py`:
- Around line 35-36: The test test_example_reduce_scatter calls
torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None),
nprocs=2) which passes args=None into example_reduce_scatter.main and causes
attribute access errors on args.M and args.N; fix by either updating the test to
pass a proper args object (replace the None with a simple namespace/config
providing M and N) or make example_reduce_scatter.main resilient by using safe
fallbacks when args is None (e.g., refer to args.M as args.M if args else
<default_M> and similarly for args.N) so the test no longer crashes; reference
example_reduce_scatter.main and test_example_reduce_scatter when making the
change.
---
Nitpick comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 10-18: The function torch_reduce_scatter currently uses a
parameter named input which shadows the Python builtin; rename that parameter
(and all internal references) to input_tensor (or inp) in the function signature
and body of torch_reduce_scatter so the code uses the new name when computing M,
N, creating output, and calling torch.distributed.reduce_scatter_tensor; ensure
any callers or references to torch_reduce_scatter are updated accordingly to
match the new parameter name.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1ee0a1af-0fbe-4d5c-b907-42da4de6164f
📒 Files selected for processing (15)
examples/distributed/README.mdexamples/distributed/intranode/example_gemm_rs_overlapped.pyexamples/distributed/intranode/example_reduce_scatter.pyexamples/distributed/intranode/reduce_scatter.pyexamples/distributed/intranode/test_intranode.pyexamples/distributed/nvshmem/example_all_to_all.pyexamples/distributed/nvshmem/example_allgather.pyexamples/distributed/nvshmem/example_cannon.pyexamples/distributed/nvshmem/example_nvshmem.pyexamples/distributed/nvshmem/example_overlapping_allgather.pyexamples/distributed/nvshmem/example_post_attn_all2all_transpose.pyexamples/distributed/nvshmem/example_pre_attn_all2all.pyexamples/distributed/nvshmem/example_pre_attn_all2all_transpose.pyexamples/distributed/nvshmem/example_simple_shift.pyexamples/distributed/nvshmem/example_summa.py
💤 Files with no reviewable changes (2)
- examples/distributed/intranode/example_gemm_rs_overlapped.py
- examples/distributed/intranode/reduce_scatter.py
✅ Files skipped from review due to trivial changes (1)
- examples/distributed/README.md
| def test_example_reduce_scatter(): | ||
| torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None), nprocs=2) |
There was a problem hiding this comment.
Same args=None issue applies to the new reduce_scatter test.
This test passes args=None, which will crash in example_reduce_scatter.main when accessing args.M and args.N. The fix proposed in example_reduce_scatter.py (adding if args else fallbacks) would resolve this.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/intranode/test_intranode.py` around lines 35 - 36, The
test test_example_reduce_scatter calls
torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None),
nprocs=2) which passes args=None into example_reduce_scatter.main and causes
attribute access errors on args.M and args.N; fix by either updating the test to
pass a proper args object (replace the None with a simple namespace/config
providing M and N) or make example_reduce_scatter.main resilient by using safe
fallbacks when args is None (e.g., refer to args.M as args.M if args else
<default_M> and similarly for args.N) so the test no longer crashes; reference
example_reduce_scatter.main and test_example_reduce_scatter when making the
change.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
examples/distributed/intranode/example_reduce_scatter.py (1)
21-24:⚠️ Potential issue | 🟠 MajorGuard
argsbefore dereferencing to avoid test/runtime crash.At Line 23 and Line 24,
argsis used unconditionally. Calls that passargs=Nonewill raiseAttributeError.🐛 Proposed fix
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): dtype = torch.float16 - M = args.M - N = args.N + M = args.M if args is not None else 8192 + N = args.N if args is not None else 8192 M_per_rank = M // num_local_ranks🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_reduce_scatter.py` around lines 21 - 24, In main(local_rank: int, num_local_ranks: int, args: argparse.Namespace) guard access to args before using args.M and args.N: check if args is not None (or supply sensible defaults) and only then read M and N, otherwise raise a clear ValueError or set defaults; update the references to args.M and args.N in the function to use the guarded values so calls that pass args=None do not raise AttributeError.
🧹 Nitpick comments (1)
examples/distributed/intranode/example_reduce_scatter.py (1)
27-59: Ensure process-group teardown runs even on failure.If an exception occurs before Line 59, child workers can exit without cleanup. Wrap the body after init in
try/finallyand teardown conditionally.♻️ Suggested structure
- rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" - - allocator = tilelang.get_allocator( - size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group - ) - ... - dist.destroy_process_group() + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + try: + assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" + allocator = tilelang.get_allocator( + size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) + ... + finally: + if dist.is_initialized(): + dist.destroy_process_group()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_reduce_scatter.py` around lines 27 - 59, After calling init_dist (which returns rank, num_ranks, group) wrap the rest of the worker logic (allocator creation, tensor setup, create_reduce_scater_2d_ctx, reduce_scatter_2d_op call, comparisons, perf loop, etc.) in a try/finally so teardown always runs; in finally call dist.destroy_process_group() only if the process group is initialized (e.g., check dist.is_initialized() or that group is not None) to avoid errors, and re-raise any caught exception so failures are not swallowed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Line 25: Before computing M_per_rank, validate that M is evenly divisible by
num_local_ranks: add a guard that checks M % num_local_ranks == 0 and raises a
clear ValueError (or AssertionError) if not, then compute M_per_rank = M //
num_local_ranks; reference the variables M, num_local_ranks and the target
variable M_per_rank in the check so the failure message clearly states the
offending values.
---
Duplicate comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 21-24: In main(local_rank: int, num_local_ranks: int, args:
argparse.Namespace) guard access to args before using args.M and args.N: check
if args is not None (or supply sensible defaults) and only then read M and N,
otherwise raise a clear ValueError or set defaults; update the references to
args.M and args.N in the function to use the guarded values so calls that pass
args=None do not raise AttributeError.
---
Nitpick comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 27-59: After calling init_dist (which returns rank, num_ranks,
group) wrap the rest of the worker logic (allocator creation, tensor setup,
create_reduce_scater_2d_ctx, reduce_scatter_2d_op call, comparisons, perf loop,
etc.) in a try/finally so teardown always runs; in finally call
dist.destroy_process_group() only if the process group is initialized (e.g.,
check dist.is_initialized() or that group is not None) to avoid errors, and
re-raise any caught exception so failures are not swallowed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e85443cb-266d-4851-92a4-6662e45bfd82
📒 Files selected for processing (4)
examples/distributed/intranode/example_gemm_rs_overlapped.pyexamples/distributed/intranode/example_reduce_scatter.pyexamples/distributed/intranode/reduce_scatter.pyexamples/distributed/intranode/test_intranode.py
💤 Files with no reviewable changes (2)
- examples/distributed/intranode/reduce_scatter.py
- examples/distributed/intranode/example_gemm_rs_overlapped.py
✅ Files skipped from review due to trivial changes (1)
- examples/distributed/intranode/test_intranode.py
| dtype = torch.float16 | ||
| M = args.M | ||
| N = args.N | ||
| M_per_rank = M // num_local_ranks |
There was a problem hiding this comment.
Add explicit shape validation before computing per-rank split.
At Line 25, floor-division can mask invalid input and defer failure into lower-level asserts. Validate M % num_local_ranks == 0 up front with a clear error.
✅ Proposed guard
- M_per_rank = M // num_local_ranks
+ if M <= 0 or N <= 0:
+ raise ValueError("M and N must be > 0")
+ if M % num_local_ranks != 0:
+ raise ValueError("M must be divisible by num_local_ranks")
+ M_per_rank = M // num_local_ranks📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| M_per_rank = M // num_local_ranks | |
| if M <= 0 or N <= 0: | |
| raise ValueError("M and N must be > 0") | |
| if M % num_local_ranks != 0: | |
| raise ValueError("M must be divisible by num_local_ranks") | |
| M_per_rank = M // num_local_ranks |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/intranode/example_reduce_scatter.py` at line 25, Before
computing M_per_rank, validate that M is evenly divisible by num_local_ranks:
add a guard that checks M % num_local_ranks == 0 and raises a clear ValueError
(or AssertionError) if not, then compute M_per_rank = M // num_local_ranks;
reference the variables M, num_local_ranks and the target variable M_per_rank in
the check so the failure message clearly states the offending values.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (5)
examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py (2)
29-29: Consider prefixing unused unpacked variables with underscore.Static analysis flags
S(line 29) andrank,num_ranks(line 82) as unused.Suggested fix
- B, _, S, D = src.shape + B, _, _S, D = src.shape- rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + _rank, _num_ranks, group = init_dist(local_rank, num_local_ranks)Also applies to: 82-82
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py` at line 29, The unpacked tuple and some local variables are flagged unused; rename unused unpack targets and variables by prefixing them with underscores to avoid linter warnings: change the unpacking B, _, S, D = src.shape so the unused element currently named S is prefixed (e.g., B, _, _, D or B, _, _S, D) and rename the local variables rank and num_ranks to _rank and _num_ranks (or _rank, _num_ranks) wherever they appear in this file to indicate they are intentionally unused.
114-114: Minor inconsistency: emoji in output message.This file uses
\u2705(✅) in the success message while the other two example files don't. Consider making output format consistent across all examples.Suggested fix for consistency
- print(f"rank {local_rank} check passed. \u2705") + print(f"rank {local_rank} check passed.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py` at line 114, The success print in the example uses an emoji escape sequence in the statement print(f"rank {local_rank} check passed. \u2705") which is inconsistent with the other example files; update that print call in example_post_attn_all2all_transpose_intranode.py (the line containing print(f"rank {local_rank} check passed. \u2705")) to match the other examples by removing the emoji (e.g., print(f"rank {local_rank} check passed.") ) or otherwise using the same formatting used across the other two example files.examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py (1)
29-29: Consider prefixing unused unpacked variables with underscore.Static analysis flags
NH(line 29) andrank,num_ranks(line 82) as unused.Suggested fix
- B, _, NH, D = src.shape + B, _, _NH, D = src.shape- rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + _rank, _num_ranks, group = init_dist(local_rank, num_local_ranks)Also applies to: 82-82
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py` at line 29, The variable unpacking and some local variables are flagged as unused; update the tuple unpack and the rank variables to use underscore-prefixed names to indicate intentional unused values: change the unpack from "B, _, NH, D = src.shape" to use a prefixed unused name (e.g., "B, _, _NH, D = src.shape" or "B, _, _, D = src.shape" if NH truly isn't needed) and rename "rank" and "num_ranks" to "_rank" and "_num_ranks" wherever they are defined/used in this file (example symbol names: src.shape unpack, rank, num_ranks) so static analysis no longer reports them as unused.examples/distributed/intranode/example_pre_attn_all2all_intranode.py (1)
28-28: Consider prefixing unused unpacked variables with underscore.Static analysis flags
NH(line 28) andrank,num_ranks(line 78) as unused. Prefixing with_signals intentional discard.Suggested fix
- B, NH, _, D = src.shape + B, _NH, _, D = src.shape- rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + _rank, _num_ranks, group = init_dist(local_rank, num_local_ranks)Also applies to: 78-78
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/intranode/example_pre_attn_all2all_intranode.py` at line 28, The unpacked variable NH from src.shape and the variables rank and num_ranks are currently unused; rename them to indicate intentional discard by prefixing with an underscore (e.g., _NH for the unpack from src.shape and _rank/_num_ranks where rank and num_ranks are defined) so static analysis won't flag them as unused while keeping intent clear; update occurrences of those names accordingly in the functions using B, _NH, _, D, and where rank/num_ranks are defined so only the used symbols remain unchanged.examples/distributed/nvshmem/example_nvshmem.py (1)
22-32: Wrap module-level execution in a__main__guard.Line 22–Line 32 executes compile/profile/print at import time. For examples, this is better gated to avoid unintended side effects during test discovery/imports.
♻️ Proposed refactor
-func = dist_test(128, 128, 128, 128) - -kernel = tilelang.compile(func, out_idx=-1) - -# Get CUDA Source -print(kernel.get_kernel_source()) - -profiler = kernel.get_profiler() -out = profiler.run_once() - -print(out) +if __name__ == "__main__": + func = dist_test(128, 128, 128, 128) + kernel = tilelang.compile(func, out_idx=-1) + + # Get CUDA Source + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler() + out = profiler.run_once() + print(out)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/distributed/nvshmem/example_nvshmem.py` around lines 22 - 32, The top-level execution that compiles and runs the kernel is happening at import time; wrap the calls to dist_test(...), tilelang.compile(...), kernel.get_kernel_source(), kernel.get_profiler(), and profiler.run_once() inside an if __name__ == "__main__": guard so importing example_nvshmem.py no longer triggers compilation/profiling, and indent those statements into that block to preserve behavior when the file is executed as a script.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/distributed/nvshmem/example_nvshmem.py`:
- Around line 29-30: The example calls profiler.run_once() (via
kernel.get_profiler()) without initializing the distributed NVSHMEM runtime,
which can cause failures when functions like T.get_pe() are used; fix this by
initializing the NVSHMEM/distributed runtime before calling profiler.run_once()
— e.g., invoke the library's NVSHMEM initialization routine (or
kernel.profiler/init method if available) on all PEs, ensure any required
communicator/PE setup is complete, then call profiler.run_once(), and finally
finalize the NVSHMEM runtime after the run.
- Around line 11-17: The kernel uses T.ceildiv grid sizing but always copies a
full block_M x block_N tile, which can read/write out-of-bounds for edge CTAs;
modify the kernel (T.Kernel) to compute the actual tile extents using N, M and
the tile indices (bx, by) — e.g., tile_h = min(block_M, M - by*block_M), tile_w
= min(block_N, N - bx*block_N) — and use those extents when performing the
shared-memory loads/stores (the T.copy calls that move between A, A_shared, and
B) or gate the copies with conditional bounds checks so you only access valid
elements. Ensure any per-CTA temporary mype/A_shared logic still works with
variable tile sizes.
In `@examples/distributed/nvshmem/example_summa.py`:
- Line 18: The call to driver.get_num_sms() in summa() uses the default device 0
and will return the wrong SM count in multi-process runs; modify the summa(...)
function signature to add device_id: int = 0, use that parameter when calling
driver.get_num_sms(device_id) inside summa, and update the call site that
launches summa(...) to pass device_id=LOCAL_RANK so each process queries its
local GPU.
---
Nitpick comments:
In
`@examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py`:
- Line 29: The unpacked tuple and some local variables are flagged unused;
rename unused unpack targets and variables by prefixing them with underscores to
avoid linter warnings: change the unpacking B, _, S, D = src.shape so the unused
element currently named S is prefixed (e.g., B, _, _, D or B, _, _S, D) and
rename the local variables rank and num_ranks to _rank and _num_ranks (or _rank,
_num_ranks) wherever they appear in this file to indicate they are intentionally
unused.
- Line 114: The success print in the example uses an emoji escape sequence in
the statement print(f"rank {local_rank} check passed. \u2705") which is
inconsistent with the other example files; update that print call in
example_post_attn_all2all_transpose_intranode.py (the line containing
print(f"rank {local_rank} check passed. \u2705")) to match the other examples by
removing the emoji (e.g., print(f"rank {local_rank} check passed.") ) or
otherwise using the same formatting used across the other two example files.
In `@examples/distributed/intranode/example_pre_attn_all2all_intranode.py`:
- Line 28: The unpacked variable NH from src.shape and the variables rank and
num_ranks are currently unused; rename them to indicate intentional discard by
prefixing with an underscore (e.g., _NH for the unpack from src.shape and
_rank/_num_ranks where rank and num_ranks are defined) so static analysis won't
flag them as unused while keeping intent clear; update occurrences of those
names accordingly in the functions using B, _NH, _, D, and where rank/num_ranks
are defined so only the used symbols remain unchanged.
In
`@examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py`:
- Line 29: The variable unpacking and some local variables are flagged as
unused; update the tuple unpack and the rank variables to use
underscore-prefixed names to indicate intentional unused values: change the
unpack from "B, _, NH, D = src.shape" to use a prefixed unused name (e.g., "B,
_, _NH, D = src.shape" or "B, _, _, D = src.shape" if NH truly isn't needed) and
rename "rank" and "num_ranks" to "_rank" and "_num_ranks" wherever they are
defined/used in this file (example symbol names: src.shape unpack, rank,
num_ranks) so static analysis no longer reports them as unused.
In `@examples/distributed/nvshmem/example_nvshmem.py`:
- Around line 22-32: The top-level execution that compiles and runs the kernel
is happening at import time; wrap the calls to dist_test(...),
tilelang.compile(...), kernel.get_kernel_source(), kernel.get_profiler(), and
profiler.run_once() inside an if __name__ == "__main__": guard so importing
example_nvshmem.py no longer triggers compilation/profiling, and indent those
statements into that block to preserve behavior when the file is executed as a
script.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2b2681ce-9775-448f-b49e-0e692140ef8f
📒 Files selected for processing (9)
examples/distributed/README.mdexamples/distributed/intranode/example_post_attn_all2all_transpose_intranode.pyexamples/distributed/intranode/example_pre_attn_all2all_intranode.pyexamples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.pyexamples/distributed/intranode/test_intranode.pyexamples/distributed/nvshmem/example_all_to_all.pyexamples/distributed/nvshmem/example_cannon.pyexamples/distributed/nvshmem/example_nvshmem.pyexamples/distributed/nvshmem/example_summa.py
💤 Files with no reviewable changes (2)
- examples/distributed/nvshmem/example_cannon.py
- examples/distributed/nvshmem/example_all_to_all.py
✅ Files skipped from review due to trivial changes (1)
- examples/distributed/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/distributed/intranode/test_intranode.py
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_N), dtype) | ||
| mype = T.alloc_local([1], "int32") | ||
|
|
||
| mype[0] = T.get_pe() | ||
| T.copy(A[by * block_M, bx * block_N], A_shared) | ||
| T.copy(A_shared, B[by * block_M, bx * block_N]) |
There was a problem hiding this comment.
Guard partial tiles or enforce divisibility before launching ceildiv grid.
Line 11 uses ceil-div launch bounds, but Line 16–Line 17 copy full block_M x block_N tiles unconditionally. For non-divisible M/N, edge CTAs can access out of bounds.
🔧 Proposed fix (safe precondition)
def dist_test(M, N, block_M, block_N, dtype="int16"):
+ if M % block_M != 0 or N % block_N != 0:
+ raise ValueError(
+ "dist_test currently requires M and N to be divisible by block_M and block_N"
+ )
`@T.prim_func`
def main(
A: T.Buffer((M, N), dtype),
B: T.Buffer((M, N), dtype),
):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/nvshmem/example_nvshmem.py` around lines 11 - 17, The
kernel uses T.ceildiv grid sizing but always copies a full block_M x block_N
tile, which can read/write out-of-bounds for edge CTAs; modify the kernel
(T.Kernel) to compute the actual tile extents using N, M and the tile indices
(bx, by) — e.g., tile_h = min(block_M, M - by*block_M), tile_w = min(block_N, N
- bx*block_N) — and use those extents when performing the shared-memory
loads/stores (the T.copy calls that move between A, A_shared, and B) or gate the
copies with conditional bounds checks so you only access valid elements. Ensure
any per-CTA temporary mype/A_shared logic still works with variable tile sizes.
| profiler = kernel.get_profiler() | ||
| out = profiler.run_once() |
There was a problem hiding this comment.
run_once() may execute without distributed runtime initialization.
Line 30 directly calls profiler.run_once(). From tilelang/profiler/__init__.py, run_once() just invokes the kernel and does not initialize distributed state. In an NVSHMEM example, this can fail at runtime or produce undefined behavior when T.get_pe() is used.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/nvshmem/example_nvshmem.py` around lines 29 - 30, The
example calls profiler.run_once() (via kernel.get_profiler()) without
initializing the distributed NVSHMEM runtime, which can cause failures when
functions like T.get_pe() are used; fix this by initializing the
NVSHMEM/distributed runtime before calling profiler.run_once() — e.g., invoke
the library's NVSHMEM initialization routine (or kernel.profiler/init method if
available) on all PEs, ensure any required communicator/PE setup is complete,
then call profiler.run_once(), and finally finalize the NVSHMEM runtime after
the run.
| accum_dtype = "float32" | ||
|
|
||
| sm_num = 132 # 132 SMs for H100 | ||
| sm_num = driver.get_num_sms() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify API default and current callsites that omit device_id.
# Expected: get_num_sms defaults to 0; this file currently calls without args.
rg -nP --type=py '\bdef\s+get_num_sms\s*\(\s*device_id\s*:\s*int\s*=\s*0' -C2
rg -nP --type=py '\bget_num_sms\s*\(\s*\)' -C2
rg -nP --type=py 'torch\.device\(f"cuda:\{RANK\}"\)|LOCAL_RANK' -C2 examples/distributed/nvshmem/example_summa.pyRepository: tile-ai/tilescale
Length of output: 5369
🏁 Script executed:
# Get full context of summa function definition
head -50 examples/distributed/nvshmem/example_summa.py
# Find all calls to summa function
rg -n 'summa\s*\(' examples/distributed/nvshmem/example_summa.py -A2 -B2Repository: tile-ai/tilescale
Length of output: 2233
🏁 Script executed:
# Find all call sites of summa function
rg -n 'summa\s*\(' --type=py | head -20
# Check how other distributed examples handle device_id in similar functions
rg -n 'def\s+\w+\(.*device_id' examples/distributed/ --type=py -A1Repository: tile-ai/tilescale
Length of output: 300
🏁 Script executed:
# Check other distributed examples that call get_num_sms()
rg -n 'get_num_sms\(\)' examples/distributed/ --type=py -B5 -A5
# Get the full main block context around the summa call to verify LOCAL_RANK scope
sed -n '145,175p' examples/distributed/nvshmem/example_summa.pyRepository: tile-ai/tilescale
Length of output: 3119
Pass LOCAL_RANK to get_num_sms() to query the correct GPU's SM count in multi-process distributed runs.
Line 18 calls driver.get_num_sms() without arguments, defaulting to device_id=0. In multi-process setups, non-zero ranks map to different GPUs and will incorrectly read SM count from device 0 instead of their rank-local device. LOCAL_RANK is available at the call site (line 150) and should be passed through the function.
Add device_id: int = 0 parameter to the summa() function signature, pass it to driver.get_num_sms(device_id) at line 18, and pass device_id=LOCAL_RANK at the call site (line 163).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/nvshmem/example_summa.py` at line 18, The call to
driver.get_num_sms() in summa() uses the default device 0 and will return the
wrong SM count in multi-process runs; modify the summa(...) function signature
to add device_id: int = 0, use that parameter when calling
driver.get_num_sms(device_id) inside summa, and update the call site that
launches summa(...) to pass device_id=LOCAL_RANK so each process queries its
local GPU.
Intranode examples weren't taken into consideration in CI previously. And i think we need to sort examples into intranode, ipc-based and others.
rdc-related issuesSummary by CodeRabbit
Tests
New Features
Documentation
Chores