Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
e041f6b
PR to main
Boss2002n Mar 18, 2026
a402f8a
fix
Boss2002n Mar 18, 2026
a9645cd
temp fix
Boss2002n Mar 18, 2026
957ea09
fix
Boss2002n Mar 18, 2026
bcc2007
fix
Boss2002n Mar 18, 2026
27f7631
fix?
Boss2002n Mar 18, 2026
858ff7d
working
Boss2002n Mar 18, 2026
3136c6e
gfx-12 pass
Boss2002n Mar 18, 2026
df46bc0
lint
Boss2002n Mar 18, 2026
9ba6bfc
fix
Boss2002n Mar 18, 2026
b17c4b0
fix
Boss2002n Mar 18, 2026
e278456
remove convert layout
Boss2002n Mar 18, 2026
84c0583
fix
Boss2002n Mar 18, 2026
684fc79
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n Mar 19, 2026
5d8d768
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n Mar 24, 2026
d380fef
Update arch_info.py
Boss2002n Mar 24, 2026
5f9a79d
latest
Boss2002n Mar 29, 2026
9c01398
small fix
Boss2002n Mar 29, 2026
a2a82a2
fix
Boss2002n Mar 29, 2026
db502af
fix
Boss2002n Mar 29, 2026
be0b1ff
Fix
Boss2002n Mar 30, 2026
ea853bf
waves =2
Boss2002n Mar 31, 2026
99630be
fix
Boss2002n Mar 31, 2026
f3185b8
fix
Boss2002n Apr 1, 2026
ec13e9e
optimized config
Boss2002n Apr 1, 2026
5075a87
fix
Boss2002n Apr 9, 2026
b6061b1
fix layout cuz A is not preshuf
Boss2002n Apr 9, 2026
79014b0
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n Apr 14, 2026
121fede
hacky b128 loads
Boss2002n Apr 15, 2026
1d37870
revert - with claude comments
Boss2002n Apr 15, 2026
f15d7d4
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n Apr 16, 2026
9ff0d5c
fix config
Boss2002n Apr 16, 2026
4056c4d
alex_pipelining
Boss2002n Apr 16, 2026
e270d52
k-tile-preshuf-fix
Boss2002n May 1, 2026
5a5d2f0
tdm advance
Boss2002n May 6, 2026
4f9c11f
remove update bounds
Boss2002n May 7, 2026
eee0e05
temp change -TO BE REVERTED
Boss2002n May 8, 2026
289e5f3
32x16
Boss2002n May 12, 2026
03fb8ef
update shuffle
Boss2002n May 13, 2026
4b7222b
fix - depreshuf -scales
Boss2002n May 13, 2026
24e4910
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n May 13, 2026
f21584d
address comments
Boss2002n May 13, 2026
bd773d9
black - format
Boss2002n May 13, 2026
08e3e1b
black - format
Boss2002n May 13, 2026
4278ec0
.load instead of relaxed shared load
Boss2002n May 19, 2026
c4371d2
B32_test
Boss2002n May 20, 2026
eefa2cd
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n May 20, 2026
1716df8
formatting
Boss2002n May 20, 2026
f269068
fix formatting
Boss2002n May 20, 2026
ac2a49e
ruff fix
Boss2002n May 20, 2026
0039f8b
fix
Boss2002n May 21, 2026
9bf916e
Merge branch 'main' into satya/gfx12_mxfp4_gemm
Boss2002n May 22, 2026
a7359e1
remove unused params from 1250 mxfp4 config
Boss2002n May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions aiter/ops/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,44 @@
import torch


def shuffle_weight_gfx1250(w: torch.Tensor) -> torch.Tensor:
"""
Preshuffle weights for gfx1250 WMMA.

For 2D input (N, K): view as (N//16, 16, K//32, 2, 16) ->
permute(0, 2, 3, 1, 4) -> reshape (N//16, K*16).
For 3D input (E, N, K) or (E, K, N): transpose to (E, N, K) first,
then apply the same pattern per-expert.

The result is reshaped to (N//16, K*16) for TDM-optimal loading.
"""
x_type = w.dtype
if hasattr(torch, "float4_e2m1fn_x2") and x_type == torch.float4_e2m1fn_x2:
w = w.view(torch.uint8)

if w.ndim == 2:
N, K = w.shape
assert N % 16 == 0, f"N={N} must be divisible by 16"
assert K % 32 == 0, f"K={K} must be divisible by 32"
w = w.view(N // 16, 16, K // 32, 2, 16)
w = w.permute(0, 2, 3, 1, 4).contiguous()
w = w.view(N // 16, K * 16)
elif w.ndim == 3:
E, K, N = w.shape
assert K % 32 == 0, f"K={K} must be divisible by 32"
assert N % 16 == 0, f"N={N} must be divisible by 16"
w = w.transpose(-1, -2) # (E, N, K)
w = w.view(E, N // 16, 16, K // 32, 2, 16)
w = w.permute(0, 1, 3, 4, 2, 5).contiguous()
w = w.view(E, N // 16, K * 16)
w = w.transpose(-1, -2) # (E, K*16, N//16)
Comment thread
vgokhale marked this conversation as resolved.
else:
raise ValueError(f"Expected 2D or 3D tensor, got {w.ndim}D")

w = w.view(x_type)
return w


def shuffle_weight(
x: torch.Tensor,
layout=(16, 16),
Expand Down
Loading
Loading