Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d21bb20
Initial plan
Copilot Mar 12, 2026
445e4ec
Implement Ring Attention (arxiv:2310.01889)
Copilot Mar 12, 2026
32c79d7
Replace torch.distributed with iris.put for ring KV rotation
Copilot Mar 12, 2026
2eefc41
Fuse K+V put into single kernel; validate on AMD GPUs (16/16 tests pass)
Copilot Mar 12, 2026
17c9856
Add benchmark script with roofline analysis and measured results on M…
Copilot Mar 12, 2026
bbdac9b
Address code review: typed buffer cache, named constants, fix --causa…
Copilot Mar 12, 2026
0f01c45
Add strong/weak scaling benchmark (1-8 GPUs) and update results figure
Copilot Mar 13, 2026
5357874
Optimize ring attention: comm-compute overlap + bug fixes
mawad-amd Mar 23, 2026
7b76443
Apply Ruff auto-fixes
github-actions[bot] Mar 23, 2026
a118644
ring_attention: fuse KV rotation into attention kernel
mawad-amd Mar 23, 2026
647b53d
Apply Ruff auto-fixes
github-actions[bot] Mar 23, 2026
5037b94
ring_attention: add per-step profiling script
mawad-amd Mar 23, 2026
f68bd36
Apply Ruff auto-fixes
github-actions[bot] Mar 23, 2026
6f9fa80
ring_attention: eliminate causal load imbalance
mawad-amd Mar 23, 2026
0430a23
Apply Ruff auto-fixes
github-actions[bot] Mar 23, 2026
061903c
perf: use fp16 MFMA and causal early exit in ring attention kernel
mawad-amd Mar 23, 2026
651d68a
ring_attention: persistent kernel with signal-flag synchronization
mawad-amd Mar 23, 2026
dc7b345
Apply Ruff auto-fixes
github-actions[bot] Mar 23, 2026
8b9ba4c
ring_attention: fix put visibility — use sys scope on completion counter
mawad-amd Mar 23, 2026
929e90c
ring_attention: simplify causal path to match original kernel structure
mawad-amd Mar 23, 2026
cf49885
DEBUG: disable inner-loop skip to test causal mask
mawad-amd Mar 23, 2026
7eb6ab0
fix: pass pre-computed kv_rank_starts array to avoid in-kernel conste…
mawad-amd Mar 23, 2026
6709eb0
Apply Ruff auto-fixes
github-actions[bot] Mar 23, 2026
a6f701e
re-enable causal early-exit optimization for KV blocks
mawad-amd Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/32_ring_attention/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
<!--
SPDX-License-Identifier: MIT
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
-->

# Ring Attention

An implementation of **Ring Attention with Blockwise Transformers** for
near-infinite context on AMD GPUs using [Iris](../../README.md).

> Liu, H., Li, M., Hall, A., Dao, T., & Abbeel, P. (2023).
> *Ring Attention with Blockwise Transformers for Near-Infinite Context.*
> arXiv:2310.01889. <https://arxiv.org/pdf/2310.01889>

---

## Algorithm

Standard self-attention requires O(n²) memory in the sequence length n.
Ring Attention enables sequences far longer than what fits on a single device
by distributing them across a *ring* of GPUs:

1. The full sequence is split evenly across **N GPUs** along the sequence
dimension. Each device holds a chunk of Q, K, and V of length
`seq_total / N`.
2. **Q stays local**. K and V rotate around the ring one step at a time.
3. At each of the **N steps**, every device runs a local
[Flash Attention](https://arxiv.org/abs/2205.14135) pass and accumulates
the result using **online softmax**.
4. After all N steps the accumulator is normalised to yield the final output.

For **causal (autoregressive) attention** only the steps where the KV chunk
precedes or coincides with the Q chunk contribute, allowing early termination
for some ranks and reducing total compute.

```
Step 0: rank r processes its own K_r, V_r (causal block diagonal)
Step 1: rank r receives K_{r-1}, V_{r-1} (full attention, past)
...
Step r: rank r receives K_0, V_0 (full attention, past)
Step r+1..N-1: all-future chunks – skipped (causal mode only)
```

---

## Files

| File | Description |
|------|-------------|
| `ring_attention_kernels.py` | Triton flash-attention kernel + Python ring-rotation helper |
| `ring_attention_layer.py` | `RingAttention` – a `torch.nn.Module` wrapper |
| `example_run.py` | End-to-end demo with timing |

---

## Usage

### Quick demo

```bash
# 2 GPUs, causal attention (default)
python examples/32_ring_attention/example_run.py

# 4 GPUs, bidirectional
python examples/32_ring_attention/example_run.py --num_ranks 4 --no_causal

# Custom sizes
python examples/32_ring_attention/example_run.py \
--num_ranks 8 \
--total_seq_len 131072 \
--num_heads 32 \
--head_dim 128
```

### Validation

```bash
python tests/run_tests_distributed.py tests/examples/test_ring_attention.py --num_ranks 2 -v
```

---

## Python API

```python
import iris
from examples.ring_attention.ring_attention_layer import RingAttention

shmem = iris.iris()

# Each rank holds its local chunk
layer = RingAttention(
shmem,
num_heads=16,
head_dim=64,
causal=True, # autoregressive masking
)

# q, k, v: [seq_local, num_heads, head_dim] (float16 or bfloat16)
output = layer(q, k, v) # [seq_local, num_heads, head_dim]
```

---

## Design Notes

* **Communication**: KV rotation uses `iris.put` Triton kernels — each rank
pushes its K/V chunk directly to the next rank's symmetric heap buffer.
A `shmem.barrier()` after each push ensures all ranks have received the
data before the next attention step proceeds. No `torch.distributed` APIs
are used.
* **Ping-pong buffers**: Two symmetric buffer pairs (`k_ping`/`k_pong` and
`v_ping`/`v_pong`) alternate as source and destination on each step. This
guarantees the source being read and the destination being written are
always different allocations, avoiding any read-after-write hazard.
* **Online softmax**: The kernel maintains running max (`M`) and sum (`L`)
accumulators in float32 for numerical stability. The final output is
`O / L` after all ring steps.
* **Causal masking**: Handled entirely at the granularity of KV *chunks* –
full attention, diagonal block attention, or skip – so the per-element mask
is applied only in the same-block diagonal case. All ranks still
participate in the rotation (required for the barrier to be well-defined).
Loading
Loading