-
Notifications
You must be signed in to change notification settings - Fork 7
[Feature] Semantics support and remote atomic add #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c859a16
cd4e509
e37fea4
12a98d0
231dad1
b496a54
6f072ab
268f54a
e8d036a
fc98be0
8404728
7b00d85
067511c
31a4643
67065ab
e2d8ee3
b00bdd8
2822ced
81af526
5f584e5
1d4bbd3
b065199
3845d36
6f570f4
5abf3f1
a9dae4c
0d1dc1c
baf1fc4
52ad42e
d05df28
777e391
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,114 @@ | ||||||||||||||
| import tilelang | ||||||||||||||
| import tilelang.language as T | ||||||||||||||
| from tilelang.distributed import init_dist | ||||||||||||||
| import torch | ||||||||||||||
| import torch.distributed as dist | ||||||||||||||
| import argparse | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def alltoall(PE_num, M, N, block_M, block_N, threads): | ||||||||||||||
| assert block_N == N | ||||||||||||||
|
|
||||||||||||||
| @T.prim_func | ||||||||||||||
| def main( | ||||||||||||||
| src: T.Tensor((PE_num * M, N), "float16"), | ||||||||||||||
| dst: T.Tensor((PE_num * M, N), "float16"), | ||||||||||||||
| barrier: T.Tensor((PE_num), "int32"), | ||||||||||||||
| ): | ||||||||||||||
| # Currently not support tiled copy | ||||||||||||||
| with T.Kernel( | ||||||||||||||
| PE_num, T.ceildiv(M, block_M), T.ceildiv(N, block_N), | ||||||||||||||
| threads=threads) as (bx, by, bz): | ||||||||||||||
| rank = T.alloc_local([1], "int32") | ||||||||||||||
| num_ranks = T.alloc_local([1], "int32") | ||||||||||||||
|
|
||||||||||||||
| dst_rank = bx | ||||||||||||||
| rank[0] = T.get_rank() | ||||||||||||||
| num_ranks[0] = T.get_num_ranks() | ||||||||||||||
|
|
||||||||||||||
| T.put_block( | ||||||||||||||
| src=T.address_of(src[dst_rank * M + by * block_M, 0]), | ||||||||||||||
| dst=T.address_of(dst[rank[0] * M + by * block_M, 0]), | ||||||||||||||
| size=block_M * block_N, | ||||||||||||||
| dst_pe=dst_rank, | ||||||||||||||
| ) | ||||||||||||||
| T.fence_sys(sem=T.MemorySemantic.RELEASE) | ||||||||||||||
|
|
||||||||||||||
| return main | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def run_alltoall(local_rank, num_ranks, args): | ||||||||||||||
| PE_num = args.PE_num | ||||||||||||||
| M = args.M | ||||||||||||||
| N = args.N | ||||||||||||||
| block_M = 512 | ||||||||||||||
| block_N = N | ||||||||||||||
| threads = 512 | ||||||||||||||
|
|
||||||||||||||
| local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) | ||||||||||||||
| allocator = tilelang.get_allocator( | ||||||||||||||
| size=2**34, | ||||||||||||||
| device="cuda", | ||||||||||||||
| is_distributed=True, | ||||||||||||||
| local_rank=local_rank, | ||||||||||||||
| num_local_ranks=num_ranks, | ||||||||||||||
| group=group_size, | ||||||||||||||
| ) | ||||||||||||||
| kernel = tilelang.compile(alltoall(PE_num, M, N, block_M, block_N, threads)) | ||||||||||||||
| kernel.initialize(allocator=allocator) | ||||||||||||||
| src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() | ||||||||||||||
| dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() | ||||||||||||||
| barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() | ||||||||||||||
|
|
||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
| dist.barrier(group_size) | ||||||||||||||
|
|
||||||||||||||
| # Warmup | ||||||||||||||
| for _ in range(args.warmup): | ||||||||||||||
| kernel(src, dst, barrier) | ||||||||||||||
| dst.zero_() | ||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
| dist.barrier(group_size) | ||||||||||||||
|
|
||||||||||||||
| start = torch.cuda.Event(enable_timing=True) | ||||||||||||||
| end = torch.cuda.Event(enable_timing=True) | ||||||||||||||
| start.record() | ||||||||||||||
| for _ in range(args.iter): | ||||||||||||||
| kernel(src, dst, barrier) | ||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
| dist.barrier(group_size) | ||||||||||||||
| end.record() | ||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
| dist.barrier(group_size) | ||||||||||||||
| elapsed_time = start.elapsed_time(end) / args.iter | ||||||||||||||
| print( | ||||||||||||||
| f"Rank {local_rank} Average Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s" | ||||||||||||||
| ) | ||||||||||||||
|
Comment on lines
+84
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bandwidth calculation may be incorrect - missing dtype size factor. The bandwidth formula 🔧 Proposed fix print(
- f"Rank {local_rank} Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s"
+ f"Rank {local_rank} Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N * 2 / (elapsed_time * 1e6):.3f} GB/s"
)Alternatively, for clarity: total_bytes = 2 * PE_num * M * N * 2 # bidirectional * elements * sizeof(float16)
bandwidth_gbps = total_bytes / (elapsed_time * 1e6)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| # Torch Reference | ||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
| dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") | ||||||||||||||
| dist.all_to_all_single(dst_ref, src, group=group_size) | ||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
|
|
||||||||||||||
| if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): | ||||||||||||||
| print(f"Rank {local_rank} Verification Passed! ✅") | ||||||||||||||
| else: | ||||||||||||||
| max_diff = (dst - dst_ref).abs().max() | ||||||||||||||
| print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") | ||||||||||||||
| print(f"dst: {dst}") | ||||||||||||||
| print(f"dst_ref: {dst_ref}") | ||||||||||||||
|
|
||||||||||||||
| dist.destroy_process_group() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| if __name__ == "__main__": | ||||||||||||||
| parser = argparse.ArgumentParser() | ||||||||||||||
| parser.add_argument("--PE_num", type=int, default=8) | ||||||||||||||
| parser.add_argument("--M", type=int, default=8192) | ||||||||||||||
| parser.add_argument("--N", type=int, default=7168) | ||||||||||||||
| parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") | ||||||||||||||
| parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") | ||||||||||||||
|
|
||||||||||||||
| args = parser.parse_args() | ||||||||||||||
| torch.multiprocessing.spawn(run_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused
barrierparameter in kernel signature.The
barrierparameter is declared in the kernel signature but never used within the kernel body. This could indicate either:Given this is an all-to-all operation, typically a barrier or synchronization mechanism is needed to ensure all ranks have completed their transfers before the kernel returns. Currently, only
T.fence_sysis called which provides memory ordering but not inter-rank synchronization.💡 Suggested fix: Either use the barrier or remove it
Option 1 - Add barrier synchronization:
T.put_block( src=T.address_of(src[dst_rank * M + by * block_M, 0]), dst=T.address_of(dst[rank[0] * M + by * block_M, 0]), size=block_M * block_N, dst_pe=dst_rank, ) T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.barrier_blocks(barrier) return mainOption 2 - Remove unused parameter:
`@T.prim_func` def main( src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), - barrier: T.Tensor((PE_num), "int32"), ):📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 16-16: Unused function argument:
barrier(ARG001)
[warning] 21-21: Unpacked variable
bzis never usedPrefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents