Skip to content

Megatron LoRA correctness: align distributed semantics with Megatron and validate TP/EP/ETP/DP against an oracle#619

Draft
FurtherAI wants to merge 20 commits intomainfrom
austin/megatron_lora_correctness_oracle_tests
Draft

Megatron LoRA correctness: align distributed semantics with Megatron and validate TP/EP/ETP/DP against an oracle#619
FurtherAI wants to merge 20 commits intomainfrom
austin/megatron_lora_correctness_oracle_tests

Conversation

@FurtherAI
Copy link
Collaborator

@FurtherAI FurtherAI commented Mar 17, 2026

Summary

This PR fixes multiple correctness issues in the Megatron LoRA backend and adds an oracle-based validation harness for distributed topologies.

It addresses real correctness bugs in:

  • TP/SP forward semantics
  • replicated LoRA parameter handling
  • QKV ordering/layout handling
  • DP token normalization

It also adds deterministic testing that validates representative TP / EP / ETP / DP topologies against a world-size-1 oracle.

Reviewer Notes

  • aligns Megatron LoRA training with Megatron-style distributed semantics instead of relying on ad hoc custom reductions
  • fixes topology-dependent final adapter drift
  • adds grad accumulation as a global sequence-count concept
  • adds a deterministic oracle harness comparing forwards, losses, grads, and LoRA deltas
  • adds sensitivity tests that intentionally break key invariants and confirm the harness catches them
  • passing topologies in this suite should be treated as validated Megatron backend topologies

What was wrong before

The previous implementation could run while still being mathematically wrong across topologies.

Main failure modes:

  • row-parallel LoRA forward paths did not fully follow Megatron TP/SP semantics
  • replicated LoRA factors were not handled according to their true distributed domain
  • QKV ordering/layout handling was not robust enough for topology parity
  • DP normalization was local in places where it needed to be global

That meant the same logical batch could produce different final LoRA updates depending on topology.

What changed

Megatron-style gradient handling

The training step now follows a Megatron-style lifecycle:

  • zero grad buffers
  • backward through all local micros
  • finalize_model_grads_extended(model, num_tokens=...)
  • optimizer step

Megatron's native finalize path now handles the standard dense DP/CP sync, and expert-DP sync via the normal allreduce=True/False routing on parameters.

ART then adds one narrow extension for non-default domains:

  • LoRA params carry explicit grad_sync_domain and grad_sync_op metadata
  • the finalize extension reads param.main_grad
  • grads are bucketed and coalesced
  • the buckets are all-reduced on the correct TP or expert-TP group

So the correction is not just "more reduction". It is reduction on the correct group, at the correct point in the Megatron step lifecycle, using the gradients Megatron itself owns.

Correct replicated/sharded LoRA math

The LoRA factors are now handled according to their actual parallel math.

For dense attention:

  • QKV LoRA uses A replicated over TP and B sharded over TP
  • o_proj LoRA uses A sharded over TP and B replicated over TP

For expert layers:

  • FC1 LoRA uses the same pattern as a column-parallel layer on expert-TP: A replicated, B sharded
  • FC2 LoRA uses the same pattern as a row-parallel layer on expert-TP: A sharded, B replicated

The important correction is the sync math for replicated factors.

Replicated factors do not get averaged here. They get summed on their shard domain, because each rank contributes a different partial gradient:

  • QKV replicated A gets a TP SUM over output-shard contributions
  • o_proj replicated B gets a TP SUM over input-shard contributions
  • expert FC1 replicated A gets an expert-TP SUM
  • expert FC2 replicated B gets an expert-TP SUM

Sharded factors do not participate in those replica-domain reductions. They only participate in the appropriate DP / expert-DP reductions.

Replicated parameters are also broadcast-initialized on their shard domain so they start from identical values rather than drifting immediately from rank-local random init.

Correct TP/SP forward semantics

The forward path for row-parallel LoRA branches now matches Megatron math instead of treating local partials as final outputs.

For o_proj:

  • each TP rank computes its local LoRA partial from its local input shard
  • without sequence parallelism, those partials are summed across TP before being added to the base output
  • with sequence parallelism, the LoRA branch follows the same reduce-scatter semantics as the base row-parallel layer

So the branch now matches the base-layer distributed equation instead of adding an unreduced local partial.

For QKV:

  • the adapter output is rebuilt in Megatron's mixed-QKV packing order before it is added to the base projection output
  • this fixes the topology-parity issue where the local tensors could be numerically plausible but packed in the wrong order relative to the Megatron layout

For expert FC2:

  • the wrapper does not add extra forward TP communication
  • that communication is already owned by the MoE token-routing path, so adding more would be the wrong math

Correct DP normalization

The DP normalization fix is explicit.

Per micro:

  • policy loss is computed with reduction="sum"

Per optimizer step:

  • micro losses are summed across the local grad-accumulation loop
  • local trainable-token counts are summed across the same micros
  • that token count is passed into Megatron finalize
  • Megatron globalizes num_tokens across the relevant distributed domain
  • the step loss is then formed as summed loss divided by global trainable-token count

So the corrected objective is effectively:

  • total summed loss over the full step
  • divided by the global trainable-token count for that same step

This is what fixes the old local-normalization bug and makes dp>1 comparable to the world-size-1 oracle when both consume the same total sequences.

Grad accumulation

Grad accumulation is now treated as a global sequence-count concept.

  • grad_accumulation_sequences is global
  • each DP rank consumes global_grad_accumulation_sequences / dp_world_size real sequences
  • the micro loop is built from that global step shape rather than implicitly repeating or wrapping samples
  • zero-contribution dummy micros are used when needed so the micro-step structure still matches the intended global step shape without introducing fake loss/token contributions

This makes numerics more stable and gives a clean correctness story for DP oracle comparisons.

This is configurable through the Megatron training config via grad_accumulation_sequences.

Oracle harness

This PR adds a deterministic oracle harness that compares each topology against a world-size-1 reference on:

  • forward activations
  • losses
  • gradients
  • LoRA parameter deltas

This does not prove the oracle is mathematically perfect, but it provides a strong practical baseline:

  • deterministic
  • no distributed communication in the reference path
  • all tested parallel topologies are checked against the same target

Expert replay

The harness also adds expert replay on the training side so MoE topology comparisons are actually comparable to the oracle.

This was necessary to make the oracle tests meaningful in expert settings. For now it is train-side only, which is intentional for this training-correctness work.

Sensitivity tests

The suite includes sensitivity tests that intentionally introduce known-bad mutations.

These are important because they show the harness is capable of detecting the kinds of distributed bugs this PR is meant to prevent.

In particular, the DP sensitivity checks include:

  • incorrect grad-accumulation sequencing
  • incorrect local token normalization

Those fail clearly, which is strong evidence that the passing DP checks are meaningful.

Scoring notes

Layer averaging

The harness includes layer-averaged summaries.

That is intentional: these are backend/distributed-semantics bugs, so if they exist they are usually systemic rather than unique to a single decoder layer.

Thresholds

Some thresholds may look high, especially for expert-related comparisons, but that is mostly because sparse/random expert values can show large relative percentages even when absolute differences are tiny.

The detailed logs give the more useful picture:

  • forward differences are generally closer to ~0.1%
  • backward / gradient-side differences are generally closer to ~1%

So the thresholds are mainly there to avoid false positives from noisy relative metrics.

Outcome

This PR turns the Megatron LoRA backend from “runs on these topologies” into “has explicit correctness evidence on these topologies.”

The passing matrix should be treated as validated support for the tested Megatron backend topologies.

Use per-step micro-accumulation over multiple packed sequences so updates are less sensitive to sparse expert token assignment. Also make backend progress accounting accumulation-aware.
Correct LoRA shard export behavior so non-zero TP ranks in EP/ETP topologies contribute when required, while still filtering replicated-only entries.
Move normalization logic into ForwardTraceCapture so saved traces are canonicalized toward world-size-1 semantics (expert row identity/order and ETP fc1 layout).
Rework oracle pass/fail evaluation with per-phase functions, layer-averaged metrics, deterministic init, expanded sensitivity mutations, and smaller Adam epsilon for tiny-gradient regimes.
Redirect suite stdout/stderr into local correctness/sensitivity logs and make skip/report messaging point to those artifacts instead of terminal output.
@FurtherAI FurtherAI requested a review from bradhilton March 18, 2026 03:31
@FurtherAI FurtherAI marked this pull request as draft March 18, 2026 07:13
Copy link
Collaborator

@bradhilton bradhilton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's goooooooo! 🔥

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants