Conversation
📝 WalkthroughWalkthroughIntroduces three new example scripts for distributed tensor operations: an IPC-based all-to-all implementation benchmarked against a PyTorch reference, a sequence-parallel all-to-all attention module with intra-node fused kernel, and supporting TileLang kernel implementations with CUDA async memcpy and barrier synchronization. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (Main Process)
participant LocalRank as Local Rank N
participant PeerRank as Peer Rank (N+i) % size
participant CUDA as CUDA Async Stream
Host->>LocalRank: Allocate distributed peer buffers
Host->>LocalRank: Pack local Q/K/V into unified layout
Host->>CUDA: Submit async 2D memcpy for Q/K/V head slices
LocalRank->>PeerRank: Copy Q/K/V data via ring topology
CUDA->>CUDA: Async transfer across distributed ranks
PeerRank->>LocalRank: Receive packed data into destination buffer
Host->>CUDA: Synchronize stream (await all transfers)
Host->>Host: Extract Q/K/V views from packed destination
Host->>Host: Verify vs PyTorch reference (torch.allclose)
sequenceDiagram
participant Host as Host (Main Process)
participant ComputeStream as Compute Stream
participant AGStream as AG (AllGather) Stream
participant Device as CUDA Device
participant Barrier as Sync Barrier
Host->>ComputeStream: Wait for prior AG stream ops
Host->>AGStream: Self-copy local Q/K/V shard to output
AGStream->>Device: Async 2D memcpy local data
Host->>AGStream: Pull remote K/V token blocks from peers
AGStream->>Device: Async memcpy remote data to local destination
AGStream->>Barrier: Schedule barrier sync on AG stream
Host->>ComputeStream: Launch packed flash-attention kernel
ComputeStream->>Device: Compute attention over packed Q/K/V (shared memory tiling)
ComputeStream->>Barrier: Wait for barrier completion
Device->>Host: Output attention results (per-token, unpadded)
Host->>Host: Verify output vs PyTorch reference
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/distributed/example_pre_attn_all2all_ipc.py`:
- Around line 63-68: The PyTorch path allocates and repacks a2a_output inside
each timed iteration while the IPC path reuses qkv_src_peers/qkv_dst_peers
prepared once and uses a different synchronization scope, so the reported
Speedup mixes unrelated work; to fix, make the two benchmarks do equivalent work
and syncs: either precompute and reuse qkv_src_peers/qkv_dst_peers and allocate
a2a_output once (reuse it across iterations) in the PyTorch path (symbol:
a2a_output) or move the qkv_src_peers/qkv_dst_peers preparation inside the IPC
timed loop so both paths perform the same prep cost (symbols: qkv_src_peers,
qkv_dst_peers); also ensure the CUDA event timing brackets identical work and
place the process-group barrier consistently (either both before or both after
the end event) so timings measure the same scope.
In `@examples/distributed/example_sp_all2all_attention_intra_node.py`:
- Around line 196-207: The code asserts divisibility but never validates that
the declared maxima actually cover the per-batch lengths; before sizing buffers
and creating packed tensors (references: max_seqlen_q, max_seqlen_k, and the
allocation/packing code that uses cu_seqlens_q_list / cu_seqlens_k_list ->
cu_seqlens_q / cu_seqlens_k), add checks that max_seqlen_q >= max(seqlens_q) and
max_seqlen_k >= max(seqlens_k) (or compute max_seqlen_q = max(seqlens_q) /
max_seqlen_k = max(seqlens_k) if maxima come from args) and raise/assert with a
clear message if violated so the subsequent buffer sizing cannot be overrun.
Ensure the assertions are placed before any buffer allocation or use of the
cu_seqlens_* tensors.
- Around line 316-317: The help text for parser.add_argument("--q_head", ...)
and ("--kv_head", ...) is misleading: these arguments are treated as global head
counts and the code divides them by num_ranks immediately, so update the help
strings to indicate they are global counts (or state that they will be divided
by num_ranks), e.g. change "local num q heads per rank" / "local num kv heads
per rank" to "global num q heads (will be divided by num_ranks)" / "global num
kv heads (will be divided by num_ranks)" to match the logic in the code where
q_head and kv_head are divided by num_ranks.
- Around line 136-140: The code currently only checks divisibility by world_size
but not that q_head is an integer multiple of kv_head; before computing
self.q_head_per_rank, self.kv_head_per_rank and self.max_q_shard_len add
explicit guards: assert kv_head > 0 and assert q_head % kv_head == 0 (in
addition to the existing q_head % world_size and kv_head % world_size checks) so
that the derived groups = q_head // kv_head used by the fused kernel is exact
and prevents out-of-bounds packed-buffer indexing.
In `@examples/distributed/sp_all2all_attention_intra_node.py`:
- Around line 111-118: The zig-zag mapping can produce incorrect half/rank
assignment when tiles straddle the sequence midpoint; update the code that
computes sp_block_idx, wait_rank and kv_load_offset inside the T.Pipelined loop
(variables: sp_block_idx, wait_rank, kv_load_offset, q_current_seqlen,
k_current_seqlen, block_M, block_N, num_ranks, enable_zig_zag) to enforce an
explicit tile-alignment contract or handle mapping at tile boundaries: either
(A) add an upfront assertion/guard when enable_zig_zag is true that
q_current_seqlen % (2 * block_M) == 0 and k_current_seqlen % (2 * num_ranks *
block_N) == 0, or (B) change the branch logic to compute half/rank per-tile
(using tile-aligned indices derived from bx and k) so no tile can span the
midpoint; pick one approach and apply the same fix for the other occurrence
around lines 276-280.
- Around line 460-467: The packer is using K-length strides so Q gets written to
the wrong block when Q and K seqlens differ; update the Q-pack path to use
Q-specific lengths/starts (e.g., use seqlen_q and cu_seqlens_q_start or the
q_current_seqlen-based stride) when computing
local_seq_len/src_token_start/dst_token_start for packed_shards and
packed_buffers (instead of seqlen_k/cu_seqlens_k_start), and apply the same fix
to the analogous block around the other occurrence noted (the chunk at 483-491)
so flashattn_packed reads the correct rank-local Q slice.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 10bea30e-8b0e-4312-a5fb-7fb994a72cf3
📒 Files selected for processing (3)
examples/distributed/example_pre_attn_all2all_ipc.pyexamples/distributed/example_sp_all2all_attention_intra_node.pyexamples/distributed/sp_all2all_attention_intra_node.py
| a2a_output = torch.empty( | ||
| (world_size, a2a_heads // world_size, a2a_seq_per_pe, a2a_batch, a2a_head_dim), | ||
| dtype=a2a_input.dtype, | ||
| device=a2a_input.device, | ||
| requires_grad=False, | ||
| ) |
There was a problem hiding this comment.
The reported speedup is not a like-for-like benchmark.
The PyTorch path allocates fresh receive tensors on Lines 63-68 and repacks/permutates inside every timed iteration, while the IPC path reuses qkv_src_peers / qkv_dst_peers prepared once on Lines 510-520. The IPC loop also times a different synchronization scope because it includes a post-copy process-group barrier before recording the end event. The final Speedup therefore mixes different work rather than comparing transport cost.
Also applies to: 510-520, 568-628
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/example_pre_attn_all2all_ipc.py` around lines 63 - 68,
The PyTorch path allocates and repacks a2a_output inside each timed iteration
while the IPC path reuses qkv_src_peers/qkv_dst_peers prepared once and uses a
different synchronization scope, so the reported Speedup mixes unrelated work;
to fix, make the two benchmarks do equivalent work and syncs: either precompute
and reuse qkv_src_peers/qkv_dst_peers and allocate a2a_output once (reuse it
across iterations) in the PyTorch path (symbol: a2a_output) or move the
qkv_src_peers/qkv_dst_peers preparation inside the IPC timed loop so both paths
perform the same prep cost (symbols: qkv_src_peers, qkv_dst_peers); also ensure
the CUDA event timing brackets identical work and place the process-group
barrier consistently (either both before or both after the end event) so timings
measure the same scope.
| assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size" | ||
| assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size" | ||
| self.q_head_per_rank = self.q_head // self.world_size | ||
| self.kv_head_per_rank = self.kv_head // self.world_size | ||
| self.max_q_shard_len = self.max_seqlen_q // self.world_size |
There was a problem hiding this comment.
Reject non-integer GQA ratios here.
The class only validates divisibility by world_size, but the fused kernel later derives groups = q_head // kv_head and assumes that ratio is exact. Inputs such as q_head=12, kv_head=8, world_size=2 pass these checks, then make the packed-buffer indexing larger than the KV slice that was allocated. Add an explicit kv_head > 0 / q_head % kv_head == 0 guard before computing per-rank heads.
Suggested guard
assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size"
assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size"
+ assert self.kv_head > 0 and self.q_head % self.kv_head == 0, "q_head should be divisible by kv_head"
self.q_head_per_rank = self.q_head // self.world_size
self.kv_head_per_rank = self.kv_head // self.world_size📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size" | |
| assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size" | |
| self.q_head_per_rank = self.q_head // self.world_size | |
| self.kv_head_per_rank = self.kv_head // self.world_size | |
| self.max_q_shard_len = self.max_seqlen_q // self.world_size | |
| assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size" | |
| assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size" | |
| assert self.kv_head > 0 and self.q_head % self.kv_head == 0, "q_head should be divisible by kv_head" | |
| self.q_head_per_rank = self.q_head // self.world_size | |
| self.kv_head_per_rank = self.kv_head // self.world_size | |
| self.max_q_shard_len = self.max_seqlen_q // self.world_size |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/example_sp_all2all_attention_intra_node.py` around lines
136 - 140, The code currently only checks divisibility by world_size but not
that q_head is an integer multiple of kv_head; before computing
self.q_head_per_rank, self.kv_head_per_rank and self.max_q_shard_len add
explicit guards: assert kv_head > 0 and assert q_head % kv_head == 0 (in
addition to the existing q_head % world_size and kv_head % world_size checks) so
that the derived groups = q_head // kv_head used by the fused kernel is exact
and prevents out-of-bounds packed-buffer indexing.
| seqlens_q = args.seqlens_q | ||
| seqlens_k = args.seqlens_k | ||
| assert len(seqlens_q) == batch_size and len(seqlens_k) == batch_size | ||
| assert q_head % num_ranks == 0, "q_head should be divisible by world size" | ||
| assert kv_head % num_ranks == 0, "kv_head should be divisible by world size" | ||
| for s in seqlens_q + seqlens_k: | ||
| assert s % num_ranks == 0, "all2all requires per-batch sequence length divisible by world size" | ||
|
|
||
| cu_seqlens_q_list = [0] + list(accumulate(seqlens_q)) | ||
| cu_seqlens_k_list = [0] + list(accumulate(seqlens_k)) | ||
| cu_seqlens_q = torch.tensor(cu_seqlens_q_list, dtype=torch.int32, device=device) // num_ranks | ||
| cu_seqlens_k = torch.tensor(cu_seqlens_k_list, dtype=torch.int32, device=device) |
There was a problem hiding this comment.
Validate the declared maxima before sizing buffers from them.
The code only checks divisibility, but Lines 224-233 size the packed buffers and synthetic inputs from max_seqlen_q / max_seqlen_k. If either max is smaller than an entry in seqlens_q or seqlens_k, cu_seqlens_* will drive the copy/kernel path past the allocated rows.
Suggested bounds checks
seqlens_q = args.seqlens_q
seqlens_k = args.seqlens_k
assert len(seqlens_q) == batch_size and len(seqlens_k) == batch_size
+ assert max(seqlens_q) <= max_seqlen_q, "max_seqlen_q must cover all query sequence lengths"
+ assert max(seqlens_k) <= max_seqlen_k, "max_seqlen_k must cover all key/value sequence lengths"
assert q_head % num_ranks == 0, "q_head should be divisible by world size"
assert kv_head % num_ranks == 0, "kv_head should be divisible by world size"Also applies to: 224-233
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 204-204: Consider [0, *list(accumulate(seqlens_q))] instead of concatenation
Replace with [0, *list(accumulate(seqlens_q))]
(RUF005)
[warning] 205-205: Consider [0, *list(accumulate(seqlens_k))] instead of concatenation
Replace with [0, *list(accumulate(seqlens_k))]
(RUF005)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/example_sp_all2all_attention_intra_node.py` around lines
196 - 207, The code asserts divisibility but never validates that the declared
maxima actually cover the per-batch lengths; before sizing buffers and creating
packed tensors (references: max_seqlen_q, max_seqlen_k, and the
allocation/packing code that uses cu_seqlens_q_list / cu_seqlens_k_list ->
cu_seqlens_q / cu_seqlens_k), add checks that max_seqlen_q >= max(seqlens_q) and
max_seqlen_k >= max(seqlens_k) (or compute max_seqlen_q = max(seqlens_q) /
max_seqlen_k = max(seqlens_k) if maxima come from args) and raise/assert with a
clear message if violated so the subsequent buffer sizing cannot be overrun.
Ensure the assertions are placed before any buffer allocation or use of the
cu_seqlens_* tensors.
| parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank") | ||
| parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank") |
There was a problem hiding this comment.
The head-count help text is backwards.
The parser says --q_head and --kv_head are local counts per rank, but the code immediately divides both by num_ranks. Following the current help text makes the example run with only 1 / num_ranks of the intended global heads.
Suggested wording
- parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank")
- parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank")
+ parser.add_argument("--q_head", type=int, default=32, help="global number of Q heads")
+ parser.add_argument("--kv_head", type=int, default=8, help="global number of KV heads")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank") | |
| parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank") | |
| parser.add_argument("--q_head", type=int, default=32, help="global number of Q heads") | |
| parser.add_argument("--kv_head", type=int, default=8, help="global number of KV heads") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/example_sp_all2all_attention_intra_node.py` around lines
316 - 317, The help text for parser.add_argument("--q_head", ...) and
("--kv_head", ...) is misleading: these arguments are treated as global head
counts and the code divides them by num_ranks immediately, so update the help
strings to indicate they are global counts (or state that they will be divided
by num_ranks), e.g. change "local num q heads per rank" / "local num kv heads
per rank" to "global num q heads (will be divided by num_ranks)" / "global num
kv heads (will be divided by num_ranks)" to match the logic in the code where
q_head and kv_head are divided by num_ranks.
| for k in T.Pipelined(loop_range, num_stages=num_stages): | ||
| sp_block_idx = (k * block_N) // kv_len_per_sp_block | ||
| wait_rank = sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1 | ||
| kv_load_offset = ( | ||
| (k * block_N) % kv_len_per_sp_block | ||
| + sp_block_idx // num_ranks * kv_len_per_sp_block | ||
| + wait_rank * (k_current_seqlen // num_ranks) | ||
| ) |
There was a problem hiding this comment.
enable_zig_zag needs an explicit tile-alignment contract.
The zig-zag path picks the Q-side offset once per bx and the KV source-rank mapping once per k. If q_current_seqlen % (2 * block_M) != 0 or k_current_seqlen % (2 * num_ranks * block_N) != 0, one tile straddles the midpoint and applies the wrong half/rank mapping to part of the block. Please either enforce that alignment up front or compute the zig-zag branch at row/tile-boundary granularity.
Also applies to: 276-280
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/sp_all2all_attention_intra_node.py` around lines 111 -
118, The zig-zag mapping can produce incorrect half/rank assignment when tiles
straddle the sequence midpoint; update the code that computes sp_block_idx,
wait_rank and kv_load_offset inside the T.Pipelined loop (variables:
sp_block_idx, wait_rank, kv_load_offset, q_current_seqlen, k_current_seqlen,
block_M, block_N, num_ranks, enable_zig_zag) to enforce an explicit
tile-alignment contract or handle mapping at tile boundaries: either (A) add an
upfront assertion/guard when enable_zig_zag is true that q_current_seqlen % (2 *
block_M) == 0 and k_current_seqlen % (2 * num_ranks * block_N) == 0, or (B)
change the branch logic to compute half/rank per-tile (using tile-aligned
indices derived from bx and k) so no tile can span the midpoint; pick one
approach and apply the same fix for the other occurrence around lines 276-280.
| local_seq_len = seqlen_k // world_size | ||
| src_token_start = cu_seqlens_k_start // world_size | ||
| dst_token_start = cu_seqlens_k_start + rank * local_seq_len | ||
|
|
||
| src_head_offset_bytes = rank * packed_heads_per_rank * head_dim * dtype_itemsize | ||
| src_ptr = packed_shards[rank].data_ptr() + src_token_start * src_token_bytes + src_head_offset_bytes | ||
| dst_ptr = packed_buffers[rank].data_ptr() + dst_token_start * dst_token_bytes | ||
| _cp_engine_copy_2d( |
There was a problem hiding this comment.
Pack Q with the Q shard length, not the K shard length.
flashattn_packed later reads Q from the rank-local block at Lines 91-99 using rank * q_current_seqlen, but this producer writes every packed slice at rank * (seqlen_k // world_size). As soon as the per-batch Q and K lengths differ, each rank starts reading Q from the wrong sequence block and the fused result is corrupted. Either give Q its own destination stride based on cu_seqlens_q, or explicitly reject mixed Q/K sequence lengths for this path.
Also applies to: 483-491
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/distributed/sp_all2all_attention_intra_node.py` around lines 460 -
467, The packer is using K-length strides so Q gets written to the wrong block
when Q and K seqlens differ; update the Q-pack path to use Q-specific
lengths/starts (e.g., use seqlen_q and cu_seqlens_q_start or the
q_current_seqlen-based stride) when computing
local_seq_len/src_token_start/dst_token_start for packed_shards and
packed_buffers (instead of seqlen_k/cu_seqlens_k_start), and apply the same fix
to the analogous block around the other occurrence noted (the chunk at 483-491)
so flashattn_packed reads the correct rank-local Q slice.
Summary by CodeRabbit