Skip to content

[Reorgnize] Reorganize intra-node examples add add tests#52

Open
Rachmanino wants to merge 8 commits intomainfrom
wt/reorg
Open

[Reorgnize] Reorganize intra-node examples add add tests#52
Rachmanino wants to merge 8 commits intomainfrom
wt/reorg

Conversation

@Rachmanino
Copy link
Copy Markdown
Collaborator

@Rachmanino Rachmanino commented Feb 13, 2026

  • Reorgnize intra-node examples and add tests for them in CI
    Intranode examples weren't taken into consideration in CI previously. And i think we need to sort examples into intranode, ipc-based and others.
  • Fix bugs in allocators and let CI pass
  • Investigate rdc-related issues

Summary by CodeRabbit

  • Tests

    • Added intranode distributed tests to validate multi-process, CUDA-based overlapped execution scenarios.
  • New Features

    • Added several intranode IPC examples (pre/post attention all‑to‑all variants and a reduce‑scatter example) and a new NVSHMEM example with CLI, correctness checks, and benchmarking.
  • Documentation

    • Reorganized distributed examples README into NVSHMEM and IPC sections with explicit run commands.
  • Chores

    • Removed legacy distributed example implementations and low‑level utilities; removed a few global cache-disabling calls.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Feb 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Removed multiple NVSHMEM/Triton-backed distributed examples/utilities, added several intranode IPC examples and tests, updated README and minor intranode tweaks.

Changes

Cohort / File(s) Summary
Removed Allgather GEMM Example
examples/distributed/example_allgather_gemm.py
Deleted the TileLang NVSHMEM all-gather GEMM example and its exported allgather_gemm(...) (kernel, NVSHMEM setup, compile, run, validation).
Removed GEMM/Reduce-Scatter Utilities
examples/distributed/gemm_rs_utils.py
Deleted utility module that provided barrier/context dataclasses, NVSHMEM buffer allocators, and helpers for reduce-scatter / GEMM tensor-parallel contexts.
Removed Triton SP Module
examples/distributed/triton_sp.py
Deleted Triton/NVSHMEM persistent GEMM + all-to-all push kernels, related inline vector externs, launcher functions, and SpUlysessOAll2AllGemmKernel class (buffer/sync/GEMM flows).
New intranode tests
examples/distributed/intranode/test_intranode.py
Added pytest suite that conditionally spawns 2-process CUDA runs to exercise several intranode example entrypoints (gated on distributed + CUDA + compute capability 9.0).
New intranode reduce-scatter example
examples/distributed/intranode/example_reduce_scatter.py
Added IPC-based intranode reduce-scatter example: torch_reduce_scatter, main to init dist, build reduce-scatter context, run TileLang op, validate vs PyTorch, benchmark, and CLI spawn.
New intranode pre/post-attn all2all examples
examples/distributed/intranode/example_pre_attn_all2all_intranode.py, examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py, examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py
Added IPC-style all-to-all examples (pre/post-attn variants) with TileLang @T.prim_func kernels using T.put_block, PyTorch reference implementations, CLI entrypoints, correctness checks and perf measurements.
NVSHMEM examples & small edits
examples/distributed/nvshmem/example_nvshmem.py, examples/distributed/nvshmem/example_all_to_all.py, examples/distributed/nvshmem/example_cannon.py, examples/distributed/nvshmem/example_summa.py
Added example_nvshmem.py; removed module-level tilelang.disable_cache() calls in several nvshmem examples; summa now queries SM count at runtime instead of hard-coded value.
Docs & trivial edits
examples/distributed/README.md, examples/distributed/intranode/example_gemm_rs_overlapped.py, examples/distributed/intranode/reduce_scatter.py
Reorganised README into NVSHMEM vs IPC sections and added launch examples; removed a from __future__ line and removed explicit tilelang.disable_cache() in one file.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • chengyupku
  • tzj-fxz

Poem

🐰 I hopped through kernels, wires, and streams,
I pruned old scripts and planted fresh tests' dreams.
Two processes tango where CUDA lights beam,
Ranks whisper sync, tensors stitch seam by seam.
A rabbit cheers — run clean, run supreme!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title mentions reorganizing intra-node examples and adding tests, which aligns with the main objective, but contains a typo ('[Reorgnize]' instead of '[Reorganize]') and the phrase 'add add' appears to be a duplicate word error. Correct the typo 'Reorgnize' to 'Reorganize' and fix the duplicate 'add add' to 'add' for clarity and professionalism.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch wt/reorg

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/distributed/intranode/test_intranode.py`:
- Around line 13-14: The test passes args=(2, None) which sends None into the
main functions and causes AttributeError when they access attributes; update the
test (test_example_allgather_gemm_overlapped) to pass a minimal valid args
object (e.g., an argparse.Namespace or simple object) containing the attributes
expected by example_allgather_gemm_overlapped.main,
example_gemm_rs_overlapped.main and example_sp_ag_attention_intra_node.main (at
least persistent for the first two, and batch_size, q_head, etc. for the
attention example), or alternately add defensive checks inside those main
functions to handle args is None before accessing attributes — pick one approach
and implement it consistently for the three mains named above.
🧹 Nitpick comments (1)
examples/distributed/intranode/test_intranode.py (1)

10-28: Consider pytest.mark.parametrize to reduce boilerplate.

All three tests share the same decorator stack and spawn logic, differing only in the imported module. A parametrized test would reduce duplication:

♻️ Optional refactor
+@tilelang.testing.requires_distributed
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
+@pytest.mark.parametrize("module", [
+    example_allgather_gemm_overlapped,
+    example_gemm_rs_overlapped,
+    example_sp_ag_attention_intra_node,
+])
+def test_intranode_example(module):
+    torch.multiprocessing.spawn(module.main, args=(2, None), nprocs=2)

Comment on lines +13 to +14
def test_example_allgather_gemm_overlapped():
torch.multiprocessing.spawn(example_allgather_gemm_overlapped.main, args=(2, None), nprocs=2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, find the location of the example modules
find . -name "example_allgather_gemm_overlapped.py" -o -name "example_gemm_rs_overlapped.py" -o -name "example_sp_ag_attention_intra_node.py" | head -20

Repository: tile-ai/tilescale

Length of output: 264


🏁 Script executed:

#!/bin/bash
# Extract main function definitions with context
for f in ./examples/distributed/intranode/example_allgather_gemm_overlapped.py ./examples/distributed/intranode/example_gemm_rs_overlapped.py ./examples/distributed/intranode/example_sp_ag_attention_intra_node.py; do
  echo "=== $(basename $f) ==="
  sed -n '/^def main/,/^def \|^class \|^$/p' "$f" | head -40
done

Repository: tile-ai/tilescale

Length of output: 931


🏁 Script executed:

#!/bin/bash
# Get the full main function for each module (up to 60 lines)
for f in ./examples/distributed/intranode/example_allgather_gemm_overlapped.py ./examples/distributed/intranode/example_gemm_rs_overlapped.py ./examples/distributed/intranode/example_sp_ag_attention_intra_node.py; do
  echo "=== $(basename $f) ==="
  sed -n '/^def main/,/^def \|^class \|^$/p' "$f" | head -60
  echo ""
done

Repository: tile-ai/tilescale

Length of output: 934


🏁 Script executed:

#!/bin/bash
# Check the test file to confirm args pattern
echo "=== test_intranode.py ==="
cat -n examples/distributed/intranode/test_intranode.py

echo ""
echo "=== Checking for any args usage in example_sp_ag_attention_intra_node.py ==="
grep -n "args\." examples/distributed/intranode/example_sp_ag_attention_intra_node.py || echo "No args. accesses found"

Repository: tile-ai/tilescale

Length of output: 1792


All three main functions will crash with AttributeError when args=None is passed.

The tests call torch.multiprocessing.spawn(..., args=(2, None), nprocs=2), which passes None as the third argument to each main function. However:

  • example_allgather_gemm_overlapped.main and example_gemm_rs_overlapped.main unconditionally access args.persistent on line 5.
  • example_sp_ag_attention_intra_node.main unconditionally accesses args.batch_size, args.q_head, and multiple other attributes (lines 293-304, 409).

All three will fail with AttributeError: 'NoneType' object has no attribute .... Either pass valid args or add None checks to all attribute accesses in these functions.

🤖 Prompt for AI Agents
In `@examples/distributed/intranode/test_intranode.py` around lines 13 - 14, The
test passes args=(2, None) which sends None into the main functions and causes
AttributeError when they access attributes; update the test
(test_example_allgather_gemm_overlapped) to pass a minimal valid args object
(e.g., an argparse.Namespace or simple object) containing the attributes
expected by example_allgather_gemm_overlapped.main,
example_gemm_rs_overlapped.main and example_sp_ag_attention_intra_node.main (at
least persistent for the first two, and batch_size, q_head, etc. for the
attention example), or alternately add defensive checks inside those main
functions to handle args is None before accessing attributes — pick one approach
and implement it consistently for the three mains named above.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
examples/distributed/intranode/example_reduce_scatter.py (1)

10-18: Avoid shadowing the Python builtin input.

The parameter name input shadows the Python builtin. Consider renaming to input_tensor or inp for clarity and to avoid potential confusion.

♻️ Suggested rename
 def torch_reduce_scatter(
     pg: torch.distributed.ProcessGroup,
-    input: torch.Tensor,
+    input_tensor: torch.Tensor,
     num_local_ranks: int,
 ) -> torch.Tensor:
-    M, N = input.shape
-    output = torch.empty((M // num_local_ranks, N), dtype=input.dtype, device=input.device)
-    torch.distributed.reduce_scatter_tensor(output, input, group=pg)
+    M, N = input_tensor.shape
+    output = torch.empty((M // num_local_ranks, N), dtype=input_tensor.dtype, device=input_tensor.device)
+    torch.distributed.reduce_scatter_tensor(output, input_tensor, group=pg)
     return output
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/intranode/example_reduce_scatter.py` around lines 10 -
18, The function torch_reduce_scatter currently uses a parameter named input
which shadows the Python builtin; rename that parameter (and all internal
references) to input_tensor (or inp) in the function signature and body of
torch_reduce_scatter so the code uses the new name when computing M, N, creating
output, and calling torch.distributed.reduce_scatter_tensor; ensure any callers
or references to torch_reduce_scatter are updated accordingly to match the new
parameter name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 21-25: The function main currently assumes args is not None and
directly reads args.M and args.N; add a None guard like other examples so tests
that call main(args=None) won't crash: inside main (function main) check if args
is None (or use a fallback pattern) and set M and N to default values when args
is None before computing M_per_rank; ensure you reference and preserve dtype,
local_rank, num_local_ranks and only replace the unconditional accesses to
args.M / args.N with guarded assignments.

In `@examples/distributed/intranode/test_intranode.py`:
- Around line 35-36: The test test_example_reduce_scatter calls
torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None),
nprocs=2) which passes args=None into example_reduce_scatter.main and causes
attribute access errors on args.M and args.N; fix by either updating the test to
pass a proper args object (replace the None with a simple namespace/config
providing M and N) or make example_reduce_scatter.main resilient by using safe
fallbacks when args is None (e.g., refer to args.M as args.M if args else
<default_M> and similarly for args.N) so the test no longer crashes; reference
example_reduce_scatter.main and test_example_reduce_scatter when making the
change.

---

Nitpick comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 10-18: The function torch_reduce_scatter currently uses a
parameter named input which shadows the Python builtin; rename that parameter
(and all internal references) to input_tensor (or inp) in the function signature
and body of torch_reduce_scatter so the code uses the new name when computing M,
N, creating output, and calling torch.distributed.reduce_scatter_tensor; ensure
any callers or references to torch_reduce_scatter are updated accordingly to
match the new parameter name.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1ee0a1af-0fbe-4d5c-b907-42da4de6164f

📥 Commits

Reviewing files that changed from the base of the PR and between b9026fd and cadd362.

📒 Files selected for processing (15)
  • examples/distributed/README.md
  • examples/distributed/intranode/example_gemm_rs_overlapped.py
  • examples/distributed/intranode/example_reduce_scatter.py
  • examples/distributed/intranode/reduce_scatter.py
  • examples/distributed/intranode/test_intranode.py
  • examples/distributed/nvshmem/example_all_to_all.py
  • examples/distributed/nvshmem/example_allgather.py
  • examples/distributed/nvshmem/example_cannon.py
  • examples/distributed/nvshmem/example_nvshmem.py
  • examples/distributed/nvshmem/example_overlapping_allgather.py
  • examples/distributed/nvshmem/example_post_attn_all2all_transpose.py
  • examples/distributed/nvshmem/example_pre_attn_all2all.py
  • examples/distributed/nvshmem/example_pre_attn_all2all_transpose.py
  • examples/distributed/nvshmem/example_simple_shift.py
  • examples/distributed/nvshmem/example_summa.py
💤 Files with no reviewable changes (2)
  • examples/distributed/intranode/example_gemm_rs_overlapped.py
  • examples/distributed/intranode/reduce_scatter.py
✅ Files skipped from review due to trivial changes (1)
  • examples/distributed/README.md

Comment on lines +35 to +36
def test_example_reduce_scatter():
torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None), nprocs=2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same args=None issue applies to the new reduce_scatter test.

This test passes args=None, which will crash in example_reduce_scatter.main when accessing args.M and args.N. The fix proposed in example_reduce_scatter.py (adding if args else fallbacks) would resolve this.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/intranode/test_intranode.py` around lines 35 - 36, The
test test_example_reduce_scatter calls
torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None),
nprocs=2) which passes args=None into example_reduce_scatter.main and causes
attribute access errors on args.M and args.N; fix by either updating the test to
pass a proper args object (replace the None with a simple namespace/config
providing M and N) or make example_reduce_scatter.main resilient by using safe
fallbacks when args is None (e.g., refer to args.M as args.M if args else
<default_M> and similarly for args.N) so the test no longer crashes; reference
example_reduce_scatter.main and test_example_reduce_scatter when making the
change.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
examples/distributed/intranode/example_reduce_scatter.py (1)

21-24: ⚠️ Potential issue | 🟠 Major

Guard args before dereferencing to avoid test/runtime crash.

At Line 23 and Line 24, args is used unconditionally. Calls that pass args=None will raise AttributeError.

🐛 Proposed fix
 def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
     dtype = torch.float16
-    M = args.M
-    N = args.N
+    M = args.M if args is not None else 8192
+    N = args.N if args is not None else 8192
     M_per_rank = M // num_local_ranks
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/intranode/example_reduce_scatter.py` around lines 21 -
24, In main(local_rank: int, num_local_ranks: int, args: argparse.Namespace)
guard access to args before using args.M and args.N: check if args is not None
(or supply sensible defaults) and only then read M and N, otherwise raise a
clear ValueError or set defaults; update the references to args.M and args.N in
the function to use the guarded values so calls that pass args=None do not raise
AttributeError.
🧹 Nitpick comments (1)
examples/distributed/intranode/example_reduce_scatter.py (1)

27-59: Ensure process-group teardown runs even on failure.

If an exception occurs before Line 59, child workers can exit without cleanup. Wrap the body after init in try/finally and teardown conditionally.

♻️ Suggested structure
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
-    assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now"
-
-    allocator = tilelang.get_allocator(
-        size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group
-    )
-    ...
-    dist.destroy_process_group()
+    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    try:
+        assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now"
+        allocator = tilelang.get_allocator(
+            size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group
+        )
+        ...
+    finally:
+        if dist.is_initialized():
+            dist.destroy_process_group()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/intranode/example_reduce_scatter.py` around lines 27 -
59, After calling init_dist (which returns rank, num_ranks, group) wrap the rest
of the worker logic (allocator creation, tensor setup,
create_reduce_scater_2d_ctx, reduce_scatter_2d_op call, comparisons, perf loop,
etc.) in a try/finally so teardown always runs; in finally call
dist.destroy_process_group() only if the process group is initialized (e.g.,
check dist.is_initialized() or that group is not None) to avoid errors, and
re-raise any caught exception so failures are not swallowed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Line 25: Before computing M_per_rank, validate that M is evenly divisible by
num_local_ranks: add a guard that checks M % num_local_ranks == 0 and raises a
clear ValueError (or AssertionError) if not, then compute M_per_rank = M //
num_local_ranks; reference the variables M, num_local_ranks and the target
variable M_per_rank in the check so the failure message clearly states the
offending values.

---

Duplicate comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 21-24: In main(local_rank: int, num_local_ranks: int, args:
argparse.Namespace) guard access to args before using args.M and args.N: check
if args is not None (or supply sensible defaults) and only then read M and N,
otherwise raise a clear ValueError or set defaults; update the references to
args.M and args.N in the function to use the guarded values so calls that pass
args=None do not raise AttributeError.

---

Nitpick comments:
In `@examples/distributed/intranode/example_reduce_scatter.py`:
- Around line 27-59: After calling init_dist (which returns rank, num_ranks,
group) wrap the rest of the worker logic (allocator creation, tensor setup,
create_reduce_scater_2d_ctx, reduce_scatter_2d_op call, comparisons, perf loop,
etc.) in a try/finally so teardown always runs; in finally call
dist.destroy_process_group() only if the process group is initialized (e.g.,
check dist.is_initialized() or that group is not None) to avoid errors, and
re-raise any caught exception so failures are not swallowed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e85443cb-266d-4851-92a4-6662e45bfd82

📥 Commits

Reviewing files that changed from the base of the PR and between cadd362 and 5952f00.

📒 Files selected for processing (4)
  • examples/distributed/intranode/example_gemm_rs_overlapped.py
  • examples/distributed/intranode/example_reduce_scatter.py
  • examples/distributed/intranode/reduce_scatter.py
  • examples/distributed/intranode/test_intranode.py
💤 Files with no reviewable changes (2)
  • examples/distributed/intranode/reduce_scatter.py
  • examples/distributed/intranode/example_gemm_rs_overlapped.py
✅ Files skipped from review due to trivial changes (1)
  • examples/distributed/intranode/test_intranode.py

dtype = torch.float16
M = args.M
N = args.N
M_per_rank = M // num_local_ranks
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add explicit shape validation before computing per-rank split.

At Line 25, floor-division can mask invalid input and defer failure into lower-level asserts. Validate M % num_local_ranks == 0 up front with a clear error.

✅ Proposed guard
-    M_per_rank = M // num_local_ranks
+    if M <= 0 or N <= 0:
+        raise ValueError("M and N must be > 0")
+    if M % num_local_ranks != 0:
+        raise ValueError("M must be divisible by num_local_ranks")
+    M_per_rank = M // num_local_ranks
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
M_per_rank = M // num_local_ranks
if M <= 0 or N <= 0:
raise ValueError("M and N must be > 0")
if M % num_local_ranks != 0:
raise ValueError("M must be divisible by num_local_ranks")
M_per_rank = M // num_local_ranks
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/intranode/example_reduce_scatter.py` at line 25, Before
computing M_per_rank, validate that M is evenly divisible by num_local_ranks:
add a guard that checks M % num_local_ranks == 0 and raises a clear ValueError
(or AssertionError) if not, then compute M_per_rank = M // num_local_ranks;
reference the variables M, num_local_ranks and the target variable M_per_rank in
the check so the failure message clearly states the offending values.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (5)
examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py (2)

29-29: Consider prefixing unused unpacked variables with underscore.

Static analysis flags S (line 29) and rank, num_ranks (line 82) as unused.

Suggested fix
-    B, _, S, D = src.shape
+    B, _, _S, D = src.shape
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    _rank, _num_ranks, group = init_dist(local_rank, num_local_ranks)

Also applies to: 82-82

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py`
at line 29, The unpacked tuple and some local variables are flagged unused;
rename unused unpack targets and variables by prefixing them with underscores to
avoid linter warnings: change the unpacking B, _, S, D = src.shape so the unused
element currently named S is prefixed (e.g., B, _, _, D or B, _, _S, D) and
rename the local variables rank and num_ranks to _rank and _num_ranks (or _rank,
_num_ranks) wherever they appear in this file to indicate they are intentionally
unused.

114-114: Minor inconsistency: emoji in output message.

This file uses \u2705 (✅) in the success message while the other two example files don't. Consider making output format consistent across all examples.

Suggested fix for consistency
-        print(f"rank {local_rank} check passed. \u2705")
+        print(f"rank {local_rank} check passed.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py`
at line 114, The success print in the example uses an emoji escape sequence in
the statement print(f"rank {local_rank} check passed. \u2705") which is
inconsistent with the other example files; update that print call in
example_post_attn_all2all_transpose_intranode.py (the line containing
print(f"rank {local_rank} check passed. \u2705")) to match the other examples by
removing the emoji (e.g., print(f"rank {local_rank} check passed.") ) or
otherwise using the same formatting used across the other two example files.
examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py (1)

29-29: Consider prefixing unused unpacked variables with underscore.

Static analysis flags NH (line 29) and rank, num_ranks (line 82) as unused.

Suggested fix
-    B, _, NH, D = src.shape
+    B, _, _NH, D = src.shape
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    _rank, _num_ranks, group = init_dist(local_rank, num_local_ranks)

Also applies to: 82-82

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py`
at line 29, The variable unpacking and some local variables are flagged as
unused; update the tuple unpack and the rank variables to use
underscore-prefixed names to indicate intentional unused values: change the
unpack from "B, _, NH, D = src.shape" to use a prefixed unused name (e.g., "B,
_, _NH, D = src.shape" or "B, _, _, D = src.shape" if NH truly isn't needed) and
rename "rank" and "num_ranks" to "_rank" and "_num_ranks" wherever they are
defined/used in this file (example symbol names: src.shape unpack, rank,
num_ranks) so static analysis no longer reports them as unused.
examples/distributed/intranode/example_pre_attn_all2all_intranode.py (1)

28-28: Consider prefixing unused unpacked variables with underscore.

Static analysis flags NH (line 28) and rank, num_ranks (line 78) as unused. Prefixing with _ signals intentional discard.

Suggested fix
-    B, NH, _, D = src.shape
+    B, _NH, _, D = src.shape
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    _rank, _num_ranks, group = init_dist(local_rank, num_local_ranks)

Also applies to: 78-78

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/intranode/example_pre_attn_all2all_intranode.py` at line
28, The unpacked variable NH from src.shape and the variables rank and num_ranks
are currently unused; rename them to indicate intentional discard by prefixing
with an underscore (e.g., _NH for the unpack from src.shape and _rank/_num_ranks
where rank and num_ranks are defined) so static analysis won't flag them as
unused while keeping intent clear; update occurrences of those names accordingly
in the functions using B, _NH, _, D, and where rank/num_ranks are defined so
only the used symbols remain unchanged.
examples/distributed/nvshmem/example_nvshmem.py (1)

22-32: Wrap module-level execution in a __main__ guard.

Line 22–Line 32 executes compile/profile/print at import time. For examples, this is better gated to avoid unintended side effects during test discovery/imports.

♻️ Proposed refactor
-func = dist_test(128, 128, 128, 128)
-
-kernel = tilelang.compile(func, out_idx=-1)
-
-# Get CUDA Source
-print(kernel.get_kernel_source())
-
-profiler = kernel.get_profiler()
-out = profiler.run_once()
-
-print(out)
+if __name__ == "__main__":
+    func = dist_test(128, 128, 128, 128)
+    kernel = tilelang.compile(func, out_idx=-1)
+
+    # Get CUDA Source
+    print(kernel.get_kernel_source())
+
+    profiler = kernel.get_profiler()
+    out = profiler.run_once()
+    print(out)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/nvshmem/example_nvshmem.py` around lines 22 - 32, The
top-level execution that compiles and runs the kernel is happening at import
time; wrap the calls to dist_test(...), tilelang.compile(...),
kernel.get_kernel_source(), kernel.get_profiler(), and profiler.run_once()
inside an if __name__ == "__main__": guard so importing example_nvshmem.py no
longer triggers compilation/profiling, and indent those statements into that
block to preserve behavior when the file is executed as a script.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/distributed/nvshmem/example_nvshmem.py`:
- Around line 29-30: The example calls profiler.run_once() (via
kernel.get_profiler()) without initializing the distributed NVSHMEM runtime,
which can cause failures when functions like T.get_pe() are used; fix this by
initializing the NVSHMEM/distributed runtime before calling profiler.run_once()
— e.g., invoke the library's NVSHMEM initialization routine (or
kernel.profiler/init method if available) on all PEs, ensure any required
communicator/PE setup is complete, then call profiler.run_once(), and finally
finalize the NVSHMEM runtime after the run.
- Around line 11-17: The kernel uses T.ceildiv grid sizing but always copies a
full block_M x block_N tile, which can read/write out-of-bounds for edge CTAs;
modify the kernel (T.Kernel) to compute the actual tile extents using N, M and
the tile indices (bx, by) — e.g., tile_h = min(block_M, M - by*block_M), tile_w
= min(block_N, N - bx*block_N) — and use those extents when performing the
shared-memory loads/stores (the T.copy calls that move between A, A_shared, and
B) or gate the copies with conditional bounds checks so you only access valid
elements. Ensure any per-CTA temporary mype/A_shared logic still works with
variable tile sizes.

In `@examples/distributed/nvshmem/example_summa.py`:
- Line 18: The call to driver.get_num_sms() in summa() uses the default device 0
and will return the wrong SM count in multi-process runs; modify the summa(...)
function signature to add device_id: int = 0, use that parameter when calling
driver.get_num_sms(device_id) inside summa, and update the call site that
launches summa(...) to pass device_id=LOCAL_RANK so each process queries its
local GPU.

---

Nitpick comments:
In
`@examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py`:
- Line 29: The unpacked tuple and some local variables are flagged unused;
rename unused unpack targets and variables by prefixing them with underscores to
avoid linter warnings: change the unpacking B, _, S, D = src.shape so the unused
element currently named S is prefixed (e.g., B, _, _, D or B, _, _S, D) and
rename the local variables rank and num_ranks to _rank and _num_ranks (or _rank,
_num_ranks) wherever they appear in this file to indicate they are intentionally
unused.
- Line 114: The success print in the example uses an emoji escape sequence in
the statement print(f"rank {local_rank} check passed. \u2705") which is
inconsistent with the other example files; update that print call in
example_post_attn_all2all_transpose_intranode.py (the line containing
print(f"rank {local_rank} check passed. \u2705")) to match the other examples by
removing the emoji (e.g., print(f"rank {local_rank} check passed.") ) or
otherwise using the same formatting used across the other two example files.

In `@examples/distributed/intranode/example_pre_attn_all2all_intranode.py`:
- Line 28: The unpacked variable NH from src.shape and the variables rank and
num_ranks are currently unused; rename them to indicate intentional discard by
prefixing with an underscore (e.g., _NH for the unpack from src.shape and
_rank/_num_ranks where rank and num_ranks are defined) so static analysis won't
flag them as unused while keeping intent clear; update occurrences of those
names accordingly in the functions using B, _NH, _, D, and where rank/num_ranks
are defined so only the used symbols remain unchanged.

In
`@examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py`:
- Line 29: The variable unpacking and some local variables are flagged as
unused; update the tuple unpack and the rank variables to use
underscore-prefixed names to indicate intentional unused values: change the
unpack from "B, _, NH, D = src.shape" to use a prefixed unused name (e.g., "B,
_, _NH, D = src.shape" or "B, _, _, D = src.shape" if NH truly isn't needed) and
rename "rank" and "num_ranks" to "_rank" and "_num_ranks" wherever they are
defined/used in this file (example symbol names: src.shape unpack, rank,
num_ranks) so static analysis no longer reports them as unused.

In `@examples/distributed/nvshmem/example_nvshmem.py`:
- Around line 22-32: The top-level execution that compiles and runs the kernel
is happening at import time; wrap the calls to dist_test(...),
tilelang.compile(...), kernel.get_kernel_source(), kernel.get_profiler(), and
profiler.run_once() inside an if __name__ == "__main__": guard so importing
example_nvshmem.py no longer triggers compilation/profiling, and indent those
statements into that block to preserve behavior when the file is executed as a
script.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b2681ce-9775-448f-b49e-0e692140ef8f

📥 Commits

Reviewing files that changed from the base of the PR and between 5952f00 and 1dd58f6.

📒 Files selected for processing (9)
  • examples/distributed/README.md
  • examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py
  • examples/distributed/intranode/example_pre_attn_all2all_intranode.py
  • examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py
  • examples/distributed/intranode/test_intranode.py
  • examples/distributed/nvshmem/example_all_to_all.py
  • examples/distributed/nvshmem/example_cannon.py
  • examples/distributed/nvshmem/example_nvshmem.py
  • examples/distributed/nvshmem/example_summa.py
💤 Files with no reviewable changes (2)
  • examples/distributed/nvshmem/example_cannon.py
  • examples/distributed/nvshmem/example_all_to_all.py
✅ Files skipped from review due to trivial changes (1)
  • examples/distributed/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/distributed/intranode/test_intranode.py

Comment on lines +11 to +17
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
mype = T.alloc_local([1], "int32")

mype[0] = T.get_pe()
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, B[by * block_M, bx * block_N])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard partial tiles or enforce divisibility before launching ceildiv grid.

Line 11 uses ceil-div launch bounds, but Line 16–Line 17 copy full block_M x block_N tiles unconditionally. For non-divisible M/N, edge CTAs can access out of bounds.

🔧 Proposed fix (safe precondition)
 def dist_test(M, N, block_M, block_N, dtype="int16"):
+    if M % block_M != 0 or N % block_N != 0:
+        raise ValueError(
+            "dist_test currently requires M and N to be divisible by block_M and block_N"
+        )
     `@T.prim_func`
     def main(
         A: T.Buffer((M, N), dtype),
         B: T.Buffer((M, N), dtype),
     ):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/nvshmem/example_nvshmem.py` around lines 11 - 17, The
kernel uses T.ceildiv grid sizing but always copies a full block_M x block_N
tile, which can read/write out-of-bounds for edge CTAs; modify the kernel
(T.Kernel) to compute the actual tile extents using N, M and the tile indices
(bx, by) — e.g., tile_h = min(block_M, M - by*block_M), tile_w = min(block_N, N
- bx*block_N) — and use those extents when performing the shared-memory
loads/stores (the T.copy calls that move between A, A_shared, and B) or gate the
copies with conditional bounds checks so you only access valid elements. Ensure
any per-CTA temporary mype/A_shared logic still works with variable tile sizes.

Comment on lines +29 to +30
profiler = kernel.get_profiler()
out = profiler.run_once()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

run_once() may execute without distributed runtime initialization.

Line 30 directly calls profiler.run_once(). From tilelang/profiler/__init__.py, run_once() just invokes the kernel and does not initialize distributed state. In an NVSHMEM example, this can fail at runtime or produce undefined behavior when T.get_pe() is used.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/nvshmem/example_nvshmem.py` around lines 29 - 30, The
example calls profiler.run_once() (via kernel.get_profiler()) without
initializing the distributed NVSHMEM runtime, which can cause failures when
functions like T.get_pe() are used; fix this by initializing the
NVSHMEM/distributed runtime before calling profiler.run_once() — e.g., invoke
the library's NVSHMEM initialization routine (or kernel.profiler/init method if
available) on all PEs, ensure any required communicator/PE setup is complete,
then call profiler.run_once(), and finally finalize the NVSHMEM runtime after
the run.

accum_dtype = "float32"

sm_num = 132 # 132 SMs for H100
sm_num = driver.get_num_sms()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify API default and current callsites that omit device_id.
# Expected: get_num_sms defaults to 0; this file currently calls without args.

rg -nP --type=py '\bdef\s+get_num_sms\s*\(\s*device_id\s*:\s*int\s*=\s*0' -C2
rg -nP --type=py '\bget_num_sms\s*\(\s*\)' -C2
rg -nP --type=py 'torch\.device\(f"cuda:\{RANK\}"\)|LOCAL_RANK' -C2 examples/distributed/nvshmem/example_summa.py

Repository: tile-ai/tilescale

Length of output: 5369


🏁 Script executed:

# Get full context of summa function definition
head -50 examples/distributed/nvshmem/example_summa.py

# Find all calls to summa function
rg -n 'summa\s*\(' examples/distributed/nvshmem/example_summa.py -A2 -B2

Repository: tile-ai/tilescale

Length of output: 2233


🏁 Script executed:

# Find all call sites of summa function
rg -n 'summa\s*\(' --type=py | head -20

# Check how other distributed examples handle device_id in similar functions
rg -n 'def\s+\w+\(.*device_id' examples/distributed/ --type=py -A1

Repository: tile-ai/tilescale

Length of output: 300


🏁 Script executed:

# Check other distributed examples that call get_num_sms() 
rg -n 'get_num_sms\(\)' examples/distributed/ --type=py -B5 -A5

# Get the full main block context around the summa call to verify LOCAL_RANK scope
sed -n '145,175p' examples/distributed/nvshmem/example_summa.py

Repository: tile-ai/tilescale

Length of output: 3119


Pass LOCAL_RANK to get_num_sms() to query the correct GPU's SM count in multi-process distributed runs.

Line 18 calls driver.get_num_sms() without arguments, defaulting to device_id=0. In multi-process setups, non-zero ranks map to different GPUs and will incorrectly read SM count from device 0 instead of their rank-local device. LOCAL_RANK is available at the call site (line 150) and should be passed through the function.

Add device_id: int = 0 parameter to the summa() function signature, pass it to driver.get_num_sms(device_id) at line 18, and pass device_id=LOCAL_RANK at the call site (line 163).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/nvshmem/example_summa.py` at line 18, The call to
driver.get_num_sms() in summa() uses the default device 0 and will return the
wrong SM count in multi-process runs; modify the summa(...) function signature
to add device_id: int = 0, use that parameter when calling
driver.get_num_sms(device_id) inside summa, and update the call site that
launches summa(...) to pass device_id=LOCAL_RANK so each process queries its
local GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant