Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d44394c
wip back of sdma integration
dsidler Nov 6, 2025
c50e761
Apply Ruff auto-fixes
github-actions[bot] Nov 6, 2025
2f7bc5e
message passing example working
dsidler Nov 6, 2025
5e38fd6
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Nov 6, 2025
759f662
Apply Ruff auto-fixes
github-actions[bot] Nov 6, 2025
ad7769d
update put example to use ce
dsidler Nov 7, 2025
b8862cc
update api calls
dsidler Nov 7, 2025
75c5626
update submodule
dsidler Nov 7, 2025
2b228ab
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Nov 7, 2025
e3aef16
fix merge
dsidler Nov 7, 2025
df04547
Apply Ruff auto-fixes
github-actions[bot] Nov 7, 2025
c5e4735
wip fixed wrap into ring when placing
dsidler Dec 5, 2025
ea17dd6
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Dec 5, 2025
5362318
to_rank 7 working
dsidler Dec 5, 2025
a6b1d40
Apply Ruff auto-fixes
github-actions[bot] Dec 10, 2025
224511f
Merge branch 'main' into dev/dasidler/sdma
dsidler Jan 14, 2026
400b5b7
use triton commit with fix
dsidler Jan 14, 2026
d06cb72
Apply Ruff auto-fixes
github-actions[bot] Jan 14, 2026
b2e358b
send to all ranks but always same stride
dsidler Jan 20, 2026
b245899
update submodule
dsidler Jan 20, 2026
0e7fbd6
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Jan 20, 2026
1ee4c58
use 32B copy packets workaround
dsidler Jan 30, 2026
1c384c3
submodule update
dsidler Jan 30, 2026
0224866
use window command
dsidler Mar 4, 2026
40c228a
use new acquire function
dsidler Mar 5, 2026
34d4ffc
update submodule
dsidler Mar 5, 2026
c8d4b46
Apply Ruff auto-fixes
github-actions[bot] Mar 5, 2026
53f1a20
move padding code
dsidler Mar 5, 2026
099a84c
update submodule for nop packet
dsidler Mar 5, 2026
75b55b2
enable flat copy
dsidler Mar 5, 2026
17d0696
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Mar 5, 2026
e5a38dd
Apply Ruff auto-fixes
github-actions[bot] Mar 5, 2026
02d08c9
Merge branch 'main' into dev/dasidler/sdma
dsidler Mar 5, 2026
0b6ff1a
clean up
dsidler Mar 5, 2026
bfe4548
add copy engine support to fused gemm-allscatter
dsidler Mar 18, 2026
27040c8
Apply Ruff auto-fixes
github-actions[bot] Mar 18, 2026
bf55b6d
switch to acquire_fadd
dsidler Mar 24, 2026
aef1411
update submodule
dsidler Mar 24, 2026
2cea9f7
initial host initiated sdma
dsidler Mar 24, 2026
53bfeaa
refactor&cleanup
dsidler Mar 24, 2026
831be93
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Mar 24, 2026
d191276
Apply Ruff auto-fixes
github-actions[bot] Mar 24, 2026
7374413
Merge branch 'main' into dev/dasidler/sdma
dsidler Apr 20, 2026
53bbd2f
importing host-initiated changes
dsidler Apr 20, 2026
fc6814c
Apply Ruff auto-fixes
github-actions[bot] Apr 20, 2026
2ad3777
update message passing tests
dsidler Apr 22, 2026
6095aed
adding unit tests, zero size failing atm
dsidler Apr 23, 2026
cf789c1
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Apr 23, 2026
e973978
Apply Ruff auto-fixes
github-actions[bot] Apr 23, 2026
72bb2e8
cleanup
dsidler Apr 23, 2026
0ad4625
add atomic_cas copy engine support
dsidler Apr 23, 2026
9c74f88
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Apr 23, 2026
1859a88
Apply Ruff auto-fixes
github-actions[bot] Apr 23, 2026
cfb37b6
address todo, import constants
dsidler Apr 23, 2026
d873234
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "ext/shader_sdma"]
path = ext/shader_sdma
url = https://github.com/AARInternal/shader_sdma.git
261 changes: 261 additions & 0 deletions examples/06_message_passing/message_passing_host_initiated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
"""
Host-Initiated Message Passing Example

This example demonstrates message passing where the producer (GPU 0) is
controlled by the HOST (Python/CPU) instead of a device kernel, while
the consumer (GPU 1) remains a device kernel.

Key difference from message_passing_put.py:
- Producer: Host uses anvil to initiate SDMA transfers from Python
- Consumer: Same device kernel waiting for data

This shows how to orchestrate GPU-to-GPU transfers from Python without
requiring kernel launches on the source GPU.
"""

import argparse

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import triton
import triton.language as tl
import random

import iris


@triton.jit
def consumer_kernel(
buffer, # tl.tensor: pointer to shared buffer (read from target_rank)
flag, # tl.tensor: sync flag per block
buffer_size, # int32: total number of elements
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < buffer_size

# Spin-wait until writer sets flag[pid] = 1
done = 0
while done == 0:
done = iris.atomic_cas(
flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys"
)

# Read from the target buffer (written by producer)
values = tl.load(buffer + offsets, mask=mask)

# Do something with values...
# (Here you might write to output, do computation, etc.)
values = values * 2

# Store chunk to target buffer
tl.store(
buffer + offsets,
values,
mask=mask,
)

# Optionally reset the flag for next iteration
tl.store(flag + pid, 0)


torch.manual_seed(123)
random.seed(123)


def torch_dtype_from_str(datatype: str) -> torch.dtype:
dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"int8": torch.int8,
"bf16": torch.bfloat16,
}
try:
return dtype_map[datatype]
except KeyError:
print(f"Unknown datatype: {datatype}")
exit(1)


def parse_args():
parser = argparse.ArgumentParser(
description="Host-Initiated SDMA Message Passing Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-t",
"--datatype",
type=str,
default="fp32",
choices=["fp16", "fp32", "int8", "bf16"],
help="Datatype of computation",
)
parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")
parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())


def host_initiated_producer(shmem, source_buffer, destination_buffer, flags, consumer_rank, block_size, verbose=True):
"""
Producer rank logic for host-initiated SDMA transfers.

Args:
shmem: Iris instance
source_buffer: Source buffer (symmetric)
destination_buffer: Destination buffer (symmetric)
flags: Flag buffer for synchronization (symmetric)
consumer_rank: Destination rank
block_size: Block size for chunking
verbose: Whether to print timing information
"""
n_elements = source_buffer.numel()
num_blocks = triton.cdiv(n_elements, block_size)

if verbose:
shmem.info(f"Rank {shmem.get_rank()} (HOST) is sending data to rank {consumer_rank}.")

# Initialize CUDA context even though we're doing host-side operations
# This is needed for the barrier to work
torch.cuda.current_device()

if verbose:
import time

start_time = time.time()

for block_id in range(num_blocks):
block_start = block_id * block_size
block_end = min(block_start + block_size, n_elements)
block_slice = slice(block_start, block_end)

# Views remain symmetric, so Iris can translate remote pointers automatically
src_chunk = source_buffer[block_slice]
dst_chunk = destination_buffer[block_slice]
flag_view = flags[block_id : block_id + 1]

shmem.put(
src_chunk,
dst_rank=consumer_rank,
dst_tensor=dst_chunk,
signal_flag=flag_view,
async_op=True,
)

shmem.quiet(dst_rank=consumer_rank)

if verbose:
end_time = time.time()
elapsed_ms = (end_time - start_time) * 1000
shmem.info(
f"Host SDMA loop took {elapsed_ms:.2f} ms for {num_blocks} blocks ({elapsed_ms / num_blocks:.2f} ms/block)"
)


def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
"""Worker function for PyTorch distributed execution."""
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(
backend=backend,
init_method=init_url,
world_size=world_size,
rank=local_rank,
device_id=torch.device(f"cuda:{local_rank}"),
)

# Main benchmark logic
shmem = iris.iris(args["heap_size"])
dtype = torch_dtype_from_str(args["datatype"])
cur_rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# Allocate source and destination buffers on the symmetric heap
destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
if dtype.is_floating_point:
source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
else:
ii = torch.iinfo(dtype)
source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)

if world_size != 2:
raise ValueError("This example requires exactly two processes.")

producer_rank = 0
consumer_rank = 1

n_elements = source_buffer.numel()
BLOCK_SIZE = args["block_size"]
num_blocks = triton.cdiv(n_elements, BLOCK_SIZE)
grid = (num_blocks,)

# Allocate flags on the symmetric heap
flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32)

if cur_rank == producer_rank:
host_initiated_producer(
shmem, source_buffer, destination_buffer, flags, consumer_rank, BLOCK_SIZE, verbose=True
)
else:
shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.")
kk = consumer_kernel[grid](
destination_buffer, flags, n_elements, consumer_rank, BLOCK_SIZE, shmem.get_heap_bases()
)

shmem.barrier()
shmem.info(f"Rank {cur_rank} has finished sending/receiving data.")
shmem.info("Validating output...")

success = True
if cur_rank == consumer_rank:
expected = source_buffer * 2
diff_mask = ~torch.isclose(destination_buffer, expected, atol=1)
breaking_indices = torch.nonzero(diff_mask, as_tuple=False)

if not torch.allclose(destination_buffer, expected, atol=1):
max_diff = (destination_buffer - expected).abs().max().item()
shmem.info(f"Max absolute difference: {max_diff}")
for idx in breaking_indices:
idx = tuple(idx.tolist())
computed_val = destination_buffer[idx]
expected_val = expected[idx]
shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}")
success = False
break

if success:
shmem.info("Validation successful.")
else:
shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}")

shmem.barrier()

dist.barrier()
dist.destroy_process_group()


def main():
args = parse_args()

num_ranks = args["num_ranks"]

init_url = "tcp://127.0.0.1:29500"
mp.spawn(
fn=_worker,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)


if __name__ == "__main__":
main()
43 changes: 36 additions & 7 deletions examples/06_message_passing/message_passing_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def producer_kernel(
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
copy_engine_handle_ptr,
USE_COPY_ENGINE: tl.constexpr,
):
pid = tl.program_id(0)

Expand All @@ -34,10 +36,30 @@ def producer_kernel(
mask = offsets < buffer_size

# Put chunk into remote buffer
iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, mask=mask)
iris.put(
source_buffer + offsets,
target_buffer + offsets,
producer_rank,
consumer_rank,
heap_bases_ptr,
copy_engine_handle_ptr,
mask=mask,
USE_COPY_ENGINE=USE_COPY_ENGINE,
)

# Set flag to signal completion
iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys")
iris.atomic_cas(
flag + pid,
0,
1,
producer_rank,
consumer_rank,
heap_bases_ptr,
sem="release",
scope="sys",
USE_COPY_ENGINE=USE_COPY_ENGINE,
copy_engine_ctx=copy_engine_handle_ptr,
)


@triton.jit
Expand Down Expand Up @@ -113,9 +135,11 @@ def parse_args():
)
parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")

parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
parser.add_argument(
"-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies"
)

return vars(parser.parse_args())

Expand All @@ -138,12 +162,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
world_size = shmem.get_num_ranks()

# Allocate source and destination buffers on the symmetric heap
source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
if dtype.is_floating_point:
destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
else:
ii = torch.iinfo(dtype)
destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)
source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)

if world_size != 2:
raise ValueError("This example requires exactly two processes.")
Expand All @@ -158,6 +182,9 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate flags on the symmetric heap
flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32)

# Get copy engine context
copy_engine_ctx = shmem.get_copy_engine_ctx()

if cur_rank == producer_rank:
shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.")
kk = producer_kernel[grid](
Expand All @@ -169,6 +196,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
consumer_rank,
args["block_size"],
shmem.get_heap_bases(),
copy_engine_ctx,
USE_COPY_ENGINE=args["use_copy_engine"],
)
else:
shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.")
Expand Down Expand Up @@ -199,7 +228,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
if success:
shmem.info("Validation successful.")
else:
shmem.info("Validation failed.")
shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}")

shmem.barrier()

Expand Down
Loading
Loading