Conversation
Use per-step micro-accumulation over multiple packed sequences so updates are less sensitive to sparse expert token assignment. Also make backend progress accounting accumulation-aware.
Correct LoRA shard export behavior so non-zero TP ranks in EP/ETP topologies contribute when required, while still filtering replicated-only entries.
Move normalization logic into ForwardTraceCapture so saved traces are canonicalized toward world-size-1 semantics (expert row identity/order and ETP fc1 layout).
Rework oracle pass/fail evaluation with per-phase functions, layer-averaged metrics, deterministic init, expanded sensitivity mutations, and smaller Adam epsilon for tiny-gradient regimes.
Redirect suite stdout/stderr into local correctness/sensitivity logs and make skip/report messaging point to those artifacts instead of terminal output.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes multiple correctness issues in the Megatron LoRA backend and adds an oracle-based validation harness for distributed topologies.
It addresses real correctness bugs in:
It also adds deterministic testing that validates representative TP / EP / ETP / DP topologies against a world-size-1 oracle.
Reviewer Notes
What was wrong before
The previous implementation could run while still being mathematically wrong across topologies.
Main failure modes:
That meant the same logical batch could produce different final LoRA updates depending on topology.
What changed
Megatron-style gradient handling
The training step now follows a Megatron-style lifecycle:
finalize_model_grads_extended(model, num_tokens=...)Megatron's native finalize path now handles the standard dense DP/CP sync, and expert-DP sync via the normal
allreduce=True/Falserouting on parameters.ART then adds one narrow extension for non-default domains:
grad_sync_domainandgrad_sync_opmetadataparam.main_gradSo the correction is not just "more reduction". It is reduction on the correct group, at the correct point in the Megatron step lifecycle, using the gradients Megatron itself owns.
Correct replicated/sharded LoRA math
The LoRA factors are now handled according to their actual parallel math.
For dense attention:
Areplicated over TP andBsharded over TPo_projLoRA usesAsharded over TP andBreplicated over TPFor expert layers:
Areplicated,BshardedAsharded,BreplicatedThe important correction is the sync math for replicated factors.
Replicated factors do not get averaged here. They get summed on their shard domain, because each rank contributes a different partial gradient:
Agets a TP SUM over output-shard contributionso_projreplicatedBgets a TP SUM over input-shard contributionsAgets an expert-TP SUMBgets an expert-TP SUMSharded factors do not participate in those replica-domain reductions. They only participate in the appropriate DP / expert-DP reductions.
Replicated parameters are also broadcast-initialized on their shard domain so they start from identical values rather than drifting immediately from rank-local random init.
Correct TP/SP forward semantics
The forward path for row-parallel LoRA branches now matches Megatron math instead of treating local partials as final outputs.
For
o_proj:So the branch now matches the base-layer distributed equation instead of adding an unreduced local partial.
For QKV:
For expert FC2:
Correct DP normalization
The DP normalization fix is explicit.
Per micro:
reduction="sum"Per optimizer step:
num_tokensacross the relevant distributed domainSo the corrected objective is effectively:
This is what fixes the old local-normalization bug and makes
dp>1comparable to the world-size-1 oracle when both consume the same total sequences.Grad accumulation
Grad accumulation is now treated as a global sequence-count concept.
grad_accumulation_sequencesis globalglobal_grad_accumulation_sequences / dp_world_sizereal sequencesThis makes numerics more stable and gives a clean correctness story for DP oracle comparisons.
This is configurable through the Megatron training config via
grad_accumulation_sequences.Oracle harness
This PR adds a deterministic oracle harness that compares each topology against a world-size-1 reference on:
This does not prove the oracle is mathematically perfect, but it provides a strong practical baseline:
Expert replay
The harness also adds expert replay on the training side so MoE topology comparisons are actually comparable to the oracle.
This was necessary to make the oracle tests meaningful in expert settings. For now it is train-side only, which is intentional for this training-correctness work.
Sensitivity tests
The suite includes sensitivity tests that intentionally introduce known-bad mutations.
These are important because they show the harness is capable of detecting the kinds of distributed bugs this PR is meant to prevent.
In particular, the DP sensitivity checks include:
Those fail clearly, which is strong evidence that the passing DP checks are meaningful.
Scoring notes
Layer averaging
The harness includes layer-averaged summaries.
That is intentional: these are backend/distributed-semantics bugs, so if they exist they are usually systemic rather than unique to a single decoder layer.
Thresholds
Some thresholds may look high, especially for expert-related comparisons, but that is mostly because sparse/random expert values can show large relative percentages even when absolute differences are tiny.
The detailed logs give the more useful picture:
So the thresholds are mainly there to avoid false positives from noisy relative metrics.
Outcome
This PR turns the Megatron LoRA backend from “runs on these topologies” into “has explicit correctness evidence on these topologies.”
The passing matrix should be treated as validated support for the tested Megatron backend topologies.