[PyTorch] torch.compile support for permutation functions#2686
[PyTorch] torch.compile support for permutation functions#2686pggPL wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
41e22ef to
8159d26
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryRefactored MoE permutation operations ( Key changes:
Note: Global workspace variables ( Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[User calls moe_permute/moe_unpermute/moe_sort_chunks_by_index] --> B{torch.compile enabled?}
B -->|Yes| C{Input is FP8 QuantizedTensor?}
B -->|No| D[Direct execution]
C -->|Yes| E[RuntimeError: FP8 not supported under torch.compile]
C -->|No| F[torch.ops.te_moe.* custom op]
D --> G[torch.ops.te_moe.* custom op]
F --> H[Custom op forward]
G --> H
H --> I{Needs shape inference?}
I -->|Yes - compile tracing| J[register_fake: return fake tensors]
I -->|No - eager execution| K[Actual kernel execution]
K --> L[tex.moe_permute_fwd/triton_permutation.*]
L --> M{FP8 handling?}
M -->|Yes| N[Extract _data, wrap result in Float8Tensor]
M -->|No| O[Direct tensor operations]
N --> P[Return output]
O --> P
J --> Q[Compile graph construction]
Q --> R[Backward registration]
R --> S[register_autograd: setup_context + backward_wrapper]
S --> T[torch.ops.te_moe.*_bwd custom op]
T --> U[Backward kernel execution]
Last reviewed commit: 3472e14 |
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
| import torch._functorch.config as functorch_config | ||
|
|
||
| functorch_config.donated_buffer = False |
There was a problem hiding this comment.
What does it do and why do we need to do that? Could we add a comment here, especially since we would be using the internal function here (and so it will most probably break at some point).
There was a problem hiding this comment.
This is optimization of torch.compile which is not compatible with retain_graph=True used in tests.
There was a problem hiding this comment.
I added some comment.
| # ===================== _moe_permute_index_map custom ops ===================== | ||
|
|
||
| topK = index.size(1) | ||
| # Workspace state for moe_permute_index_map |
There was a problem hiding this comment.
I don't like it (although I realize this is not really the problem with this PR, but rather the original implementation).
There was a problem hiding this comment.
If we can figure out how to change that however, that would be great. Maybe we could make moe_compute a functor (struct MoECompute with __call__ methods and the workspaces, then moe_compute would just be a object of that class that we would create at the very beginning).
There was a problem hiding this comment.
why? I mean what you don't like about it
There was a problem hiding this comment.
Well, the main thing is the fact that we implicitly rely on the fact that there is only one permutation happening at a time (and that problem would not be solved by my proposal BTW - this would need a change of this to be actual nn.Module but that has its own problems by effectively being an API break, we should still do it for TE 3.0 though). If you run 2 permutations in 2 streams then that has a chance of silent data corruption since both of those kernels would be using the same underlying workspace. This is something that the user has no way of knowing about without consulting the code. And with torch.compile the chance of this happening may be even bigger - we are at the whim of the compiler optimizations at this point.
Signed-off-by: root <pgadzinski@nvidia.com>
Description
This PR adds
torch.compile(fullgraph=True)support for MoE permutation operations (moe_permute,moe_unpermute,moe_sort_chunks_by_index) by converting alltorch.autograd.Functionimplementations to PyTorch custom operators usingtorch.library.custom_op.Note that this PR does not add torch.compile support for QuantizedTensor as an input.
Related to #2590
Type of change
Checklist: