Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

BLOCK_SIZE_M = 32

# Default to Opus unless CK sorting is explicitly requested.
# Default to Opus unless CK or FlyDSL sorting is explicitly requested.
_USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1"
_USE_FLYDSL_MOE_SORTING = os.environ.get("AITER_USE_FLYDSL_MOE_SORTING", "0") == "1"
_ACT_TYPE_DISABLED_KEY = "__ignore__"
_SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256"))

Expand Down Expand Up @@ -105,6 +106,47 @@ def _moe_sorting_impl(
return ret


def _flydsl_moe_sorting(
topk_ids,
topk_weights,
num_experts,
model_dim,
moebuf_dtype,
block_size,
expert_mask,
num_local_tokens,
):
"""FlyDSL sorting dispatch — called outside torch_compile_guard."""
from aiter.ops.flydsl.moe_sorting import flydsl_moe_sorting_fwd

device = topk_ids.device
M, topk = topk_ids.shape
max_num_tokens_padded = int(topk_ids.numel() + num_experts * block_size - topk)
max_num_m_blocks = int((max_num_tokens_padded + block_size - 1) // block_size)
sorted_ids = torch.empty(max_num_tokens_padded, dtype=dtypes.i32, device=device)
sorted_weights = torch.empty(
max_num_tokens_padded, dtype=dtypes.fp32, device=device
)
sorted_expert_ids = torch.empty(max_num_m_blocks, dtype=dtypes.i32, device=device)
num_valid_ids = torch.empty(2, dtype=dtypes.i32, device=device)
moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device)

flydsl_moe_sorting_fwd(
topk_ids,
topk_weights,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
moe_buf,
num_experts,
int(block_size),
expert_mask,
num_local_tokens,
)
return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf


def moe_sorting(
topk_ids,
topk_weights,
Expand All @@ -117,6 +159,17 @@ def moe_sorting(
dispatch_policy=0,
return_local_topk_ids=False,
):
if _USE_FLYDSL_MOE_SORTING:
return _flydsl_moe_sorting(
topk_ids,
topk_weights,
num_experts,
model_dim,
moebuf_dtype,
block_size,
expert_mask,
num_local_tokens,
)
try:
return _moe_sorting_impl(
topk_ids,
Expand Down
Loading
Loading