Skip to content

refactor: reorganize moe ops and kernels#243

Open
zhenhuang12 wants to merge 3 commits into
mainfrom
refactor/moe-ops-reorganize
Open

refactor: reorganize moe ops and kernels#243
zhenhuang12 wants to merge 3 commits into
mainfrom
refactor/moe-ops-reorganize

Conversation

@zhenhuang12
Copy link
Copy Markdown
Collaborator

No description provided.

- Merge fused_router, permute, unpermute... into a single moe_utils.py file
- Rename token_permute/token_unpermute to moe_permute/moe_unpermute
- Update all references in token_dispatcher.py and test files
- Merge fused_router, multihot_to_indices, tokens_per_expert_to_mask into moe_utils.py
- Rename token_permute/token_unpermute to moe_permute/moe_unpermute
Copilot AI review requested due to automatic review settings March 5, 2026 12:36
@zhenhuang12 zhenhuang12 changed the title refactor: moe ops reorganize refactor: reorganize moe ops and kernels Mar 5, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors MoE-related ops by consolidating previously separate router / indices conversion / permutation / tokens-per-expert utilities into shared moe_utils modules (PyTorch + Triton), and updates call sites/tests to the new API names and import paths.

Changes:

  • Move Triton MoE kernels (router, indices↔multihot, tokens-per-expert mask) into primus_turbo/triton/moe/moe_utils.py.
  • Consolidate PyTorch MoE ops into primus_turbo/pytorch/ops/moe/moe_utils.py, renaming token_permute/unpermutemoe_permute/unpermute.
  • Update tests and module call sites to new names/imports and simplify MoE package exports.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/pytorch/ops/test_tokens_per_expert_to_mask.py Updates import to new consolidated MoE utils location.
tests/pytorch/ops/test_permutation.py Updates permutation API calls to moe_permute / moe_unpermute.
primus_turbo/triton/moe/tokens_per_expert_to_mask_kernel.py Deleted; kernel moved into Triton moe_utils.
primus_turbo/triton/moe/multihot_to_indices.py Deleted; kernels moved into Triton moe_utils.
primus_turbo/triton/moe/moe_utils.py Becomes the consolidated Triton MoE kernel module.
primus_turbo/pytorch/ops/moe/tokens_per_expert_to_mask.py Deleted; op moved into PyTorch moe_utils.
primus_turbo/pytorch/ops/moe/moe_utils.py New single entry point for router, indices conversion, permute/unpermute, and tokens-per-expert mask.
primus_turbo/pytorch/ops/moe/indices_converter.py Deleted; functionality moved into PyTorch moe_utils.
primus_turbo/pytorch/ops/moe/fused_moe_router.py Deleted; functionality moved into PyTorch moe_utils.
primus_turbo/pytorch/ops/moe/init.py Re-exports MoE API via moe_dispatch_combine + moe_utils.
primus_turbo/pytorch/modules/moe/token_dispatcher.py Updates call sites to renamed moe_permute/unpermute API.
primus_turbo/pytorch/kernels/moe/tokens_per_expert_to_mask_impl.py Updates kernel import to consolidated Triton moe_utils.
primus_turbo/pytorch/kernels/moe/fused_moe_router_impl.py Updates kernel import to consolidated Triton moe_utils.
Comments suppressed due to low confidence (7)

primus_turbo/pytorch/ops/moe/moe_utils.py:100

  • torch.zeros(..., device="cuda") hard-codes GPU 0; this will break on multi-GPU if logits is on a non-default CUDA device. Allocate on logits.device (and similarly avoid hard-coded "cuda" elsewhere in this module) so the op is device-correct.
    primus_turbo/pytorch/ops/moe/moe_utils.py:190
  • These outputs are always allocated on device="cuda", which forces CUDA:0 even if indices is on another CUDA device. Allocate multihot_indices, probs_in_multihot, and position_map on indices.device to keep the op multi-GPU safe.
    primus_turbo/pytorch/ops/moe/moe_utils.py:236
  • grad_probs_indices is allocated on device="cuda", which can put the gradient tensor on the wrong GPU when running on CUDA:1+. Allocate it on grad_probs_in_multihot.device (or position_map.device) instead.
    primus_turbo/pytorch/ops/moe/moe_utils.py:255
  • IndicesToMultihot.forward takes 3 inputs, but backward returns 4 gradients. PyTorch will error with “returned an invalid number of gradients”. Return exactly 3 items (one per forward input).
    primus_turbo/pytorch/ops/moe/moe_utils.py:560
  • The return type annotation for moe_permute doesn’t match actual returns: this function returns 4 values (both in fused and non-fused paths), but the annotation declares only 3. Update the signature typing to reflect the real return tuple shape to avoid misleading callers and type-checkers.
    primus_turbo/pytorch/ops/moe/moe_utils.py:592
  • Use is not None / is None for None checks instead of != None to avoid surprising behavior with overloaded __eq__ and to match the rest of this module’s style.
    primus_turbo/pytorch/ops/moe/moe_utils.py:100
  • In this autograd Function, g_probs can be None when backpropagating only through output_scores (e.g., calling output_scores.backward(...) without touching output_probs). Multiplying g_probs before checking it will raise a runtime error; guard the multiplication (or treat missing grads as zeros) before using it.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants