Skip to content

[PyTorch] torch.compile support for permutation functions#2686

Open
pggPL wants to merge 10 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile
Open

[PyTorch] torch.compile support for permutation functions#2686
pggPL wants to merge 10 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Feb 17, 2026

Description

This PR adds torch.compile(fullgraph=True) support for MoE permutation operations (moe_permute, moe_unpermute, moe_sort_chunks_by_index) by converting all torch.autograd.Function implementations to PyTorch custom operators using torch.library.custom_op.

Note that this PR does not add torch.compile support for QuantizedTensor as an input.

Related to #2590

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the moe_torch_compile branch from 41e22ef to 8159d26 Compare February 18, 2026 17:31
pre-commit-ci bot and others added 4 commits February 18, 2026 17:32
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review February 19, 2026 15:45
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

Refactored MoE permutation operations (moe_permute, moe_unpermute, moe_sort_chunks_by_index) from torch.autograd.Function to torch.library.custom_op to enable torch.compile(fullgraph=True) support.

Key changes:

  • Converted three autograd function classes to 10+ custom operators with forward, backward, and fake (shape inference) implementations
  • Added passthrough mechanism in QuantizedTensor.__torch_dispatch__ to prevent unwrapping for custom ops
  • Added runtime checks to reject FP8 quantized inputs under torch.compile (documented limitation)
  • Registered custom ops in _quantized_tensor_passthrough_ops set for proper FP8 handling
  • Comprehensive test coverage added for torch.compile path (tested on subset of configurations)

Note: Global workspace variables (_moe_permute_index_map_workspace, _moe_permute_index_map_max_expanded_token_num) remain as module-level state, which was already flagged in previous thread.

Confidence Score: 4/5

  • Safe to merge with awareness of global state limitation
  • The refactoring correctly implements torch.compile support through custom operators with proper forward/backward/fake implementations. Tests validate the new code paths. However, the global workspace variables remain a known issue (already commented on), and the PR explicitly documents the limitation that FP8 quantized inputs are not supported under torch.compile.
  • No files require special attention beyond the already-documented global state concern in transformer_engine/pytorch/permutation.py:30-31

Important Files Changed

Filename Overview
transformer_engine/pytorch/permutation.py Converted torch.autograd.Function classes to torch.library.custom_op for torch.compile support. Adds runtime check to prevent FP8 quantized inputs under torch.compile. Global workspace variables remain for index map operations.
tests/pytorch/test_permutation.py Added use_torch_compile parameter to test functions with torch.compile(fullgraph=True) wrapping. Tests cover limited configurations (single config per test) with torch.compile enabled.
transformer_engine/pytorch/quantized_tensor.py Added _quantized_tensor_passthrough_ops set and passthrough logic in __torch_dispatch__ to handle custom ops without unwrapping quantized tensors.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[User calls moe_permute/moe_unpermute/moe_sort_chunks_by_index] --> B{torch.compile enabled?}
    B -->|Yes| C{Input is FP8 QuantizedTensor?}
    B -->|No| D[Direct execution]
    C -->|Yes| E[RuntimeError: FP8 not supported under torch.compile]
    C -->|No| F[torch.ops.te_moe.* custom op]
    D --> G[torch.ops.te_moe.* custom op]
    F --> H[Custom op forward]
    G --> H
    H --> I{Needs shape inference?}
    I -->|Yes - compile tracing| J[register_fake: return fake tensors]
    I -->|No - eager execution| K[Actual kernel execution]
    K --> L[tex.moe_permute_fwd/triton_permutation.*]
    L --> M{FP8 handling?}
    M -->|Yes| N[Extract _data, wrap result in Float8Tensor]
    M -->|No| O[Direct tensor operations]
    N --> P[Return output]
    O --> P
    J --> Q[Compile graph construction]
    Q --> R[Backward registration]
    R --> S[register_autograd: setup_context + backward_wrapper]
    S --> T[torch.ops.te_moe.*_bwd custom op]
    T --> U[Backward kernel execution]
Loading

Last reviewed commit: 3472e14

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits February 19, 2026 15:57
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +225 to +227
import torch._functorch.config as functorch_config

functorch_config.donated_buffer = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do and why do we need to do that? Could we add a comment here, especially since we would be using the internal function here (and so it will most probably break at some point).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is optimization of torch.compile which is not compatible with retain_graph=True used in tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where?

# ===================== _moe_permute_index_map custom ops =====================

topK = index.size(1)
# Workspace state for moe_permute_index_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it (although I realize this is not really the problem with this PR, but rather the original implementation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can figure out how to change that however, that would be great. Maybe we could make moe_compute a functor (struct MoECompute with __call__ methods and the workspaces, then moe_compute would just be a object of that class that we would create at the very beginning).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? I mean what you don't like about it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the main thing is the fact that we implicitly rely on the fact that there is only one permutation happening at a time (and that problem would not be solved by my proposal BTW - this would need a change of this to be actual nn.Module but that has its own problems by effectively being an API break, we should still do it for TE 3.0 though). If you run 2 permutations in 2 streams then that has a chance of silent data corruption since both of those kernels would be using the same underlying workspace. This is something that the user has no way of knowing about without consulting the code. And with torch.compile the chance of this happening may be even bigger - we are at the whim of the compiler optimizations at this point.

Signed-off-by: root <pgadzinski@nvidia.com>
Signed-off-by: root <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants