Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
20ad279
megatron: integrate lora grad sync with finalize_model_grads
FurtherAI Mar 10, 2026
112e97c
megatron: harden sharded lora merge validation
FurtherAI Mar 10, 2026
4d5c345
tests: add megatron lora oracle correctness harness
FurtherAI Mar 10, 2026
fde2ff3
Minor typing changes
FurtherAI Mar 10, 2026
d2c1161
megatron: extend LoRA grad-sync semantics across tp/expert-tp
FurtherAI Mar 12, 2026
e418018
megatron: add MoE routing replay core and unit tests
FurtherAI Mar 12, 2026
bc5e7a4
megatron runtime/service: wire routing replay into training jobs
FurtherAI Mar 12, 2026
c5e06d9
oracle worker/trace: capture forward traces and emit replay bundles
FurtherAI Mar 12, 2026
a73ca1a
oracle harness/tests: refactor suite and add oracle-replay parity flow
FurtherAI Mar 12, 2026
ec83716
typing: clear blocking ty errors in oracle replay and LoRA paths
FurtherAI Mar 12, 2026
83d871b
megatron: reduce oracle variance with sequence grad accumulation
FurtherAI Mar 14, 2026
84e2ea7
megatron lora: fix TP/EP export participation rules
FurtherAI Mar 14, 2026
0bc9919
oracle trace: canonicalize MoE outputs across arbitrary topologies
FurtherAI Mar 14, 2026
8370c7d
oracle harness: stabilize scoring and expand sensitivity mutations
FurtherAI Mar 14, 2026
d396bfd
oracle tests: write suite output tables to log files
FurtherAI Mar 14, 2026
5385fbb
Add correct data parallelism.
FurtherAI Mar 16, 2026
7525567
Fix per-token DP normalization in Megatron training
FurtherAI Mar 17, 2026
7eb96e5
Expand the oracle harness for DP correctness checks
FurtherAI Mar 17, 2026
204e580
Merge origin/main into austin/megatron_lora_correctness_oracle_tests
FurtherAI Mar 17, 2026
9cde0d4
Clean up type errors in Megatron correctness changes
FurtherAI Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/art/dev/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Literal
from typing import TYPE_CHECKING, Literal

from typing_extensions import TypedDict

if TYPE_CHECKING:
from art.megatron.routing_replay import MoeRoutingReplayBundle


class TrainConfig(TypedDict, total=False):
advantage_balance: float
Expand All @@ -22,6 +25,9 @@ class TrainConfig(TypedDict, total=False):
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
max_negative_advantage_importance_sampling_weight: float
moe_routing_replay_bundle: "MoeRoutingReplayBundle | None"
moe_routing_replay_path: str | None
moe_routing_replay_strict: bool
num_trajectories_learning_rate_multiplier_power: float
plot_tensors: bool
ppo: bool
Expand Down
5 changes: 4 additions & 1 deletion src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,10 @@ async def _train_model(
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
)
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
estimated_gradient_steps = disk_packed_tensors["num_sequences"]
grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences))
estimated_gradient_steps = math.ceil(
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
)
pbar = tqdm.tqdm(total=estimated_gradient_steps, desc="train")
async for result in service.train(
disk_packed_tensors, config, dev_config, verbose
Expand Down
34 changes: 23 additions & 11 deletions src/art/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel, ConfigDict
import torch
Expand All @@ -13,8 +13,10 @@

class Loss(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mean_policy_loss: torch.Tensor
mean_entropy: torch.Tensor | None
reduction: Literal["mean", "sum"]
policy_loss: torch.Tensor
kl: torch.Tensor
entropy: torch.Tensor | None
policy_loss_sum: torch.Tensor
probs_corr: torch.Tensor
kl_policy_ref: torch.Tensor | None = None
Expand All @@ -26,6 +28,7 @@ def loss_fn(
ref_logprobs: torch.Tensor | None,
entropies: torch.Tensor | None,
experimental_config: dev.TrainConfig,
reduction: Literal["mean", "sum"] = "mean",
) -> Loss:
old_logprobs = shift_tensor(inputs["logprobs"], float("nan"))
advantages = shift_tensor(inputs["advantages"], 0.0)
Expand Down Expand Up @@ -123,19 +126,28 @@ def loss_fn(
logprob_diff = old_logprobs - original_logprobs
prob_ratio = torch.exp(logprob_diff)
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
if ref_logprobs is not None:
kl_div = (
torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0
)
else:
kl_div = torch.zeros_like(policy_loss)
policy_loss = policy_loss * weights * assistant_mask
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
# Compute mean entropy for the current step
kl_div = kl_div * weights * assistant_mask
denominator = assistant_mask.sum() + 1e-6 if reduction == "mean" else 1.0
reduced_policy_loss = policy_loss.sum() / denominator
kl = kl_div.sum() / denominator
# Compute reduced entropy for the current step.
if entropies is not None:
shifted_entropies = shift_tensor(entropies, 0.0)
mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
entropy = (shifted_entropies * weights * assistant_mask).sum() / denominator
else:
mean_entropy = None
entropy = None
return Loss(
mean_policy_loss=mean_policy_loss,
mean_entropy=mean_entropy,
reduction=reduction,
policy_loss=reduced_policy_loss,
kl=kl,
entropy=entropy,
policy_loss_sum=policy_loss.sum(),
probs_corr=probs_corr,
kl_policy_ref=kl_policy_ref,
Expand Down
134 changes: 134 additions & 0 deletions src/art/megatron/finalize_grads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Literal, cast

from megatron.core import parallel_state as ps
from megatron.core.distributed.finalize_model_grads import finalize_model_grads
from megatron.core.transformer.module import MegatronModule
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

GradSyncDomain = Literal["tp_default", "expert_tp"]
GradSyncOp = Literal["none", "sum", "avg"]

TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default"
EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp"
GRAD_SYNC_OP_NONE: GradSyncOp = "none"
GRAD_SYNC_OP_SUM: GradSyncOp = "sum"
GRAD_SYNC_OP_AVG: GradSyncOp = "avg"
VALID_DOMAINS = (TP_DEFAULT_GRAD_SYNC_DOMAIN, EXPERT_TP_GRAD_SYNC_DOMAIN)
VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_SUM, GRAD_SYNC_OP_AVG)


def _iter_named_trainable_parameters(
model: list[MegatronModule],
) -> Iterable[tuple[str, torch.nn.Parameter]]:
seen: set[int] = set()
for chunk_index, model_chunk in enumerate(model):
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
param_id = id(param)
if param_id in seen:
continue
seen.add(param_id)
yield f"chunk{chunk_index}.{name}", param


def _resolve_domain_group(
domain: GradSyncDomain,
) -> Any | None:
if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN:
group = ps.get_tensor_model_parallel_group(check_initialized=False)
if group is None or group.size() <= 1:
return None
return group
if domain != EXPERT_TP_GRAD_SYNC_DOMAIN:
raise RuntimeError(f"Unknown grad sync domain: {domain}")

group = ps.get_expert_tensor_parallel_group(check_initialized=False)
if group is None or group.size() <= 1:
return None
return group


def _resolve_reduce_op(op: GradSyncOp) -> Any:
if op == GRAD_SYNC_OP_SUM:
return torch.distributed.ReduceOp.SUM # ty: ignore[possibly-missing-attribute]
if op == GRAD_SYNC_OP_AVG:
return torch.distributed.ReduceOp.AVG # ty: ignore[possibly-missing-attribute]
raise RuntimeError(f"Unknown grad sync op: {op}")


def finalize_model_grads_extended(
model: list[MegatronModule],
num_tokens: torch.Tensor | None = None,
) -> None:
"""Run Megatron finalize, then apply extra LoRA grad-sync reductions.

Megatron finalize handles DP/CP(via `param.allreduce=True`)(and expert-DP via `param.allreduce=False`) internally.
This extension handles extra TP/expert-TP reductions for params annotated
with grad_sync_* metadata.
"""
# All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
# embedding grads across first and last pipeline stages (if not tied)
finalize_model_grads(
cast(list[torch.nn.Module], model),
num_tokens=num_tokens,
)

buckets: dict[
tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device],
list[tuple[str, torch.Tensor]],
] = defaultdict(list)

for name, param in _iter_named_trainable_parameters(model):
domain: GradSyncDomain = getattr(
param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN
)
if _resolve_domain_group(domain) is None:
continue

op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE)
if op not in VALID_SYNC_OPS:
raise RuntimeError(f"{name}: unsupported grad_sync_op={op}")
if op == GRAD_SYNC_OP_NONE:
continue

if not hasattr(param, "main_grad"):
raise RuntimeError(
f"{name}: expected main_grad for domain={domain} reduce_op={op}, but attribute is missing"
)
grad = param.main_grad
if grad is None:
raise RuntimeError(
f"{name}: expected non-None main_grad for domain={domain} reduce_op={op}"
)
local_grad = cast( # local part of dtensor
torch.Tensor, grad._local_tensor if hasattr(grad, "_local_tensor") else grad
)
buckets[(domain, op, local_grad.dtype, local_grad.device)].append(
(name, local_grad)
)

for (domain, op, _dtype, _device), entries in buckets.items():
group = _resolve_domain_group(
domain
) # already checked if the domain is one we are handling

grads = [grad for _name, grad in entries]
coalesced = _flatten_dense_tensors(grads)
reduced = (
coalesced.float()
if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32
else coalesced
)
torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute]
reduced,
op=_resolve_reduce_op(op),
group=group,
)
if reduced is not coalesced:
reduced = reduced.to(dtype=coalesced.dtype)
for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)):
grad.copy_(synced)
Loading
Loading