Skip to content

[Example] Add Ulysses SP example#55

Open
zyy3077 wants to merge 6 commits intotile-ai:mainfrom
zyy3077:zyy/dev
Open

[Example] Add Ulysses SP example#55
zyy3077 wants to merge 6 commits intotile-ai:mainfrom
zyy3077:zyy/dev

Conversation

@zyy3077
Copy link
Copy Markdown

@zyy3077 zyy3077 commented Mar 28, 2026

  • ipc-based implementation of pre-attention all2all
  • input/output dimension and profile
  • fuse computation with all2all

Summary by CodeRabbit

  • New Features
    • Added new distributed attention example scripts showcasing pre-attention all-to-all operations using custom CUDA/IPC implementations with golden reference verification.
    • Added intra-node fused sequence-parallel all-to-all attention example with optimized kernel execution and performance benchmarking utilities.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 28, 2026

📝 Walkthrough

Walkthrough

Introduces three new example scripts for distributed tensor operations: an IPC-based all-to-all implementation benchmarked against a PyTorch reference, a sequence-parallel all-to-all attention module with intra-node fused kernel, and supporting TileLang kernel implementations with CUDA async memcpy and barrier synchronization.

Changes

Cohort / File(s) Summary
Pre-attention All-to-All IPC
examples/distributed/example_pre_attn_all2all_ipc.py
Introduces benchmark script comparing custom CUDA/IPC all-to-all implementation with PyTorch reference. Includes packing/unpacking routines for Q/K/V tensors, async memcpy-based data movement across distributed peers, packed layout optimization for fused transfers, and verification via tensor comparison with detailed diagnostics.
Sequence-Parallel All-to-All Attention
examples/distributed/example_sp_all2all_attention_intra_node.py, examples/distributed/sp_all2all_attention_intra_node.py
Adds intra-node fused SP all-to-all attention implementation and validation. Includes TileLang JIT kernels (packed flash-attention with zigzag support, barrier synchronization), distributed context dataclass with peer buffers/output/AG stream, async memcpy producer for cross-rank data movement, and orchestration routine that coordinates async AG stream operations with compute stream attention computation.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host (Main Process)
    participant LocalRank as Local Rank N
    participant PeerRank as Peer Rank (N+i) % size
    participant CUDA as CUDA Async Stream

    Host->>LocalRank: Allocate distributed peer buffers
    Host->>LocalRank: Pack local Q/K/V into unified layout
    Host->>CUDA: Submit async 2D memcpy for Q/K/V head slices
    LocalRank->>PeerRank: Copy Q/K/V data via ring topology
    CUDA->>CUDA: Async transfer across distributed ranks
    PeerRank->>LocalRank: Receive packed data into destination buffer
    Host->>CUDA: Synchronize stream (await all transfers)
    Host->>Host: Extract Q/K/V views from packed destination
    Host->>Host: Verify vs PyTorch reference (torch.allclose)
Loading
sequenceDiagram
    participant Host as Host (Main Process)
    participant ComputeStream as Compute Stream
    participant AGStream as AG (AllGather) Stream
    participant Device as CUDA Device
    participant Barrier as Sync Barrier

    Host->>ComputeStream: Wait for prior AG stream ops
    Host->>AGStream: Self-copy local Q/K/V shard to output
    AGStream->>Device: Async 2D memcpy local data
    Host->>AGStream: Pull remote K/V token blocks from peers
    AGStream->>Device: Async memcpy remote data to local destination
    AGStream->>Barrier: Schedule barrier sync on AG stream
    Host->>ComputeStream: Launch packed flash-attention kernel
    ComputeStream->>Device: Compute attention over packed Q/K/V (shared memory tiling)
    ComputeStream->>Barrier: Wait for barrier completion
    Device->>Host: Output attention results (per-token, unpadded)
    Host->>Host: Verify output vs PyTorch reference
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • [Feat]support sp all2all without transpose example #16: Adds/modifies sequence-parallel and pre-attention all-to-all reference implementations with similarly named utility functions (e.g., torch_pre_attn_qkv_a2a_reference) and distributed process group initialization patterns.

Poem

🐰 Hoppy hops through distributed lanes,
All-to-all data on CUDA planes,
Packed layouts dance, barriers sync tight,
IPC whispers in the night—
Flash-attention batches take flight! ✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.13% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title '[Example] Add Ulysses SP example' is vague and generic, using non-specific terminology like 'example' and 'Ulysses SP' without clearly conveying what the actual changes accomplish. Consider a more specific title that describes the concrete changes, such as '[Example] Add sequence-parallel all-to-all attention with IPC implementation' to better reflect the actual content of the PR.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@zyy3077 zyy3077 marked this pull request as ready for review April 7, 2026 05:50
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/distributed/example_pre_attn_all2all_ipc.py`:
- Around line 63-68: The PyTorch path allocates and repacks a2a_output inside
each timed iteration while the IPC path reuses qkv_src_peers/qkv_dst_peers
prepared once and uses a different synchronization scope, so the reported
Speedup mixes unrelated work; to fix, make the two benchmarks do equivalent work
and syncs: either precompute and reuse qkv_src_peers/qkv_dst_peers and allocate
a2a_output once (reuse it across iterations) in the PyTorch path (symbol:
a2a_output) or move the qkv_src_peers/qkv_dst_peers preparation inside the IPC
timed loop so both paths perform the same prep cost (symbols: qkv_src_peers,
qkv_dst_peers); also ensure the CUDA event timing brackets identical work and
place the process-group barrier consistently (either both before or both after
the end event) so timings measure the same scope.

In `@examples/distributed/example_sp_all2all_attention_intra_node.py`:
- Around line 196-207: The code asserts divisibility but never validates that
the declared maxima actually cover the per-batch lengths; before sizing buffers
and creating packed tensors (references: max_seqlen_q, max_seqlen_k, and the
allocation/packing code that uses cu_seqlens_q_list / cu_seqlens_k_list ->
cu_seqlens_q / cu_seqlens_k), add checks that max_seqlen_q >= max(seqlens_q) and
max_seqlen_k >= max(seqlens_k) (or compute max_seqlen_q = max(seqlens_q) /
max_seqlen_k = max(seqlens_k) if maxima come from args) and raise/assert with a
clear message if violated so the subsequent buffer sizing cannot be overrun.
Ensure the assertions are placed before any buffer allocation or use of the
cu_seqlens_* tensors.
- Around line 316-317: The help text for parser.add_argument("--q_head", ...)
and ("--kv_head", ...) is misleading: these arguments are treated as global head
counts and the code divides them by num_ranks immediately, so update the help
strings to indicate they are global counts (or state that they will be divided
by num_ranks), e.g. change "local num q heads per rank" / "local num kv heads
per rank" to "global num q heads (will be divided by num_ranks)" / "global num
kv heads (will be divided by num_ranks)" to match the logic in the code where
q_head and kv_head are divided by num_ranks.
- Around line 136-140: The code currently only checks divisibility by world_size
but not that q_head is an integer multiple of kv_head; before computing
self.q_head_per_rank, self.kv_head_per_rank and self.max_q_shard_len add
explicit guards: assert kv_head > 0 and assert q_head % kv_head == 0 (in
addition to the existing q_head % world_size and kv_head % world_size checks) so
that the derived groups = q_head // kv_head used by the fused kernel is exact
and prevents out-of-bounds packed-buffer indexing.

In `@examples/distributed/sp_all2all_attention_intra_node.py`:
- Around line 111-118: The zig-zag mapping can produce incorrect half/rank
assignment when tiles straddle the sequence midpoint; update the code that
computes sp_block_idx, wait_rank and kv_load_offset inside the T.Pipelined loop
(variables: sp_block_idx, wait_rank, kv_load_offset, q_current_seqlen,
k_current_seqlen, block_M, block_N, num_ranks, enable_zig_zag) to enforce an
explicit tile-alignment contract or handle mapping at tile boundaries: either
(A) add an upfront assertion/guard when enable_zig_zag is true that
q_current_seqlen % (2 * block_M) == 0 and k_current_seqlen % (2 * num_ranks *
block_N) == 0, or (B) change the branch logic to compute half/rank per-tile
(using tile-aligned indices derived from bx and k) so no tile can span the
midpoint; pick one approach and apply the same fix for the other occurrence
around lines 276-280.
- Around line 460-467: The packer is using K-length strides so Q gets written to
the wrong block when Q and K seqlens differ; update the Q-pack path to use
Q-specific lengths/starts (e.g., use seqlen_q and cu_seqlens_q_start or the
q_current_seqlen-based stride) when computing
local_seq_len/src_token_start/dst_token_start for packed_shards and
packed_buffers (instead of seqlen_k/cu_seqlens_k_start), and apply the same fix
to the analogous block around the other occurrence noted (the chunk at 483-491)
so flashattn_packed reads the correct rank-local Q slice.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 10bea30e-8b0e-4312-a5fb-7fb994a72cf3

📥 Commits

Reviewing files that changed from the base of the PR and between 4704282 and d3fb5f2.

📒 Files selected for processing (3)
  • examples/distributed/example_pre_attn_all2all_ipc.py
  • examples/distributed/example_sp_all2all_attention_intra_node.py
  • examples/distributed/sp_all2all_attention_intra_node.py

Comment on lines +63 to +68
a2a_output = torch.empty(
(world_size, a2a_heads // world_size, a2a_seq_per_pe, a2a_batch, a2a_head_dim),
dtype=a2a_input.dtype,
device=a2a_input.device,
requires_grad=False,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The reported speedup is not a like-for-like benchmark.

The PyTorch path allocates fresh receive tensors on Lines 63-68 and repacks/permutates inside every timed iteration, while the IPC path reuses qkv_src_peers / qkv_dst_peers prepared once on Lines 510-520. The IPC loop also times a different synchronization scope because it includes a post-copy process-group barrier before recording the end event. The final Speedup therefore mixes different work rather than comparing transport cost.

Also applies to: 510-520, 568-628

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/example_pre_attn_all2all_ipc.py` around lines 63 - 68,
The PyTorch path allocates and repacks a2a_output inside each timed iteration
while the IPC path reuses qkv_src_peers/qkv_dst_peers prepared once and uses a
different synchronization scope, so the reported Speedup mixes unrelated work;
to fix, make the two benchmarks do equivalent work and syncs: either precompute
and reuse qkv_src_peers/qkv_dst_peers and allocate a2a_output once (reuse it
across iterations) in the PyTorch path (symbol: a2a_output) or move the
qkv_src_peers/qkv_dst_peers preparation inside the IPC timed loop so both paths
perform the same prep cost (symbols: qkv_src_peers, qkv_dst_peers); also ensure
the CUDA event timing brackets identical work and place the process-group
barrier consistently (either both before or both after the end event) so timings
measure the same scope.

Comment on lines +136 to +140
assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size"
assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size"
self.q_head_per_rank = self.q_head // self.world_size
self.kv_head_per_rank = self.kv_head // self.world_size
self.max_q_shard_len = self.max_seqlen_q // self.world_size
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject non-integer GQA ratios here.

The class only validates divisibility by world_size, but the fused kernel later derives groups = q_head // kv_head and assumes that ratio is exact. Inputs such as q_head=12, kv_head=8, world_size=2 pass these checks, then make the packed-buffer indexing larger than the KV slice that was allocated. Add an explicit kv_head > 0 / q_head % kv_head == 0 guard before computing per-rank heads.

Suggested guard
         assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size"
         assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size"
+        assert self.kv_head > 0 and self.q_head % self.kv_head == 0, "q_head should be divisible by kv_head"
         self.q_head_per_rank = self.q_head // self.world_size
         self.kv_head_per_rank = self.kv_head // self.world_size
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size"
assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size"
self.q_head_per_rank = self.q_head // self.world_size
self.kv_head_per_rank = self.kv_head // self.world_size
self.max_q_shard_len = self.max_seqlen_q // self.world_size
assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size"
assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size"
assert self.kv_head > 0 and self.q_head % self.kv_head == 0, "q_head should be divisible by kv_head"
self.q_head_per_rank = self.q_head // self.world_size
self.kv_head_per_rank = self.kv_head // self.world_size
self.max_q_shard_len = self.max_seqlen_q // self.world_size
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/example_sp_all2all_attention_intra_node.py` around lines
136 - 140, The code currently only checks divisibility by world_size but not
that q_head is an integer multiple of kv_head; before computing
self.q_head_per_rank, self.kv_head_per_rank and self.max_q_shard_len add
explicit guards: assert kv_head > 0 and assert q_head % kv_head == 0 (in
addition to the existing q_head % world_size and kv_head % world_size checks) so
that the derived groups = q_head // kv_head used by the fused kernel is exact
and prevents out-of-bounds packed-buffer indexing.

Comment on lines +196 to +207
seqlens_q = args.seqlens_q
seqlens_k = args.seqlens_k
assert len(seqlens_q) == batch_size and len(seqlens_k) == batch_size
assert q_head % num_ranks == 0, "q_head should be divisible by world size"
assert kv_head % num_ranks == 0, "kv_head should be divisible by world size"
for s in seqlens_q + seqlens_k:
assert s % num_ranks == 0, "all2all requires per-batch sequence length divisible by world size"

cu_seqlens_q_list = [0] + list(accumulate(seqlens_q))
cu_seqlens_k_list = [0] + list(accumulate(seqlens_k))
cu_seqlens_q = torch.tensor(cu_seqlens_q_list, dtype=torch.int32, device=device) // num_ranks
cu_seqlens_k = torch.tensor(cu_seqlens_k_list, dtype=torch.int32, device=device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Validate the declared maxima before sizing buffers from them.

The code only checks divisibility, but Lines 224-233 size the packed buffers and synthetic inputs from max_seqlen_q / max_seqlen_k. If either max is smaller than an entry in seqlens_q or seqlens_k, cu_seqlens_* will drive the copy/kernel path past the allocated rows.

Suggested bounds checks
         seqlens_q = args.seqlens_q
         seqlens_k = args.seqlens_k
         assert len(seqlens_q) == batch_size and len(seqlens_k) == batch_size
+        assert max(seqlens_q) <= max_seqlen_q, "max_seqlen_q must cover all query sequence lengths"
+        assert max(seqlens_k) <= max_seqlen_k, "max_seqlen_k must cover all key/value sequence lengths"
         assert q_head % num_ranks == 0, "q_head should be divisible by world size"
         assert kv_head % num_ranks == 0, "kv_head should be divisible by world size"

Also applies to: 224-233

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 204-204: Consider [0, *list(accumulate(seqlens_q))] instead of concatenation

Replace with [0, *list(accumulate(seqlens_q))]

(RUF005)


[warning] 205-205: Consider [0, *list(accumulate(seqlens_k))] instead of concatenation

Replace with [0, *list(accumulate(seqlens_k))]

(RUF005)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/example_sp_all2all_attention_intra_node.py` around lines
196 - 207, The code asserts divisibility but never validates that the declared
maxima actually cover the per-batch lengths; before sizing buffers and creating
packed tensors (references: max_seqlen_q, max_seqlen_k, and the
allocation/packing code that uses cu_seqlens_q_list / cu_seqlens_k_list ->
cu_seqlens_q / cu_seqlens_k), add checks that max_seqlen_q >= max(seqlens_q) and
max_seqlen_k >= max(seqlens_k) (or compute max_seqlen_q = max(seqlens_q) /
max_seqlen_k = max(seqlens_k) if maxima come from args) and raise/assert with a
clear message if violated so the subsequent buffer sizing cannot be overrun.
Ensure the assertions are placed before any buffer allocation or use of the
cu_seqlens_* tensors.

Comment on lines +316 to +317
parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank")
parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

The head-count help text is backwards.

The parser says --q_head and --kv_head are local counts per rank, but the code immediately divides both by num_ranks. Following the current help text makes the example run with only 1 / num_ranks of the intended global heads.

Suggested wording
-    parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank")
-    parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank")
+    parser.add_argument("--q_head", type=int, default=32, help="global number of Q heads")
+    parser.add_argument("--kv_head", type=int, default=8, help="global number of KV heads")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank")
parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank")
parser.add_argument("--q_head", type=int, default=32, help="global number of Q heads")
parser.add_argument("--kv_head", type=int, default=8, help="global number of KV heads")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/example_sp_all2all_attention_intra_node.py` around lines
316 - 317, The help text for parser.add_argument("--q_head", ...) and
("--kv_head", ...) is misleading: these arguments are treated as global head
counts and the code divides them by num_ranks immediately, so update the help
strings to indicate they are global counts (or state that they will be divided
by num_ranks), e.g. change "local num q heads per rank" / "local num kv heads
per rank" to "global num q heads (will be divided by num_ranks)" / "global num
kv heads (will be divided by num_ranks)" to match the logic in the code where
q_head and kv_head are divided by num_ranks.

Comment on lines +111 to +118
for k in T.Pipelined(loop_range, num_stages=num_stages):
sp_block_idx = (k * block_N) // kv_len_per_sp_block
wait_rank = sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1
kv_load_offset = (
(k * block_N) % kv_len_per_sp_block
+ sp_block_idx // num_ranks * kv_len_per_sp_block
+ wait_rank * (k_current_seqlen // num_ranks)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

enable_zig_zag needs an explicit tile-alignment contract.

The zig-zag path picks the Q-side offset once per bx and the KV source-rank mapping once per k. If q_current_seqlen % (2 * block_M) != 0 or k_current_seqlen % (2 * num_ranks * block_N) != 0, one tile straddles the midpoint and applies the wrong half/rank mapping to part of the block. Please either enforce that alignment up front or compute the zig-zag branch at row/tile-boundary granularity.

Also applies to: 276-280

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/sp_all2all_attention_intra_node.py` around lines 111 -
118, The zig-zag mapping can produce incorrect half/rank assignment when tiles
straddle the sequence midpoint; update the code that computes sp_block_idx,
wait_rank and kv_load_offset inside the T.Pipelined loop (variables:
sp_block_idx, wait_rank, kv_load_offset, q_current_seqlen, k_current_seqlen,
block_M, block_N, num_ranks, enable_zig_zag) to enforce an explicit
tile-alignment contract or handle mapping at tile boundaries: either (A) add an
upfront assertion/guard when enable_zig_zag is true that q_current_seqlen % (2 *
block_M) == 0 and k_current_seqlen % (2 * num_ranks * block_N) == 0, or (B)
change the branch logic to compute half/rank per-tile (using tile-aligned
indices derived from bx and k) so no tile can span the midpoint; pick one
approach and apply the same fix for the other occurrence around lines 276-280.

Comment on lines +460 to +467
local_seq_len = seqlen_k // world_size
src_token_start = cu_seqlens_k_start // world_size
dst_token_start = cu_seqlens_k_start + rank * local_seq_len

src_head_offset_bytes = rank * packed_heads_per_rank * head_dim * dtype_itemsize
src_ptr = packed_shards[rank].data_ptr() + src_token_start * src_token_bytes + src_head_offset_bytes
dst_ptr = packed_buffers[rank].data_ptr() + dst_token_start * dst_token_bytes
_cp_engine_copy_2d(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pack Q with the Q shard length, not the K shard length.

flashattn_packed later reads Q from the rank-local block at Lines 91-99 using rank * q_current_seqlen, but this producer writes every packed slice at rank * (seqlen_k // world_size). As soon as the per-batch Q and K lengths differ, each rank starts reading Q from the wrong sequence block and the fused result is corrupted. Either give Q its own destination stride based on cu_seqlens_q, or explicitly reject mixed Q/K sequence lengths for this path.

Also applies to: 483-491

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/distributed/sp_all2all_attention_intra_node.py` around lines 460 -
467, The packer is using K-length strides so Q gets written to the wrong block
when Q and K seqlens differ; update the Q-pack path to use Q-specific
lengths/starts (e.g., use seqlen_q and cu_seqlens_q_start or the
q_current_seqlen-based stride) when computing
local_seq_len/src_token_start/dst_token_start for packed_shards and
packed_buffers (instead of seqlen_k/cu_seqlens_k_start), and apply the same fix
to the analogous block around the other occurrence noted (the chunk at 483-491)
so flashattn_packed reads the correct rank-local Q slice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant