From 20ad27933f3879993e93667a732c1dc3cf4a66b8 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 10 Mar 2026 05:12:50 +0000 Subject: [PATCH 01/19] megatron: integrate lora grad sync with finalize_model_grads --- src/art/megatron/finalize_grads.py | 112 +++++ src/art/megatron/lora.py | 615 +++++++++++++++++++-------- src/art/megatron/train.py | 642 ++++++++++++++++++----------- 3 files changed, 957 insertions(+), 412 deletions(-) create mode 100644 src/art/megatron/finalize_grads.py diff --git a/src/art/megatron/finalize_grads.py b/src/art/megatron/finalize_grads.py new file mode 100644 index 00000000..8c496667 --- /dev/null +++ b/src/art/megatron/finalize_grads.py @@ -0,0 +1,112 @@ +from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Literal, cast + +from megatron.core import parallel_state as ps +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +GradSyncDomain = Literal["tp_default", "expert_tp"] +GradSyncOp = Literal["none", "avg"] + +TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" +EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" +GRAD_SYNC_OP_NONE: GradSyncOp = "none" +GRAD_SYNC_OP_AVG: GradSyncOp = "avg" +VALID_DOMAINS = (TP_DEFAULT_GRAD_SYNC_DOMAIN, EXPERT_TP_GRAD_SYNC_DOMAIN) +VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_AVG) + + +def _iter_named_trainable_parameters( + model: list[torch.nn.Module], +) -> Iterable[tuple[str, torch.nn.Parameter]]: + seen: set[int] = set() + for chunk_index, model_chunk in enumerate(model): + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + yield f"chunk{chunk_index}.{name}", param + + +def _resolve_domain_group( + domain: GradSyncDomain, +) -> torch.distributed.ProcessGroup | None: + if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN: + return None + if domain != EXPERT_TP_GRAD_SYNC_DOMAIN: + raise RuntimeError(f"Unknown grad sync domain: {domain}") + + group = ps.get_expert_tensor_parallel_group(check_initialized=False) + if group is None: + return None + if group.size() <= 1: + return None + return group + + +def _resolve_reduce_op(op: GradSyncOp) -> Any: + if op == GRAD_SYNC_OP_AVG: + return torch.distributed.ReduceOp.AVG + raise RuntimeError(f"Unknown grad sync op: {op}") + + +def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: + """Run Megatron finalize, then apply non-default grad-sync reductions. + + Megatron finalize handles DP/CP (and expert-DP via `param.allreduce=False`) internally. + This extension only handles extra reductions outside Megatron's default TP path, + currently expert-TP reductions for params annotated with grad_sync_* metadata. + """ + finalize_model_grads(model) + + buckets: dict[ + tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device], + list[tuple[str, torch.Tensor]], + ] = defaultdict(list) + + for name, param in _iter_named_trainable_parameters(model): + domain: GradSyncDomain = getattr( + param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN + ) + if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN: + continue + if domain not in VALID_DOMAINS: + raise RuntimeError(f"{name}: unsupported grad_sync_domain={domain}") + + op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) + if op not in VALID_SYNC_OPS: + raise RuntimeError(f"{name}: unsupported grad_sync_op={op}") + if op == GRAD_SYNC_OP_NONE: + continue + + if not hasattr(param, "main_grad"): + raise RuntimeError( + f"{name}: expected main_grad for domain={domain} reduce_op={op}, but attribute is missing" + ) + grad = param.main_grad + if grad is None: + raise RuntimeError( + f"{name}: expected non-None main_grad for domain={domain} reduce_op={op}" + ) + local_grad = cast( + torch.Tensor, grad._local_tensor if hasattr(grad, "_local_tensor") else grad + ) + buckets[(domain, op, local_grad.dtype, local_grad.device)].append( + (name, local_grad) + ) + + for (domain, op, _dtype, _device), entries in buckets.items(): + group = _resolve_domain_group(domain) + if group is None: + continue + + grads = [grad for _name, grad in entries] + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, op=_resolve_reduce_op(op), group=group) + for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + grad.copy_(synced) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 3ba97a77..12a38dec 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence import math -from typing import Sequence +from typing import Any, Literal from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.core import parallel_state as ps @@ -9,12 +10,111 @@ TERowParallelGroupedLinear, TERowParallelLinear, ) +from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.moe import grouped_gemm_util from megatron.core.transformer.moe.experts import TEGroupedMLP from megatron.core.transformer.transformer_layer import TransformerLayer +from pydantic import BaseModel, ConfigDict import torch +ShardDomain = Literal["tp", "expert_tp"] +GradSyncDomain = Literal["tp_default", "expert_tp"] +GradSyncOp = Literal["none", "avg"] + +TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" +EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" +GRAD_SYNC_OP_NONE: GradSyncOp = "none" +GRAD_SYNC_OP_AVG: GradSyncOp = "avg" + + +class LoRAParallelSpec(BaseModel): + # This spec only describes TP / expert-TP behavior. + # DP/CP vs expert-DP behavior is selected separately via `allreduce`. + model_config = ConfigDict(frozen=True) + + shard_domain: ShardDomain = "tp" + sharded: bool = False + shard_axis: int | None = None + grad_sync_domain: GradSyncDomain = TP_DEFAULT_GRAD_SYNC_DOMAIN + grad_sync_op: GradSyncOp = GRAD_SYNC_OP_NONE + + +def _distributed_initialized() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +def _get_shard_world_size(domain: ShardDomain) -> int: + if not _distributed_initialized(): + return 1 + if domain == "tp": + return ps.get_tensor_model_parallel_world_size() + group = ps.get_expert_tensor_parallel_group(check_initialized=False) + if group is None: + return 1 + return group.size() + + +def _get_shard_rank(domain: ShardDomain) -> int: + if not _distributed_initialized(): + return 0 + if domain == "tp": + return ps.get_tensor_model_parallel_rank() + group = ps.get_expert_tensor_parallel_group(check_initialized=False) + if group is None: + return 0 + return group.rank() + + +def _get_shard_group(domain: ShardDomain) -> torch.distributed.ProcessGroup | None: + if not _distributed_initialized(): + return None + if domain == "tp": + return ps.get_tensor_model_parallel_group() + return ps.get_expert_tensor_parallel_group(check_initialized=False) + + +def _normalize_axis(axis: int, ndim: int) -> int: + if axis < 0: + axis += ndim + if axis < 0 or axis >= ndim: + raise ValueError(f"Invalid shard axis {axis} for tensor ndim={ndim}") + return axis + + +def _set_lora_parallel_metadata( + param: torch.nn.Parameter, + *, + parallel_spec: LoRAParallelSpec, + allreduce: bool, +) -> None: + replicated = not parallel_spec.sharded + setattr(param, "lora_shard_domain", parallel_spec.shard_domain) + setattr(param, "lora_tp_sharded", parallel_spec.sharded) + setattr(param, "lora_tp_replicated", replicated) + setattr(param, "lora_tp_shard_axis", parallel_spec.shard_axis) + setattr(param, "grad_sync_domain", parallel_spec.grad_sync_domain) + setattr(param, "grad_sync_op", parallel_spec.grad_sync_op) + # Megatron DDP routing flag: + # - allreduce=True: sync with regular DP/CP replicas. + # - allreduce=False: sync with expert-DP replicas. + # TP / expert-TP replica handling is controlled by grad_sync_* metadata. + setattr(param, "allreduce", allreduce) + + # Megatron's native TP finalize path consumes this attr. + setattr( + param, + "average_gradients_across_tp_domain", + ( + replicated + and parallel_spec.grad_sync_domain == TP_DEFAULT_GRAD_SYNC_DOMAIN + and parallel_spec.grad_sync_op == GRAD_SYNC_OP_AVG + ), + ) + class LoRA(torch.nn.Module): def __init__( @@ -27,6 +127,9 @@ def __init__( dtype: torch.dtype, device: torch.device, num_local_experts: int = 1, + a_parallel_spec: LoRAParallelSpec = LoRAParallelSpec(), + b_parallel_spec: LoRAParallelSpec = LoRAParallelSpec(), + allreduce: bool = True, ) -> None: super().__init__() assert num_local_experts == 1 or "{expert}" in adapter_model_prefix, ( @@ -44,6 +147,16 @@ def __init__( num_local_experts, rank, out_features, dtype=dtype, device=device ).squeeze(0) ) + _set_lora_parallel_metadata( + self.A_T, + parallel_spec=a_parallel_spec, + allreduce=allreduce, + ) + _set_lora_parallel_metadata( + self.B_T, + parallel_spec=b_parallel_spec, + allreduce=allreduce, + ) self._expert_offset = ps.get_expert_model_parallel_rank() * num_local_experts self.reset_lora_parameters() @@ -51,6 +164,21 @@ def __init__( def num_local_experts(self) -> int: return self.A_T.shape[0] if self.A_T.ndim == 3 else 1 + def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: + if not getattr(param, "lora_tp_replicated", False): + return + domain = getattr(param, "lora_shard_domain") + world_size = _get_shard_world_size(domain) + if world_size <= 1: + return + group = _get_shard_group(domain) + if group is None: + raise RuntimeError( + f"{self.adapter_model_prefix}: missing process group for replicated parameter domain={domain}" + ) + src = torch.distributed.get_global_rank(group, 0) + torch.distributed.broadcast(param.data, src=src, group=group) + def reset_lora_parameters(self) -> None: """Initialize LoRA weights (A=Kaiming, B=zeros) like PEFT defaults.""" if self.A_T.ndim == 3: @@ -59,22 +187,38 @@ def reset_lora_parameters(self) -> None: else: torch.nn.init.kaiming_uniform_(self.A_T.T, a=math.sqrt(5)) torch.nn.init.zeros_(self.B_T) + self._broadcast_if_replicated(self.A_T) + self._broadcast_if_replicated(self.B_T) + + def _expected_weight_keys(self, suffix: str) -> list[str]: + if self.num_local_experts > 1: + return [ + f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{suffix}.weight" + for expert in range(self.num_local_experts) + ] + return [f"{self.adapter_model_prefix}.{suffix}.weight"] def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: - try: - self.load_weights( - adapter_model, - suffix="lora_A", - into=self.A_T, + missing_keys = [ + key + for suffix in ("lora_A", "lora_B") + for key in self._expected_weight_keys(suffix) + if key not in adapter_model + ] + if missing_keys: + raise KeyError( + f"Missing LoRA adapter keys for {self.adapter_model_prefix}: {sorted(missing_keys)}" ) - self.load_weights( - adapter_model, - suffix="lora_B", - into=self.B_T, - ) - except KeyError: - print("Unable to find LoRA weights for", self.adapter_model_prefix) - self.reset_lora_parameters() + self.load_weights( + adapter_model, + suffix="lora_A", + into=self.A_T, + ) + self.load_weights( + adapter_model, + suffix="lora_B", + into=self.B_T, + ) def load_weights( self, @@ -83,65 +227,111 @@ def load_weights( suffix: str, into: torch.nn.Parameter, ) -> None: - self.load_weight( - ( - torch.stack( - [ - adapter_model[ - f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{suffix}.weight" - ].T - for expert in range(self.num_local_experts) - ] - ) - if self.num_local_experts > 1 - else adapter_model[f"{self.adapter_model_prefix}.{suffix}.weight"].T - ), - into=into, - ) + keys = self._expected_weight_keys(suffix) + if self.num_local_experts > 1: + weight = torch.stack([adapter_model[key].T for key in keys]) + else: + weight = adapter_model[keys[0]].T + self.load_weight(weight, into=into) def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: - setattr(into, "sharded", False) - tp_world_size = ps.get_tensor_model_parallel_world_size() - tp_rank = ps.get_tensor_model_parallel_rank() - for axis in (-2, -1): - if weight.shape[axis] == into.shape[axis]: - continue - # assume our param is tensor sharded along this axis - assert weight.shape[axis] // tp_world_size == into.shape[axis], ( - f"Weight shape {weight.shape} does not match into shape {into.shape} along axis {axis}" + domain = getattr(into, "lora_shard_domain") + sharded = bool(getattr(into, "lora_tp_sharded")) + if sharded: + axis = getattr(into, "lora_tp_shard_axis") + if axis is None: + raise RuntimeError( + f"{self.adapter_model_prefix}: missing shard axis for sharded parameter" + ) + axis = _normalize_axis(axis, weight.ndim) + world_size = _get_shard_world_size(domain) + rank = _get_shard_rank(domain) + if weight.shape[axis] % world_size != 0: + raise ValueError( + f"{self.adapter_model_prefix}: weight shape {tuple(weight.shape)} is not divisible by world size " + f"{world_size} on axis {axis}" + ) + local_size = weight.shape[axis] // world_size + if into.shape[axis] != local_size: + raise ValueError( + f"{self.adapter_model_prefix}: expected local shard size {into.shape[axis]}, got {local_size}" + ) + weight = weight.narrow(axis, rank * local_size, local_size) + elif tuple(weight.shape) != tuple(into.shape): + raise ValueError( + f"{self.adapter_model_prefix}: unsharded load shape mismatch, got {tuple(weight.shape)} " + f"expected {tuple(into.shape)}" ) - s = into.shape[axis] - weight = weight.narrow(axis, tp_rank * s, s) - setattr(into, "sharded", True) into.data.copy_(weight) into.requires_grad = True - def sharded_lora_state_dict(self) -> dict[str, torch.Tensor]: - if self.num_local_experts > 1: + def _should_export_parameter(self, param: torch.nn.Parameter) -> bool: + if self.num_local_experts > 1: # self is a MoE layer if ps.get_expert_data_parallel_rank() != 0: - return {} - return { - f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{key}": param.data[ - expert - ].T - for expert in range(self.num_local_experts) - for key, param in ( - ("lora_A.weight", self.A_T), - ("lora_B.weight", self.B_T), - ) - } - if ps.get_data_parallel_rank() != 0 or torch.all(self.A_T == 0): - return {} + return False + else: # self is a non-MoE layer + if ps.get_data_parallel_rank() != 0: + return False + # Non-MoE layers are replicated across expert-model-parallel ranks. + if ( + ps.get_expert_model_parallel_world_size() > 1 + and ps.get_expert_model_parallel_rank() != 0 + ): + return False + + if getattr(param, "lora_tp_sharded", False): + # this param is fully sharded, all shard ranks participate + return True + + domain = getattr(param, "lora_shard_domain") + # param is replicated, tp rank 0 or etp rank 0 participates + return _get_shard_rank(domain) == 0 + + def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: + domain = getattr(param, "lora_shard_domain") + sharded = bool(getattr(param, "lora_tp_sharded", False)) + shard_axis = getattr(param, "lora_tp_shard_axis", None) return { - f"{self.adapter_model_prefix}.{key}": param.data.T - for key, param in ( - ("lora_A.weight", self.A_T), - ("lora_B.weight", self.B_T), - ) - if getattr(param, "sharded", False) - or ps.get_tensor_model_parallel_rank() == 0 + "domain": domain, + "sharded": sharded, + "shard_axis": shard_axis, + "shard_world_size": _get_shard_world_size(domain) if sharded else 1, + "shard_rank": _get_shard_rank(domain) if sharded else 0, + } + + def _lora_params(self) -> list[tuple[str, torch.nn.Parameter]]: + return [ + ("lora_A.weight", self.A_T), + ("lora_B.weight", self.B_T), + ] + + def _export_items( + self, + ) -> list[tuple[str, torch.nn.Parameter, int | None]]: + export_items: list[tuple[str, torch.nn.Parameter, int | None]] = [] + for key, param in self._lora_params(): + if not self._should_export_parameter(param): + continue + if self.num_local_experts > 1: + for expert in range(self.num_local_experts): + full_key = f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{key}" + export_items.append((full_key, param, expert)) + else: + export_items.append((f"{self.adapter_model_prefix}.{key}", param, None)) + return export_items + + def sharded_lora_manifest(self) -> dict[str, dict[str, Any]]: + return { + key: self._manifest_for_param(param) + for key, param, _expert in self._export_items() } + def sharded_lora_state_dict(self) -> dict[str, torch.Tensor]: + state: dict[str, torch.Tensor] = {} + for key, param, expert in self._export_items(): + state[key] = param.data[expert].T if expert is not None else param.data.T + return state + def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None ) -> torch.Tensor: @@ -152,14 +342,13 @@ def forward( bsz = tokens_per_expert if isinstance(bsz, list): bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") - # If no tokens routed locally, return zeros + # If no tokens routed locally, return zeros. if isinstance(bsz, torch.Tensor) and int(torch.count_nonzero(bsz)) == 0: return x.new_zeros((x.shape[0], self.B_T.shape[-1])) tmp = grouped_gemm_util.ops.gmm(x, self.A_T, bsz, trans_b=False) # type: ignore[attr-defined] out = grouped_gemm_util.ops.gmm(tmp, self.B_T, bsz, trans_b=False) # type: ignore[attr-defined] return out * self.scale - else: - return ((x @ self.A_T) @ self.B_T) * self.scale + return ((x @ self.A_T) @ self.B_T) * self.scale class SelfAttentionLinearProjLoRA(torch.nn.Module): @@ -175,6 +364,20 @@ def __init__( self.provider = provider self.linear_proj = linear_proj assert isinstance(linear_proj.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=True, + shard_axis=-2, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": False, + "shard_axis": None, + "grad_sync_op": GRAD_SYNC_OP_AVG, # megatron reduces across TP ranks + } + ) self.lora = LoRA( adapter_model_prefix=adapter_model_prefix, in_features=linear_proj.in_features, @@ -183,22 +386,23 @@ def __init__( alpha=alpha, dtype=linear_proj.weight.dtype, device=linear_proj.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. + allreduce=True, ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: base_output, bias_output = self.linear_proj(x) assert isinstance(base_output, torch.Tensor) assert isinstance(bias_output, (torch.Tensor, type(None))) + lora_output = self.lora(x) - if ( - self.provider.sequence_parallel - and self.provider.tensor_model_parallel_size > 1 - ): - tp_rank = ps.get_tensor_model_parallel_rank() - tokens_per_rank = base_output.shape[0] - start = tp_rank * tokens_per_rank - end = start + tokens_per_rank - lora_output = lora_output[start:end] + if self.provider.tensor_model_parallel_size > 1: + if self.provider.sequence_parallel: + lora_output = reduce_scatter_to_sequence_parallel_region(lora_output) + else: + lora_output = reduce_from_tensor_model_parallel_region(lora_output) return base_output + lora_output, bias_output @@ -231,32 +435,64 @@ def __init__( q_out_features_per_rank = q_out_features // tp_world_size kv_out_features_per_rank = kv_out_features // tp_world_size assert isinstance(linear_qkv.weight, torch.Tensor) - self.q_proj_lora = LoRA( + self.q_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.q_proj", - in_features=linear_qkv.in_features, - out_features=q_out_features_per_rank, + linear_qkv=linear_qkv, rank=rank, alpha=alpha, - dtype=linear_qkv.weight.dtype, - device=linear_qkv.weight.device, + out_features=q_out_features_per_rank, ) - self.k_proj_lora = LoRA( + self.k_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.k_proj", - in_features=linear_qkv.in_features, - out_features=kv_out_features_per_rank, + linear_qkv=linear_qkv, rank=rank, alpha=alpha, - dtype=linear_qkv.weight.dtype, - device=linear_qkv.weight.device, + out_features=kv_out_features_per_rank, ) - self.v_proj_lora = LoRA( + self.v_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.v_proj", - in_features=linear_qkv.in_features, + linear_qkv=linear_qkv, + rank=rank, + alpha=alpha, out_features=kv_out_features_per_rank, + ) + + @staticmethod + def _build_qkv_lora( + *, + adapter_model_prefix: str, + linear_qkv: TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + out_features: int, + ) -> LoRA: + assert isinstance(linear_qkv.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=False, + shard_axis=None, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_AVG, # megatron reduces across TP ranks + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_axis": -1, + "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear_qkv.in_features, + out_features=out_features, rank=rank, alpha=alpha, dtype=linear_qkv.weight.dtype, device=linear_qkv.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. + allreduce=True, ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -302,19 +538,48 @@ def __init__( super().__init__() assert linear_fc1 is not None self.linear_fc1 = linear_fc1 - assert isinstance(linear_fc1.weight0, torch.Tensor) - self.gate_lora = LoRA( + self.gate_lora = self._build_fc1_lora( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_proj", - in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features // 2, + linear_fc1=linear_fc1, rank=rank, alpha=alpha, - dtype=linear_fc1.weight0.dtype, - device=linear_fc1.weight0.device, num_local_experts=num_local_experts, ) - self.up_lora = LoRA( + self.up_lora = self._build_fc1_lora( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.up_proj", + linear_fc1=linear_fc1, + rank=rank, + alpha=alpha, + num_local_experts=num_local_experts, + ) + + @staticmethod + def _build_fc1_lora( + *, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelGroupedLinear, + rank: int, + alpha: float, + num_local_experts: int, + ) -> LoRA: + assert isinstance(linear_fc1.weight0, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="expert_tp", + sharded=False, + shard_axis=None, + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_AVG, # we handle this with extended finalize_grads + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_axis": -1, + "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, + "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, in_features=linear_fc1.in_features, out_features=linear_fc1.out_features // 2, rank=rank, @@ -322,6 +587,10 @@ def __init__( dtype=linear_fc1.weight0.dtype, device=linear_fc1.weight0.device, num_local_experts=num_local_experts, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Expert LoRA params use Megatron's expert-DP gradient buckets. + allreduce=False, ) def forward( @@ -347,6 +616,21 @@ def __init__( assert linear_fc2 is not None assert isinstance(linear_fc2.weight0, torch.Tensor) self.linear_fc2 = linear_fc2 + a_parallel_spec = LoRAParallelSpec( + shard_domain="expert_tp", + sharded=True, + shard_axis=-2, + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": False, + "shard_axis": None, + "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, + "grad_sync_op": GRAD_SYNC_OP_AVG, # we handle this with extended finalize_grads + } + ) self.lora = LoRA( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.down_proj", in_features=linear_fc2.in_features, @@ -356,6 +640,10 @@ def __init__( dtype=linear_fc2.weight0.dtype, device=linear_fc2.weight0.device, num_local_experts=num_local_experts, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Expert LoRA params use Megatron's expert-DP gradient buckets. + allreduce=False, ) def forward( @@ -369,77 +657,68 @@ def forward( def apply_lora_adapters( model: Sequence[torch.nn.Module], provider: GPTModelProvider, -) -> None: - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if isinstance(module, TransformerLayer): - adapter_model_prefix = ( - f"base_model.model.model.layers.{module.layer_number - 1}" - ) - assert isinstance(module.self_attention, SelfAttention) - self_attention_linear_proj = module.self_attention.linear_proj - if not isinstance(self_attention_linear_proj, TERowParallelLinear): - self_attention_linear_proj = ( - self_attention_linear_proj.linear_proj - ) - assert isinstance( - self_attention_linear_proj, TERowParallelLinear - ) - module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", - linear_proj=self_attention_linear_proj, - rank=1, - alpha=32, - provider=provider, - ) - self_attention_linear_qkv = module.self_attention.linear_qkv - if not isinstance( - self_attention_linear_qkv, TELayerNormColumnParallelLinear - ): - self_attention_linear_qkv = self_attention_linear_qkv.linear_qkv - assert isinstance( - self_attention_linear_qkv, TELayerNormColumnParallelLinear - ) - module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn", - linear_qkv=self_attention_linear_qkv, - rank=1, - alpha=32, - provider=provider, - ) - assert isinstance(module.mlp.experts, TEGroupedMLP) - mlp_experts_linear_fc1 = module.mlp.experts.linear_fc1 - if not isinstance( - mlp_experts_linear_fc1, - TEColumnParallelGroupedLinear, # type: ignore - ): - mlp_experts_linear_fc1 = mlp_experts_linear_fc1.linear_fc1 - assert isinstance( - mlp_experts_linear_fc1, - TEColumnParallelGroupedLinear, # type: ignore - ) - module.mlp.experts.linear_fc1 = MLPExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc1=mlp_experts_linear_fc1, - rank=1, - alpha=32, - num_local_experts=module.mlp.experts.num_local_experts, - ) - mlp_experts_linear_fc2 = module.mlp.experts.linear_fc2 - if not isinstance( - mlp_experts_linear_fc2, - TERowParallelGroupedLinear, # type: ignore - ): - mlp_experts_linear_fc2 = mlp_experts_linear_fc2.linear_fc2 - assert isinstance( - mlp_experts_linear_fc2, - TERowParallelGroupedLinear, # type: ignore - ) - module.mlp.experts.linear_fc2 = MLPExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, - rank=1, - alpha=32, - num_local_experts=module.mlp.experts.num_local_experts, - ) +) -> list[torch.nn.Module]: + def _unwrap_attr(value: Any, attr_name: str, expected_type: type[Any]) -> Any: + if isinstance(value, expected_type): + return value + unwrapped = getattr(value, attr_name) + assert isinstance(unwrapped, expected_type) + return unwrapped + + for chunk in model: + for module in chunk.modules(): + if isinstance(module, TransformerLayer): + adapter_model_prefix = ( + f"base_model.model.model.layers.{module.layer_number - 1}" + ) + assert isinstance(module.self_attention, SelfAttention) + self_attention_linear_proj = _unwrap_attr( + module.self_attention.linear_proj, + "linear_proj", + TERowParallelLinear, + ) + module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", + linear_proj=self_attention_linear_proj, + rank=1, + alpha=32, + provider=provider, + ) + self_attention_linear_qkv = _unwrap_attr( + module.self_attention.linear_qkv, + "linear_qkv", + TELayerNormColumnParallelLinear, + ) + module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn", + linear_qkv=self_attention_linear_qkv, + rank=1, + alpha=32, + provider=provider, + ) + assert isinstance(module.mlp.experts, TEGroupedMLP) + mlp_experts_linear_fc1 = _unwrap_attr( + module.mlp.experts.linear_fc1, + "linear_fc1", + TEColumnParallelGroupedLinear, # type: ignore[arg-type] + ) + module.mlp.experts.linear_fc1 = MLPExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc1=mlp_experts_linear_fc1, + rank=1, + alpha=32, + num_local_experts=module.mlp.experts.num_local_experts, + ) + mlp_experts_linear_fc2 = _unwrap_attr( + module.mlp.experts.linear_fc2, + "linear_fc2", + TERowParallelGroupedLinear, # type: ignore[arg-type] + ) + module.mlp.experts.linear_fc2 = MLPExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc2=mlp_experts_linear_fc2, + rank=1, + alpha=32, + num_local_experts=module.mlp.experts.num_local_experts, + ) + return list(model) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 480a03be..abc2ef7b 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -20,20 +20,21 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: import math import shutil import time -from typing import Any, cast +from typing import Any, Callable, cast from megatron.core import parallel_state as ps from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer.module import MegatronModule -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from safetensors.torch import load_file, save_file import torch from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir from art import dev, types from art.loss import loss_fn, shift_tensor +from art.megatron.finalize_grads import finalize_model_grads_extended from art.megatron.flex_attention import create_shared_prefix_attention_state from art.megatron.lora import apply_lora_adapters from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu @@ -44,9 +45,41 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: packed_tensors_from_dir, ) -provider = get_provider( - os.environ.get("MODEL_IDENTIFIER", "Qwen/Qwen3-30B-A3B-Instruct-2507") -) +DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507" + + +class TrainingJob(BaseModel): + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dev.TrainConfig + + +class TrainingRuntime(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + provider: Any + model: list[MegatronModule] + optimizer: Any + rank: int + world_size: int + + +class TrainStepResult(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + reduced_loss: torch.Tensor + probs_corr: float + new_logprobs: torch.Tensor + update_successful: bool + grad_norm: float + num_zeros_in_grad: int | None + + +def print0(rank: int, *values: Any) -> None: + if rank == 0: + print(*values) def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: @@ -56,275 +89,396 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: return model_chunks -provider.register_pre_wrap_hook(lambda x: freeze_model(x) or x) - -model = provider.provide_distributed_model( - ddp_config=DistributedDataParallelConfig(), - data_parallel_random_init=False, -) - -rank = torch.distributed.get_rank() # ty:ignore[possibly-missing-attribute] -world_size = torch.distributed.get_world_size() # ty:ignore[possibly-missing-attribute] - -if rank == 0: - print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) - print("Resolved inductor cache_dir():", inductor_cache_dir()) - print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) +def _install_gpt_preprocess_hook(model_chunks: list[MegatronModule]) -> None: + for chunk in model_chunks: + module: Any = chunk + while not isinstance(module, GPTModel) and hasattr(module, "module"): + module = module.module + if not isinstance(module, GPTModel): + continue + preprocess = module._preprocess -for module in model: - while not isinstance(module, GPTModel) and hasattr(module, "module"): - module = module.module - if isinstance(module, GPTModel): - _preprocess = module._preprocess - - def _preprocess_hook(*args, **kwargs): + def preprocess_hook(*args, _preprocess=preprocess, **kwargs): preproc_output = list(_preprocess(*args, **kwargs)) - preproc_output[0].requires_grad = True # type: ignore - table = preproc_output[1] # [S,B,1,D] type: ignore - D = table.size(-1) # type: ignore - table_flat = table.view(table.size(0), D) # type: ignore - # position_ids: [B, S] - position_ids = kwargs["position_ids"] - B, S = position_ids.shape - gathered = table_flat.index_select(0, position_ids.reshape(-1)) # [B*S, D] - gathered = gathered.view(B, S, D).permute(1, 0, 2).contiguous() # [S, B, D] + preproc_output[0].requires_grad = True # type: ignore[index] + table = preproc_output[1] # [S, B, 1, D] # type: ignore[index] + embedding_dim = table.size(-1) + table_flat = table.view(table.size(0), embedding_dim) + position_ids = kwargs["position_ids"] # [B, S] + batch_size, sequence_length = position_ids.shape + gathered = table_flat.index_select(0, position_ids.reshape(-1)) + gathered = ( + gathered.view(batch_size, sequence_length, embedding_dim) + .permute(1, 0, 2) + .contiguous() + ) preproc_output[1] = gathered.unsqueeze(2) # [S, B, 1, D] return tuple(preproc_output) - module._preprocess = _preprocess_hook # type: ignore[attr-defined] - + module._preprocess = preprocess_hook # type: ignore[attr-defined] -apply_lora_adapters(model, provider) -optimizer = get_megatron_optimizer( - config=OptimizerConfig( +def _default_optimizer_config() -> OptimizerConfig: + return OptimizerConfig( bf16=True, lr=5e-6, adam_beta1=0.9, adam_beta2=0.99, clip_grad=0.1, weight_decay=0.1, - ), - model_chunks=model, # type: ignore -) + ) + -if rank == 0: - # Print the number of parameters in the optimizer, nicely formatted - num_params = sum( - p.numel() - for group in optimizer.param_groups - if not group["is_decoupled_lr"] - for p in group["params"] +def build_training_runtime( + *, + model_identifier: str | None = None, + provider_configure: Callable[[Any], None] | None = None, + optimizer_config: OptimizerConfig | None = None, + print_env: bool = True, + print_optimizer_stats: bool = True, +) -> TrainingRuntime: + provider = get_provider( + model_identifier or os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER) + ) + if provider_configure is not None: + provider_configure(provider) + provider.register_pre_wrap_hook(freeze_model) + provider.register_pre_wrap_hook( + lambda chunks: apply_lora_adapters(chunks, provider) ) - print(f"Number of parameters in optimizer: {num_params:,}") - total_params = sum(p.numel() for m in model for p in m.parameters()) - percent = (num_params / total_params) * 100 if total_params > 0 else 0 - print(f"Optimizer parameters as percent of total: {percent:0.2f}%") + model = cast( + list[MegatronModule], + provider.provide_distributed_model( + ddp_config=DistributedDataParallelConfig(), + data_parallel_random_init=False, + ), + ) -class TrainingJob(BaseModel): - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before building runtime" + ) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + if rank == 0 and print_env: + print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) + print("Resolved inductor cache_dir():", inductor_cache_dir()) + print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) -def print0(*values: Any) -> None: - if rank == 0: - print(*values) + _install_gpt_preprocess_hook(model) + optimizer = get_megatron_optimizer( + config=optimizer_config or _default_optimizer_config(), + model_chunks=model, + ) + + if rank == 0 and print_optimizer_stats: + num_params = sum( + p.numel() + for group in optimizer.param_groups + if not group["is_decoupled_lr"] + for p in group["params"] + ) + print(f"Number of parameters in optimizer: {num_params:,}") + total_params = sum(p.numel() for module in model for p in module.parameters()) + percent = (num_params / total_params) * 100 if total_params > 0 else 0 + print(f"Optimizer parameters as percent of total: {percent:0.2f}%") + + return TrainingRuntime( + provider=provider, + model=model, + optimizer=optimizer, + rank=rank, + world_size=world_size, + ) + + +def iter_modules(model_chunks: list[MegatronModule]) -> Any: + for chunk in model_chunks: + for module in chunk.modules(): + yield module + + +def load_adapter_into_model( + model_chunks: list[MegatronModule], + adapter_model: dict[str, torch.Tensor], +) -> None: + with torch.no_grad(): + for module in iter_modules(model_chunks): + if hasattr(module, "load_lora"): + module.load_lora(adapter_model) # type: ignore[attr-defined] + + +def collect_sharded_lora_state( + model_chunks: list[MegatronModule], + adapter_model: dict[str, torch.Tensor], +) -> tuple[dict[str, torch.Tensor], dict[str, dict[str, Any]]]: + sharded_state_dict: dict[str, torch.Tensor] = {} + sharded_state_manifest: dict[str, dict[str, Any]] = {} + for module in iter_modules(model_chunks): + if hasattr(module, "sharded_lora_state_dict"): + module_sharded_lora_state_dict: dict[str, torch.Tensor] = ( + module.sharded_lora_state_dict() # type: ignore[attr-defined] + ) + for key, value in module_sharded_lora_state_dict.items(): + target_dtype = ( + adapter_model[key].dtype if key in adapter_model else value.dtype + ) + sharded_state_dict[key] = value.to(target_dtype) + if hasattr(module, "sharded_lora_manifest"): + module_sharded_lora_manifest: dict[str, dict[str, Any]] = ( + module.sharded_lora_manifest() # type: ignore[attr-defined] + ) + sharded_state_manifest.update(module_sharded_lora_manifest) + return sharded_state_dict, sharded_state_manifest + + +def select_indexed_inputs(packed_tensors: PackedTensors, index: int) -> PackedTensors: + return PackedTensors( # type: ignore[call-arg] + **{ + key: value[index : index + 1] + for key, value in packed_tensors.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=[None], + image_grid_thw=[None], + ) + + +def _move_inputs_to_device(inputs: PackedTensors, device: torch.device) -> None: + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.to(device) # type: ignore[index] -offload_state = OffloadState() +def _finalize_grads(model_chunks: list[MegatronModule]) -> None: + finalize_model_grads_extended(cast(list[torch.nn.Module], model_chunks)) -offload_to_cpu(model, optimizer, rank, offload_state) -while True: - torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - job_names = sorted( - job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") +def _optimizer_step( + optimizer: Any, + learning_rate: float, +) -> tuple[bool, float, int | None]: + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + update_successful, grad_norm, num_zeros_in_grad = cast( + tuple[bool, float, int | None], optimizer.step() ) - if not job_names: - time.sleep(1) - continue - - wake_lock_path = "/tmp/megatron_vllm_waking" - while os.path.exists(wake_lock_path): - time.sleep(0.2) - - reload_to_gpu(model, optimizer, rank, offload_state) - - job_name = job_names[0] - job_path = os.path.join(jobs_dir, job_name) - with open(job_path, "rb") as f: - job = TrainingJob.model_validate_json(f.read()) - config = job.config - experimental_config = job.experimental_config - print0("Loaded job from", job_path) - print0("Job:", job) - adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" - if os.path.exists(adapter_model_path): - print0("Loading adapter model from", adapter_model_path) - adapter_model = load_file(adapter_model_path) - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "load_lora"): - module.load_lora(adapter_model) # type: ignore - else: - print0("No adapter model found at", adapter_model_path) - adapter_model = {} - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "reset_lora_parameters"): - module.reset_lora_parameters() # type: ignore - optimizer_shard_path = os.path.join( - job.optimizer_state_path, f"{rank + 1:02d}-of-{world_size:02d}.pt" + optimizer.zero_grad() + return update_successful, grad_norm, num_zeros_in_grad + + +def _reduce_loss(loss: torch.Tensor) -> torch.Tensor: + reduced_loss = loss.detach().clone() + torch.distributed.all_reduce(reduced_loss, op=torch.distributed.ReduceOp.AVG) + return reduced_loss + + +def run_training_step( + *, + model_chunks: list[MegatronModule], + optimizer: Any, + learning_rate: float, + inputs: PackedTensors, + config: types.TrainConfig, + experimental_config: dev.TrainConfig, + ref_logprobs: torch.Tensor | None = None, +) -> TrainStepResult: + device = next(model_chunks[0].parameters()).device + _move_inputs_to_device(inputs, device) + + attention_state = create_shared_prefix_attention_state( + group_ids=inputs["group_ids"], + parent_ids=inputs["parent_ids"], ) - if os.path.exists(optimizer_shard_path): - print( - "Loading optimizer state from", - optimizer_shard_path, - ) - optimizer.load_state_dict(torch.load(optimizer_shard_path)) - else: - # No checkpoint for this run; reset optimizer state to avoid cross-run leakage - print( - "No optimizer state found at", - optimizer_shard_path, - "— resetting optimizer for new run", - ) - optimizer.optimizer.state.clear() - optimizer.reload_model_params() - print0("Loading packed tensors from", job.disk_packed_tensors["dir"]) - packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) - num_sequences = job.disk_packed_tensors["num_sequences"] - dp_rank = ps.get_data_parallel_rank() - dp_world_size = ps.get_data_parallel_world_size() - num_indices = math.ceil(num_sequences / dp_world_size) - indices = list(range(dp_rank, num_sequences, dp_world_size)) - if not indices: - indices = [dp_rank % num_sequences] - # pad indices by repeating & slicing to target length - repeat = math.ceil(num_indices / len(indices)) - indices = (indices * repeat)[:num_indices] - for index in indices: - inputs = PackedTensors( # type: ignore - **{ - key: value[index : index + 1] - for key, value in packed_tensors.items() - if isinstance(value, torch.Tensor) - }, - pixel_values=[None], - image_grid_thw=[None], + attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) + + for chunk in model_chunks: + cast(Any, chunk).zero_grad_buffer() + + new_logprobs: torch.Tensor = -model_chunks[0]( + input_ids=inputs["tokens"], + position_ids=inputs["input_pos"], + attention_mask=attention_mask, + labels=shift_tensor(inputs["tokens"], 0), + extra_block_kwargs={"attention_bias": attention_state}, + ) + + loss_info = loss_fn( + cast(Any, inputs), + new_logprobs, + ref_logprobs, + None, + experimental_config, + ) + loss = loss_info.mean_policy_loss + config.beta * loss_info.mean_kl + loss.backward() + _finalize_grads(model_chunks) + update_successful, grad_norm, num_zeros_in_grad = _optimizer_step( + optimizer, + learning_rate, + ) + reduced_loss = _reduce_loss(loss) + + return TrainStepResult( + reduced_loss=reduced_loss, + probs_corr=float(loss_info.probs_corr.item()), + new_logprobs=new_logprobs, + update_successful=update_successful, + grad_norm=grad_norm, + num_zeros_in_grad=num_zeros_in_grad, + ) + + +def _run_service_loop(runtime: TrainingRuntime) -> None: + offload_state = OffloadState() + offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) + + while True: + torch.distributed.barrier() + jobs_dir = "/tmp/megatron_training_jobs" + os.makedirs(jobs_dir, exist_ok=True) + job_names = sorted( + job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") ) - ref_logprobs = None - device = next(model[0].parameters()).device - for key, value in inputs.items(): - if isinstance(value, torch.Tensor): - inputs[key] = value.to(device) # type: ignore - attention_state = create_shared_prefix_attention_state( # should happen after group_ids is moved to device - group_ids=inputs["group_ids"], - parent_ids=inputs["parent_ids"], + if not job_names: + time.sleep(1) + continue + + wake_lock_path = "/tmp/megatron_vllm_waking" + while os.path.exists(wake_lock_path): + time.sleep(0.2) + + reload_to_gpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) + + job_name = job_names[0] + job_path = os.path.join(jobs_dir, job_name) + with open(job_path, "rb") as handle: + job = TrainingJob.model_validate_json(handle.read()) + config = job.config + experimental_config = job.experimental_config + + print0(runtime.rank, "Loaded job from", job_path) + print0(runtime.rank, "Job:", job) + + adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" + if not os.path.exists(adapter_model_path): + raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") + print0(runtime.rank, "Loading adapter model from", adapter_model_path) + adapter_model = load_file(adapter_model_path) + load_adapter_into_model(runtime.model, adapter_model) + + optimizer_shard_path = os.path.join( + job.optimizer_state_path, + f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", ) - # Megatron full-layer recompute saves positional tensor args, so keep a tiny - # placeholder Tensor here and pass flex BlockMask state via attention_bias. - attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) - new_logprobs: torch.Tensor = -model[0]( - input_ids=inputs["tokens"], - position_ids=inputs["input_pos"], - attention_mask=attention_mask, - labels=shift_tensor(inputs["tokens"], 0), - extra_block_kwargs={"attention_bias": attention_state}, + if os.path.exists(optimizer_shard_path): + print("Loading optimizer state from", optimizer_shard_path) + runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path)) + else: + print( + "No optimizer state found at", + optimizer_shard_path, + "- resetting optimizer for new run", + ) + runtime.optimizer.optimizer.state.clear() + runtime.optimizer.reload_model_params() + + print0( + runtime.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"] ) - loss = loss_fn( - inputs, # type: ignore - new_logprobs, - ref_logprobs, - None, - experimental_config, + packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + num_sequences = job.disk_packed_tensors["num_sequences"] + + dp_rank = ps.get_data_parallel_rank() + dp_world_size = ps.get_data_parallel_world_size() + num_indices = math.ceil(num_sequences / dp_world_size) + indices = list(range(dp_rank, num_sequences, dp_world_size)) + if not indices: + indices = [dp_rank % num_sequences] + repeat = math.ceil(num_indices / len(indices)) + indices = (indices * repeat)[:num_indices] + + for index in indices: + inputs = select_indexed_inputs(packed_tensors, index) + step_result = run_training_step( + model_chunks=runtime.model, + optimizer=runtime.optimizer, + learning_rate=config.learning_rate, + inputs=inputs, + config=config, + experimental_config=experimental_config, + ref_logprobs=None, + ) + print0( + runtime.rank, + "Correlation between old and new probabilities:", + step_result.probs_corr, + ) + + if runtime.rank == 0: + with open( + "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" + ) as log_file: + log_msg = json.dumps( + { + "loss": step_result.reduced_loss.item(), + "grad_norm": step_result.grad_norm, + "probs_corr": step_result.probs_corr, + } + ) + print("Logging", log_msg) + log_file.write(log_msg + "\n") + + sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state( + runtime.model, + adapter_model, ) - probs_corr = loss.probs_corr.item() - print0("Correlation between old and new probabilities:", probs_corr) - loss = loss.mean_policy_loss + config.beta * loss.mean_kl - loss.backward() - # Reduce LoRA grads - start = time.perf_counter() - num_grads = 0 - for chunk in model: - for param in chunk.parameters(): - if param.grad is None: - continue - torch.distributed.all_reduce( # ty:ignore[possibly-missing-attribute] - param.grad, - op=torch.distributed.ReduceOp.AVG, # ty:ignore[possibly-missing-attribute] - group=ps.get_data_parallel_group(), - ) - num_grads += 1 - print0( - f"Reduced {num_grads} LoRA grads in {(time.perf_counter() - start) * 1e3:.1f} ms" + shard_path = os.path.join( + job.lora_path, + f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors", ) - for param_group in optimizer.param_groups: - param_group["lr"] = config.learning_rate - update_successful, grad_norm, num_zeros_in_grad = cast( - tuple[bool, float, int | None], optimizer.step() + manifest_path = os.path.join( + job.lora_path, + f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json", ) - optimizer.zero_grad() - - # Mean reduce loss across all ranks for logging - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) # ty:ignore[possibly-missing-attribute] - - if rank == 0: - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_msg = json.dumps( - { - "loss": loss.item(), - "grad_norm": grad_norm, - "probs_corr": probs_corr, - } - ) - print("Logging", log_msg) - log_file.write(log_msg + "\n") - - sharded_state_dict = {} - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "sharded_lora_state_dict"): - module_sharded_lora_state_dict: dict[str, torch.Tensor] = ( - module.sharded_lora_state_dict() # type: ignore - ) - for key, value in module_sharded_lora_state_dict.items(): - target_dtype = ( - adapter_model[key].dtype - if key in adapter_model - else value.dtype - ) - sharded_state_dict[key] = value.to(target_dtype) - shard_path = os.path.join( - job.lora_path, - f"adapter_model-{rank + 1:02d}-of-{world_size:02d}.safetensors", + print("Saving adapter shard to", shard_path) + save_file(sharded_state_dict, shard_path) + print("Saving adapter shard manifest to", manifest_path) + with open(manifest_path, "w", encoding="utf-8") as manifest_file: + json.dump(sharded_state_manifest, manifest_file, sort_keys=True) + + print("Saving optimizer shard to", optimizer_shard_path) + os.makedirs(job.optimizer_state_path, exist_ok=True) + torch.save(runtime.optimizer.state_dict(), optimizer_shard_path) + + offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) + + del packed_tensors + del adapter_model + if "inputs" in locals(): + del inputs + gc.collect() + torch.cuda.empty_cache() + + torch.distributed.barrier() + if runtime.rank == 0: + os.remove(job_path) + with open( + "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" + ) as log_file: + log_file.write("all done\n") + shutil.rmtree(job.disk_packed_tensors["dir"]) + + +def main() -> None: + runtime = build_training_runtime( + model_identifier=os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER) ) - print("Saving adapter shard to", shard_path) - save_file(sharded_state_dict, shard_path) - print("Saving optimizer shard to", optimizer_shard_path) - os.makedirs(job.optimizer_state_path, exist_ok=True) - torch.save(optimizer.state_dict(), optimizer_shard_path) - offload_to_cpu(model, optimizer, rank, offload_state) - # Release mmap-backed packed tensor references on all ranks before rank0 cleanup. - del packed_tensors - del adapter_model - if "inputs" in locals(): - del inputs - gc.collect() - torch.cuda.empty_cache() - # Ensure all ranks have finished saving before signaling completion - torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] - if rank == 0: - os.remove(job_path) - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_file.write("all done\n") - shutil.rmtree(job.disk_packed_tensors["dir"]) + _run_service_loop(runtime) + + +if __name__ == "__main__": + main() From 112e97cde67deced438a473c5482467f0d5ff7d5 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 10 Mar 2026 05:13:03 +0000 Subject: [PATCH 02/19] megatron: harden sharded lora merge validation --- src/art/megatron/service.py | 91 +++++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 13 deletions(-) diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 8ed6b82c..e4c99a98 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -12,7 +12,7 @@ from peft.tuners.lora.config import LoraConfig from pydantic import BaseModel from safetensors import safe_open -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file import torch from vllm import AsyncEngineArgs from vllm.lora.request import LoRARequest @@ -311,26 +311,91 @@ def _merge_lora_adapter(self, lora_path: str) -> None: if not shard_filenames: return - adapter_model_path = base_dir / "adapter_model.safetensors" - sharded_tensors: dict[str, list[torch.Tensor]] = {} + shard_files_by_suffix = { + path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path + for path in shard_filenames + } + manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json")) + manifest_files_by_suffix = { + path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path + for path in manifest_filenames + } - for filename in shard_filenames: - with safe_open(filename, framework="pt") as file: - for key in file.keys(): - tensor = file.get_tensor(key) - sharded_tensors.setdefault(key, []).append(tensor) + if set(shard_files_by_suffix) != set(manifest_files_by_suffix): + raise RuntimeError( + "Shard/manifest coverage mismatch: " + f"shards={sorted(shard_files_by_suffix)}, " + f"manifests={sorted(manifest_files_by_suffix)}" + ) - adapter_model: dict[str, torch.Tensor] = {} - if adapter_model_path.exists(): - adapter_model = load_file(adapter_model_path) + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {} + for suffix in sorted(shard_files_by_suffix): + shard_path = shard_files_by_suffix[suffix] + manifest_path = manifest_files_by_suffix[suffix] + with open(manifest_path, "r", encoding="utf-8") as manifest_file: + shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file) - for key, tensors in sharded_tensors.items(): - tensor = torch.cat(tensors, dim=1 if "lora_A" in key else 0) + with safe_open(shard_path, framework="pt") as file: + shard_tensors = {key: file.get_tensor(key) for key in file.keys()} + + if set(shard_tensors) != set(shard_manifest): + raise RuntimeError( + f"Tensor/manifest key mismatch for shard suffix={suffix}: " + f"tensor_keys={sorted(shard_tensors)}, " + f"manifest_keys={sorted(shard_manifest)}" + ) + + for key, tensor in shard_tensors.items(): + entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor)) + + adapter_model: dict[str, torch.Tensor] = {} + for key, key_entries in entries_by_key.items(): + first_manifest = key_entries[0][0] + sharded = bool(first_manifest["sharded"]) + shard_world_size = int(first_manifest["shard_world_size"]) + + for manifest_entry, _tensor in key_entries: + if bool(manifest_entry["sharded"]) != sharded: + raise RuntimeError(f"Inconsistent sharded flag for key={key}") + if int(manifest_entry["shard_world_size"]) != shard_world_size: + raise RuntimeError(f"Inconsistent shard world size for key={key}") + + if not sharded: + if len(key_entries) != 1: + raise RuntimeError( + f"Replicated key={key} expected 1 shard, got {len(key_entries)}" + ) + tensor = key_entries[0][1] + else: + shard_rank_to_tensor: dict[int, torch.Tensor] = {} + for manifest_entry, shard_tensor in key_entries: + shard_rank = int(manifest_entry["shard_rank"]) + if shard_rank in shard_rank_to_tensor: + raise RuntimeError( + f"Duplicate shard_rank={shard_rank} for key={key}" + ) + shard_rank_to_tensor[shard_rank] = shard_tensor + + expected_shard_ranks = set(range(shard_world_size)) + if set(shard_rank_to_tensor.keys()) != expected_shard_ranks: + raise RuntimeError( + f"Shard rank coverage mismatch for key={key}: " + f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor.keys())}" + ) + + ordered_shards = [ + shard_rank_to_tensor[i] for i in range(shard_world_size) + ] + concat_dim = 1 if "lora_A" in key else 0 + tensor = torch.cat(ordered_shards, dim=concat_dim) adapter_model[key] = tensor + adapter_model_path = base_dir / "adapter_model.safetensors" save_file(adapter_model, adapter_model_path) for filename in shard_filenames: filename.unlink() + for filename in manifest_filenames: + filename.unlink() @cached_property def llm(self) -> asyncio.Task[AsyncLLM]: From 4d5c3454ff2e7d3e40436628fb4e0cc51dc7af98 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 10 Mar 2026 05:13:25 +0000 Subject: [PATCH 03/19] tests: add megatron lora oracle correctness harness --- tests/integration/megatron_oracle_harness.py | 1189 +++++++++++++++++ .../test_megatron_lora_oracle_correctness.py | 100 ++ 2 files changed, 1289 insertions(+) create mode 100644 tests/integration/megatron_oracle_harness.py create mode 100644 tests/integration/test_megatron_lora_oracle_correctness.py diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py new file mode 100644 index 00000000..0d61b41f --- /dev/null +++ b/tests/integration/megatron_oracle_harness.py @@ -0,0 +1,1189 @@ +from __future__ import annotations + +import argparse +from contextlib import contextmanager +import hashlib +import json +import os +from pathlib import Path +import random +import shutil +import subprocess +import sys +from typing import Any, Callable, Literal, cast + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +REPO_ROOT = Path(__file__).resolve().parents[2] +ARTIFACT_ROOT = Path(REPO_ROOT / ".local/megatron_lora_oracles") + +REGENERATE_ENV = "ART_REGENERATE_MEGATRON_ORACLE" +BASE_MODEL_ENV = "ART_MEGATRON_ORACLE_BASE_MODEL" +DP_SUPPORT_ENV = "ART_MEGATRON_ORACLE_ENABLE_DP_PHASE_B" +SENSITIVITY_MUTATION_ENV = "ART_MEGATRON_ORACLE_MUTATION" + +SensitivityMutation = Literal["drop_finalize"] + +REQUIRED_PACKED_TENSOR_FILES = ( + "tokens.pt", + "group_ids.pt", + "parent_ids.pt", + "input_pos.pt", + "assistant_mask.pt", + "logprobs.pt", + "advantages.pt", + "weights.pt", +) + + +class Topology(BaseModel): + model_config = ConfigDict(frozen=True) + + tp: int + ep: int + etp: int = 1 + dp: int = 1 + sp: int = 0 + phase: Literal["A", "B"] = "A" + + def slug(self) -> str: + return f"tp{self.tp}_ep{self.ep}_etp{self.etp}_dp{self.dp}_sp{self.sp}" + + def world_size(self) -> int: + return self.tp * self.ep * self.etp * self.dp + + +class PackedTensorConfig(BaseModel): + num_sequences: int = 8 + sequence_length: int = 256 + prefill_tokens: int = 64 + decode_tokens: int = 64 + vocab_high: int = 8192 + + +class LoraConfig(BaseModel): + rank: int = 1 + alpha: int = 32 + target_modules: list[str] = Field( + default_factory=lambda: [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + ) + + +class ToleranceProfile(BaseModel): + outputs_abs: float = 1e-2 + outputs_rel: float = 1e-2 + losses_abs: float = 1e-4 + losses_rel: float = 1e-4 + grads_abs: float = 1e-2 + grads_rel: float = 1e-2 + deltas_abs: float = 1e-2 + deltas_rel: float = 1e-2 + + +class OracleCaseConfig(BaseModel): + base_model: str + seed: int = 20260305 + num_steps: int = 3 + learning_rate: float = 5e-6 + beta: float = 0.0 + packed_tensors: PackedTensorConfig = Field(default_factory=PackedTensorConfig) + lora: LoraConfig = Field(default_factory=LoraConfig) + tolerances: ToleranceProfile = Field(default_factory=ToleranceProfile) + + +class DiskPackedTensorsSpec(BaseModel): + dir: str + num_sequences: int + sequence_length: int + pixel_values: tuple[int, list[int]] | None = None + image_grid_thw: tuple[int, list[int]] | None = None + + +class CaseArtifacts(BaseModel): + case_id: str + case_dir: str + packed_tensors: DiskPackedTensorsSpec + shared_init_adapter_path: str + + +class WorkerRunRequest(BaseModel): + case_id: str + case_config: OracleCaseConfig + topology: Topology + topology_dir: str + packed_tensors: DiskPackedTensorsSpec + shared_init_adapter_path: str + allow_create_shared_init: bool = False + mutation: SensitivityMutation | None = None + + +class StepTrace(BaseModel): + step_index: int + loss: float + probs_corr: float + output_file: str + grads_file: str + deltas_file: str + lora_file: str + + +class RunManifest(BaseModel): + case_id: str + base_model: str + topology: str + world_size: int + seed: int + num_steps: int + packed_tensors: DiskPackedTensorsSpec + tolerances: ToleranceProfile + steps: list[StepTrace] + + +class ComparisonFailure(BaseModel): + case_id: str + topology: str + oracle_topology: str + metric: Literal["outputs", "losses", "grads", "lora_deltas"] + step_index: int + key: str + max_abs_error: float + max_rel_error: float + abs_tolerance: float + rel_tolerance: float + message: str + + +PHASE_A_TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=1, sp=0, phase="A"), + Topology(tp=2, ep=1, etp=1, dp=1, sp=1, phase="A"), + Topology(tp=1, ep=2, etp=1, dp=1, sp=0, phase="A"), + Topology(tp=2, ep=2, etp=1, dp=1, sp=1, phase="A"), +] +PHASE_B_TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=2, sp=0, phase="B"), + Topology(tp=2, ep=1, etp=1, dp=2, sp=1, phase="B"), +] +ORACLE_TOPOLOGY = PHASE_A_TOPOLOGIES[0] +SENSITIVITY_TOPOLOGY = PHASE_A_TOPOLOGIES[1] + + +def _truthy(value: str | None) -> bool: + if value is None: + return False + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def sensitivity_mutation() -> SensitivityMutation | None: + raw = os.environ.get(SENSITIVITY_MUTATION_ENV) + if raw is None or raw.strip() == "": + return None + normalized = raw.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return "drop_finalize" + if normalized == "drop_finalize": + return "drop_finalize" + raise ValueError( + f"Unsupported {SENSITIVITY_MUTATION_ENV} value '{raw}'. " + "Supported values: drop_finalize, 1/true/yes/on." + ) + + +def sensitivity_enabled() -> bool: + return sensitivity_mutation() is not None + + +def phase_b_dp_enabled() -> bool: + return _truthy(os.environ.get(DP_SUPPORT_ENV)) + + +def regenerate_requested() -> bool: + return _truthy(os.environ.get(REGENERATE_ENV)) + + +def default_case_config() -> OracleCaseConfig: + def _env_float(name: str, default: str) -> float: + return float(os.environ.get(name, default)) + + tolerances = ToleranceProfile( + outputs_abs=_env_float("ART_MEGATRON_ORACLE_OUTPUTS_ABS_TOL", "1e-2"), + outputs_rel=_env_float("ART_MEGATRON_ORACLE_OUTPUTS_REL_TOL", "1e-2"), + losses_abs=_env_float("ART_MEGATRON_ORACLE_LOSSES_ABS_TOL", "1e-4"), + losses_rel=_env_float("ART_MEGATRON_ORACLE_LOSSES_REL_TOL", "1e-4"), + grads_abs=_env_float("ART_MEGATRON_ORACLE_GRADS_ABS_TOL", "1e-2"), + grads_rel=_env_float("ART_MEGATRON_ORACLE_GRADS_REL_TOL", "1e-2"), + deltas_abs=_env_float("ART_MEGATRON_ORACLE_DELTAS_ABS_TOL", "1e-2"), + deltas_rel=_env_float("ART_MEGATRON_ORACLE_DELTAS_REL_TOL", "1e-2"), + ) + return OracleCaseConfig( + base_model=os.environ.get( + BASE_MODEL_ENV, + "Qwen/Qwen3-30B-A3B-Instruct-2507", + ), + seed=int(os.environ.get("ART_MEGATRON_ORACLE_SEED", "20260305")), + num_steps=int(os.environ.get("ART_MEGATRON_ORACLE_NUM_STEPS", "3")), + learning_rate=float(os.environ.get("ART_MEGATRON_ORACLE_LR", "5e-6")), + beta=float(os.environ.get("ART_MEGATRON_ORACLE_BETA", "0.0")), + tolerances=tolerances, + ) + + +def available_gpu_count() -> int: + import torch + + if not torch.cuda.is_available(): + return 0 + return int(torch.cuda.device_count()) + + +def stable_case_id(case_config: OracleCaseConfig) -> str: + payload = case_config.model_dump(mode="json") + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) + digest = hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:16] + model_tag = ( + case_config.base_model.replace("/", "_") + .replace("-", "_") + .replace(".", "_") + .lower() + ) + return f"{model_tag}_{digest}" + + +def _write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def _build_packed_tensors( + config: PackedTensorConfig, + seed: int, +) -> dict[str, Any]: + import torch + + if config.num_sequences <= 1: + raise ValueError("num_sequences must be greater than 1") + shape = (config.num_sequences, config.sequence_length) + generator = torch.Generator().manual_seed(seed) + tokens = torch.randint( + low=10, + high=config.vocab_high, + size=shape, + dtype=torch.long, + generator=generator, + ) + group_ids = torch.zeros(shape, dtype=torch.long) + parent_ids = torch.full(shape, -1, dtype=torch.long) + input_pos = ( + torch.arange(config.sequence_length, dtype=torch.long) + .unsqueeze(0) + .expand(config.num_sequences, -1) + .clone() + ) + prefix_length = max(1, min(config.sequence_length - 1, config.prefill_tokens)) + decode_span = max(1, config.decode_tokens) + cursor = prefix_length + branch = 1 + while cursor < config.sequence_length: + end = min(config.sequence_length, cursor + decode_span) + group_ids[:, cursor:end] = branch + parent_ids[:, cursor:end] = 0 + cursor = end + branch += 1 + assistant_mask = torch.zeros(shape, dtype=torch.bool) + assistant_mask[:, prefix_length:] = True + logprobs = ( + torch.randn( + shape, + generator=generator, + dtype=torch.float32, + ) + * 0.25 + - 1.75 + ) + advantages = ( + torch.randn( + shape, + generator=generator, + dtype=torch.float32, + ) + * 0.1 + + 1.0 + ) + weights = torch.ones(shape, dtype=torch.float32) + return { + "tokens": tokens, + "group_ids": group_ids, + "parent_ids": parent_ids, + "input_pos": input_pos, + "assistant_mask": assistant_mask, + "logprobs": logprobs, + "advantages": advantages, + "weights": weights, + "pixel_values": [None] * config.num_sequences, + "image_grid_thw": [None] * config.num_sequences, + } + + +def _create_packed_tensors( + case_config: OracleCaseConfig, + packed_dir: Path, +) -> DiskPackedTensorsSpec: + from art.preprocessing.pack import PackedTensors, packed_tensors_to_dir + + packed_tensors = cast( + PackedTensors, + _build_packed_tensors(case_config.packed_tensors, case_config.seed), + ) + descriptor = packed_tensors_to_dir(packed_tensors, str(packed_dir)) + return DiskPackedTensorsSpec.model_validate(descriptor) + + +def _validate_packed_tensor_files(spec: DiskPackedTensorsSpec) -> None: + tensor_dir = Path(spec.dir) + for filename in REQUIRED_PACKED_TENSOR_FILES: + file_path = tensor_dir / filename + if not file_path.exists(): + raise FileNotFoundError(f"Missing packed tensor file: {file_path}") + + +def ensure_case_artifacts(case_config: OracleCaseConfig) -> CaseArtifacts: + case_id = stable_case_id(case_config) + case_dir = ARTIFACT_ROOT / case_id + case_dir.mkdir(parents=True, exist_ok=True) + _write_json(case_dir / "case_config.json", case_config.model_dump(mode="json")) + + descriptor_path = case_dir / "packed_tensors.json" + if descriptor_path.exists(): + packed_spec = DiskPackedTensorsSpec.model_validate(_read_json(descriptor_path)) + _validate_packed_tensor_files(packed_spec) + else: + packed_spec = _create_packed_tensors(case_config, case_dir / "packed_tensors") + _write_json(descriptor_path, packed_spec.model_dump(mode="json")) + + shared_init_path = case_dir / "shared_init" / "adapter_model.safetensors" + shared_init_path.parent.mkdir(parents=True, exist_ok=True) + return CaseArtifacts( + case_id=case_id, + case_dir=str(case_dir), + packed_tensors=packed_spec, + shared_init_adapter_path=str(shared_init_path), + ) + + +def _replace_topology_dir(path: Path) -> None: + if path.exists(): + shutil.rmtree(path) + path.mkdir(parents=True, exist_ok=True) + (path / "traces").mkdir(parents=True, exist_ok=True) + + +def _run_worker_subprocess(request: WorkerRunRequest, topology_dir: Path) -> None: + request_path = topology_dir / "run_request.json" + _write_json(request_path, request.model_dump(mode="json")) + + command = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(request.topology.world_size()), + str(Path(__file__).resolve()), + "--worker-run", + "--run-request", + str(request_path), + ] + run = subprocess.run( + command, + cwd=str(REPO_ROOT), + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + capture_output=True, + text=True, + check=False, + ) + combined_output = f"{run.stdout}\n{run.stderr}".strip() + (topology_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") + if run.returncode != 0: + tail = "\n".join(combined_output.splitlines()[-80:]) + raise RuntimeError( + f"Topology run failed for {request.topology.slug()} with exit code " + f"{run.returncode}.\n{tail}" + ) + + +def ensure_topology_artifacts( + case_config: OracleCaseConfig, + topology: Topology, + *, + regenerate: bool = False, + mutation: SensitivityMutation | None = None, +) -> Path: + case_artifacts = ensure_case_artifacts(case_config) + case_dir = Path(case_artifacts.case_dir) + topology_dir = case_dir / topology.slug() + manifest_path = topology_dir / "manifest.json" + if manifest_path.exists() and not regenerate: + return topology_dir + + _replace_topology_dir(topology_dir) + shared_init_path = Path(case_artifacts.shared_init_adapter_path) + allow_create_shared_init = topology.slug() == ORACLE_TOPOLOGY.slug() + if not allow_create_shared_init and not shared_init_path.exists(): + ensure_topology_artifacts( + case_config=case_config, + topology=ORACLE_TOPOLOGY, + regenerate=False, + mutation=None, + ) + if not allow_create_shared_init and not shared_init_path.exists(): + raise FileNotFoundError( + f"Oracle shared adapter missing after oracle generation: {shared_init_path}" + ) + if mutation is not None and topology.slug() == ORACLE_TOPOLOGY.slug(): + raise RuntimeError("Sensitivity mutation cannot be applied to oracle topology") + + request = WorkerRunRequest( + case_id=case_artifacts.case_id, + case_config=case_config, + topology=topology, + topology_dir=str(topology_dir), + packed_tensors=case_artifacts.packed_tensors, + shared_init_adapter_path=str(shared_init_path), + allow_create_shared_init=allow_create_shared_init, + mutation=mutation, + ) + _run_worker_subprocess(request, topology_dir) + if not manifest_path.exists(): + raise RuntimeError(f"Missing manifest after run: {manifest_path}") + return topology_dir + + +def ensure_oracle_reference_artifacts( + *, + case_config: OracleCaseConfig, + regenerate: bool = False, +) -> Path: + return ensure_topology_artifacts( + case_config=case_config, + topology=ORACLE_TOPOLOGY, + regenerate=regenerate, + mutation=None, + ) + + +def _load_manifest(topology_dir: Path) -> RunManifest: + manifest_path = topology_dir / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"Missing topology manifest: {manifest_path}") + return RunManifest.model_validate(_read_json(manifest_path)) + + +def _load_output_tensor(topology_dir: Path, step: StepTrace): + import torch + + path = topology_dir / step.output_file + if not path.exists(): + raise FileNotFoundError(f"Missing output trace: {path}") + return torch.load(path, map_location="cpu") + + +def _load_safetensor_map(path: Path) -> dict[str, Any]: + from safetensors.torch import load_file + + if not path.exists(): + raise FileNotFoundError(f"Missing safetensor trace: {path}") + return load_file(str(path)) + + +def _tensor_error(reference, candidate) -> tuple[float, float]: + ref = reference.detach().float() + cand = candidate.detach().float() + if ref.shape != cand.shape: + return float("inf"), float("inf") + if ref.numel() == 0: + return 0.0, 0.0 + diff = (cand - ref).abs() + max_abs = float(diff.max().item()) + max_rel = float((diff / ref.abs().clamp_min(1e-12)).max().item()) + return max_abs, max_rel + + +def _build_failure( + *, + case_id: str, + topology: str, + metric: Literal["outputs", "losses", "grads", "lora_deltas"], + step_index: int, + key: str, + max_abs_error: float, + max_rel_error: float, + abs_tolerance: float, + rel_tolerance: float, + message: str, +) -> ComparisonFailure: + return ComparisonFailure( + case_id=case_id, + topology=topology, + oracle_topology=ORACLE_TOPOLOGY.slug(), + metric=metric, + step_index=step_index, + key=key, + max_abs_error=max_abs_error, + max_rel_error=max_rel_error, + abs_tolerance=abs_tolerance, + rel_tolerance=rel_tolerance, + message=message, + ) + + +def _compare_tensor_pair( + *, + case_id: str, + topology: str, + metric: Literal["outputs", "losses", "grads", "lora_deltas"], + step_index: int, + key: str, + reference, + candidate, + abs_tolerance: float, + rel_tolerance: float, +) -> ComparisonFailure | None: + max_abs, max_rel = _tensor_error(reference, candidate) + if max_abs <= abs_tolerance or max_rel <= rel_tolerance: + return None + return _build_failure( + case_id=case_id, + topology=topology, + metric=metric, + step_index=step_index, + key=key, + max_abs_error=max_abs, + max_rel_error=max_rel, + abs_tolerance=abs_tolerance, + rel_tolerance=rel_tolerance, + message=f"{metric} mismatch at step {step_index}, key '{key}'", + ) + + +def _compare_tensor_maps( + *, + case_id: str, + topology: str, + metric: Literal["grads", "lora_deltas"], + step_index: int, + reference: dict[str, Any], + candidate: dict[str, Any], + abs_tolerance: float, + rel_tolerance: float, +) -> ComparisonFailure | None: + ref_keys = set(reference.keys()) + cand_keys = set(candidate.keys()) + if ref_keys != cand_keys: + missing = sorted(ref_keys - cand_keys) + extra = sorted(cand_keys - ref_keys) + return _build_failure( + case_id=case_id, + topology=topology, + metric=metric, + step_index=step_index, + key="__keys__", + max_abs_error=float("inf"), + max_rel_error=float("inf"), + abs_tolerance=abs_tolerance, + rel_tolerance=rel_tolerance, + message=( + f"{metric} key mismatch at step {step_index}; " + f"missing={missing[:3]}, extra={extra[:3]}" + ), + ) + for key in sorted(ref_keys): + failure = _compare_tensor_pair( + case_id=case_id, + topology=topology, + metric=metric, + step_index=step_index, + key=key, + reference=reference[key], + candidate=candidate[key], + abs_tolerance=abs_tolerance, + rel_tolerance=rel_tolerance, + ) + if failure is not None: + return failure + return None + + +def _write_failure_report(topology_dir: Path, failure: ComparisonFailure) -> None: + _write_json(topology_dir / "failure_report.json", failure.model_dump(mode="json")) + + +def compare_topology_to_oracle( + *, + case_config: OracleCaseConfig, + topology: Topology, +) -> ComparisonFailure | None: + if topology.slug() == ORACLE_TOPOLOGY.slug(): + return None + + case_id = stable_case_id(case_config) + case_dir = ARTIFACT_ROOT / case_id + oracle_dir = case_dir / ORACLE_TOPOLOGY.slug() + topology_dir = case_dir / topology.slug() + + oracle_manifest = _load_manifest(oracle_dir) + topology_manifest = _load_manifest(topology_dir) + if len(oracle_manifest.steps) != len(topology_manifest.steps): + return _build_failure( + case_id=case_id, + topology=topology.slug(), + metric="losses", + step_index=0, + key="__step_count__", + max_abs_error=float("inf"), + max_rel_error=float("inf"), + abs_tolerance=case_config.tolerances.losses_abs, + rel_tolerance=case_config.tolerances.losses_rel, + message=( + "Step count mismatch: " + f"oracle={len(oracle_manifest.steps)} vs " + f"topology={len(topology_manifest.steps)}" + ), + ) + + import torch + + for oracle_step, topology_step in zip( + oracle_manifest.steps, topology_manifest.steps + ): + step_index = oracle_step.step_index + oracle_outputs = _load_output_tensor(oracle_dir, oracle_step) + topology_outputs = _load_output_tensor(topology_dir, topology_step) + failure = _compare_tensor_pair( + case_id=case_id, + topology=topology.slug(), + metric="outputs", + step_index=step_index, + key="logprobs", + reference=oracle_outputs, + candidate=topology_outputs, + abs_tolerance=case_config.tolerances.outputs_abs, + rel_tolerance=case_config.tolerances.outputs_rel, + ) + if failure is not None: + return failure + + oracle_loss = torch.tensor([oracle_step.loss], dtype=torch.float32) + topology_loss = torch.tensor([topology_step.loss], dtype=torch.float32) + failure = _compare_tensor_pair( + case_id=case_id, + topology=topology.slug(), + metric="losses", + step_index=step_index, + key="loss", + reference=oracle_loss, + candidate=topology_loss, + abs_tolerance=case_config.tolerances.losses_abs, + rel_tolerance=case_config.tolerances.losses_rel, + ) + if failure is not None: + return failure + + for metric, oracle_file, topo_file, abs_tol, rel_tol in ( + ( + "grads", + oracle_step.grads_file, + topology_step.grads_file, + case_config.tolerances.grads_abs, + case_config.tolerances.grads_rel, + ), + ( + "lora_deltas", + oracle_step.deltas_file, + topology_step.deltas_file, + case_config.tolerances.deltas_abs, + case_config.tolerances.deltas_rel, + ), + ): + failure = _compare_tensor_maps( + case_id=case_id, + topology=topology.slug(), + metric=metric, + step_index=step_index, + reference=_load_safetensor_map(oracle_dir / oracle_file), + candidate=_load_safetensor_map(topology_dir / topo_file), + abs_tolerance=abs_tol, + rel_tolerance=rel_tol, + ) + if failure is not None: + return failure + return None + + +def run_and_compare_topology( + *, + case_config: OracleCaseConfig, + topology: Topology, + regenerate: bool = False, +) -> None: + ensure_oracle_reference_artifacts( + case_config=case_config, + regenerate=regenerate and topology.slug() == ORACLE_TOPOLOGY.slug(), + ) + ensure_topology_artifacts( + case_config=case_config, + topology=topology, + regenerate=regenerate, + mutation=None, + ) + failure = compare_topology_to_oracle(case_config=case_config, topology=topology) + if failure is None: + return + topology_dir = ARTIFACT_ROOT / failure.case_id / topology.slug() + _write_failure_report(topology_dir, failure) + raise AssertionError( + "Megatron oracle mismatch: " + f"topology={failure.topology}, metric={failure.metric}, " + f"step={failure.step_index}, key={failure.key}, " + f"max_abs={failure.max_abs_error:.6g}, " + f"max_rel={failure.max_rel_error:.6g}, " + f"tol_abs={failure.abs_tolerance:.6g}, " + f"tol_rel={failure.rel_tolerance:.6g}" + ) + + +def run_sensitivity_check( + *, + case_config: OracleCaseConfig, + regenerate: bool = False, +) -> None: + mutation = sensitivity_mutation() + if mutation is None: + raise RuntimeError( + f"Sensitivity check requires {SENSITIVITY_MUTATION_ENV} to be set" + ) + + ensure_oracle_reference_artifacts( + case_config=case_config, + regenerate=regenerate, + ) + ensure_topology_artifacts( + case_config=case_config, + topology=SENSITIVITY_TOPOLOGY, + regenerate=True, + mutation=mutation, + ) + failure = compare_topology_to_oracle( + case_config=case_config, + topology=SENSITIVITY_TOPOLOGY, + ) + if failure is None: + raise AssertionError( + "Sensitivity mutation did not produce an oracle mismatch. " + f"mutation={mutation}, topology={SENSITIVITY_TOPOLOGY.slug()}" + ) + + +def _set_deterministic_seed(seed: int) -> None: + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def _merge_sharded_dicts(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any]: + import torch + + merged: dict[str, list[Any]] = {} + for rank_shards in shards_by_rank: + for key, tensor in rank_shards.items(): + merged.setdefault(key, []).append(tensor.detach().cpu()) + full_state: dict[str, Any] = {} + for key, shards in merged.items(): + if len(shards) == 1: + full_state[key] = shards[0].contiguous() + continue + concat_dim = 1 if ".lora_A." in key else 0 + full_state[key] = torch.cat(shards, dim=concat_dim).contiguous() + return full_state + + +def _gather_full_state(local_state: dict[str, Any]) -> dict[str, Any] | None: + import torch + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + gathered = [None for _ in range(world_size)] if rank == 0 else None + torch.distributed.gather_object(local_state, gathered, dst=0) + if rank != 0: + return None + assert gathered is not None + entries = [entry for entry in gathered if entry is not None] + return _merge_sharded_dicts(entries) + + +def _collect_lora_state(model_chunks: list[Any]) -> dict[str, Any] | None: + local_state: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if not hasattr(module, "sharded_lora_state_dict"): + continue + module_state = module.sharded_lora_state_dict() + for key, value in module_state.items(): + if key in local_state: + raise RuntimeError( + f"Duplicate LoRA key while collecting state: {key}" + ) + local_state[key] = value.detach().cpu() + return _gather_full_state(local_state) + + +def _collect_lora_grads(model_chunks: list[Any]) -> dict[str, Any] | None: + from megatron.core import parallel_state as ps + + from art.megatron.lora import LoRA + + local_grads: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + grad_a = ( + module.A_T.grad + if module.A_T.grad is not None + else module.A_T.new_zeros(module.A_T.shape) + ) + grad_b = ( + module.B_T.grad + if module.B_T.grad is not None + else module.B_T.new_zeros(module.B_T.shape) + ) + if module.num_local_experts > 1: + if ps.get_expert_data_parallel_rank() != 0: + continue + for expert in range(module.num_local_experts): + prefix = module.adapter_model_prefix.format( + expert=expert + module._expert_offset + ) + local_grads[f"{prefix}.lora_A.weight"] = ( + grad_a[expert].detach().cpu().T + ) + local_grads[f"{prefix}.lora_B.weight"] = ( + grad_b[expert].detach().cpu().T + ) + else: + if ps.get_data_parallel_rank() != 0: + continue + local_grads[f"{module.adapter_model_prefix}.lora_A.weight"] = ( + grad_a.detach().cpu().T + ) + local_grads[f"{module.adapter_model_prefix}.lora_B.weight"] = ( + grad_b.detach().cpu().T + ) + return _gather_full_state(local_grads) + + +def _validate_adapter_exact( + expected_state: dict[str, Any], + adapter_model: dict[str, Any], +) -> None: + expected_keys = set(expected_state.keys()) + adapter_keys = set(adapter_model.keys()) + missing = sorted(expected_keys - adapter_keys) + extra = sorted(adapter_keys - expected_keys) + if missing or extra: + raise KeyError( + f"Adapter keys mismatch: missing={missing[:5]} extra={extra[:5]}" + ) + + +def _validate_loaded_state_matches_adapter( + loaded_state: dict[str, Any], + adapter_model: dict[str, Any], +) -> None: + import torch + + for key in sorted(adapter_model.keys()): + if key not in loaded_state: + raise KeyError(f"Loaded LoRA state missing key: {key}") + if not torch.equal(loaded_state[key].cpu(), adapter_model[key].cpu()): + max_abs, max_rel = _tensor_error(adapter_model[key], loaded_state[key]) + raise RuntimeError( + f"Loaded LoRA state mismatch for key '{key}' " + f"(max_abs={max_abs:.6g}, max_rel={max_rel:.6g})" + ) + + +def _configure_provider(provider: Any, topology: Topology) -> None: + provider.tensor_model_parallel_size = topology.tp + provider.expert_model_parallel_size = topology.ep + provider.expert_tensor_parallel_size = topology.etp + provider.pipeline_model_parallel_size = 1 + provider.context_parallel_size = 1 + provider.sequence_parallel = bool(topology.sp) + if hasattr(provider, "attention_dropout"): + provider.attention_dropout = 0.0 + if hasattr(provider, "hidden_dropout"): + provider.hidden_dropout = 0.0 + + +def _delta_state( + initial_state: dict[str, Any], + current_state: dict[str, Any], +) -> dict[str, Any]: + initial_keys = set(initial_state.keys()) + current_keys = set(current_state.keys()) + if initial_keys != current_keys: + missing = sorted(initial_keys - current_keys) + extra = sorted(current_keys - initial_keys) + raise KeyError( + f"LoRA state keys changed during training: missing={missing[:3]} extra={extra[:3]}" + ) + return { + key: current_state[key].detach().cpu() - initial_state[key].detach().cpu() + for key in sorted(initial_keys) + } + + +@contextmanager +def _mutation_hook( + megatron_train_module: Any, + mutation: SensitivityMutation | None, + pre_optimizer_step_hook: Callable[[], None] | None = None, +): + original_finalize = megatron_train_module._finalize_grads + original_optimizer_step = megatron_train_module._optimizer_step + + if mutation == "drop_finalize": + megatron_train_module._finalize_grads = lambda _model: None + elif mutation is not None: + raise ValueError(f"Unsupported mutation: {mutation}") + + if pre_optimizer_step_hook is not None: + + def _patched_optimizer_step(optimizer: Any, learning_rate: float): + pre_optimizer_step_hook() + return original_optimizer_step(optimizer, learning_rate) + + megatron_train_module._optimizer_step = _patched_optimizer_step + + if mutation is None: + if pre_optimizer_step_hook is None: + yield + return + try: + yield + finally: + megatron_train_module._finalize_grads = original_finalize + megatron_train_module._optimizer_step = original_optimizer_step + + +def _worker_run(request: WorkerRunRequest) -> None: + from megatron.core.optimizer import OptimizerConfig + from safetensors.torch import load_file, save_file + import torch + + from art import dev, types + from art.megatron import train as megatron_train + from art.preprocessing.pack import packed_tensors_from_dir + + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(backend="nccl") + _set_deterministic_seed(request.case_config.seed) + + world_size = torch.distributed.get_world_size() + if world_size != request.topology.world_size(): + raise RuntimeError( + f"World size mismatch: expected {request.topology.world_size()}, got {world_size}" + ) + + runtime = megatron_train.build_training_runtime( + model_identifier=request.case_config.base_model, + provider_configure=lambda provider: _configure_provider( + provider, request.topology + ), + optimizer_config=OptimizerConfig( + bf16=True, + lr=request.case_config.learning_rate, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.1, + ), + print_env=False, + print_optimizer_stats=False, + ) + model_chunks = runtime.model + optimizer = runtime.optimizer + + topology_dir = Path(request.topology_dir) + traces_dir = topology_dir / "traces" + traces_dir.mkdir(parents=True, exist_ok=True) + + shared_init_path = Path(request.shared_init_adapter_path) + if not shared_init_path.exists(): + if not request.allow_create_shared_init: + raise FileNotFoundError( + f"Missing oracle shared adapter at {shared_init_path}" + ) + initial_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: + assert initial_state is not None + shared_init_path.parent.mkdir(parents=True, exist_ok=True) + save_file(initial_state, str(shared_init_path)) + torch.distributed.barrier() + if not shared_init_path.exists(): + raise FileNotFoundError(f"Shared init adapter not created: {shared_init_path}") + + adapter_model = load_file(str(shared_init_path)) + expected_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: + assert expected_state is not None + _validate_adapter_exact(expected_state, adapter_model) + torch.distributed.barrier() + + megatron_train.load_adapter_into_model(model_chunks, adapter_model) + loaded_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: + assert loaded_state is not None + _validate_loaded_state_matches_adapter(loaded_state, adapter_model) + torch.distributed.barrier() + + packed_tensors = packed_tensors_from_dir( + **request.packed_tensors.model_dump(exclude_none=True) + ) + initial_lora_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0 and initial_lora_state is None: + raise RuntimeError("Failed to collect initial LoRA state on rank 0") + + train_config = types.TrainConfig( + learning_rate=request.case_config.learning_rate, + beta=request.case_config.beta, + kl_penalty_coef=0.0, + ) + experimental_config: dev.TrainConfig = {} + step_traces: list[StepTrace] = [] + captured_grads: dict[str, Any] | None = None + + def _capture_lora_grads() -> None: + nonlocal captured_grads + captured_grads = _collect_lora_grads(model_chunks) + + with _mutation_hook( + megatron_train, + request.mutation, + pre_optimizer_step_hook=_capture_lora_grads, + ): + for step_index in range(request.case_config.num_steps): + sample_index = step_index % request.packed_tensors.num_sequences + inputs = megatron_train.select_indexed_inputs(packed_tensors, sample_index) + captured_grads = None + + step_result = megatron_train.run_training_step( + model_chunks=model_chunks, + optimizer=optimizer, + learning_rate=train_config.learning_rate, + inputs=inputs, + config=train_config, + experimental_config=experimental_config, + ref_logprobs=None, + ) + if torch.distributed.get_rank() == 0 and captured_grads is None: + raise RuntimeError("Failed to collect LoRA grads on rank 0") + + current_lora_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0 and current_lora_state is None: + raise RuntimeError("Failed to collect current LoRA state on rank 0") + + if torch.distributed.get_rank() == 0: + assert captured_grads is not None + assert initial_lora_state is not None + assert current_lora_state is not None + output_rel = Path("traces") / f"output_step_{step_index:03d}.pt" + grads_rel = Path("traces") / f"grads_step_{step_index:03d}.safetensors" + deltas_rel = ( + Path("traces") / f"deltas_step_{step_index:03d}.safetensors" + ) + lora_rel = Path(f"lora_step_{step_index:03d}.safetensors") + + torch.save( + step_result.new_logprobs.detach().cpu().float(), + topology_dir / output_rel, + ) + save_file(captured_grads, str(topology_dir / grads_rel)) + deltas = _delta_state(initial_lora_state, current_lora_state) + save_file(deltas, str(topology_dir / deltas_rel)) + save_file(current_lora_state, str(topology_dir / lora_rel)) + + step_traces.append( + StepTrace( + step_index=step_index, + loss=float(step_result.reduced_loss.item()), + probs_corr=step_result.probs_corr, + output_file=str(output_rel), + grads_file=str(grads_rel), + deltas_file=str(deltas_rel), + lora_file=str(lora_rel), + ) + ) + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + manifest = RunManifest( + case_id=request.case_id, + base_model=request.case_config.base_model, + topology=request.topology.slug(), + world_size=request.topology.world_size(), + seed=request.case_config.seed, + num_steps=request.case_config.num_steps, + packed_tensors=request.packed_tensors, + tolerances=request.case_config.tolerances, + steps=step_traces, + ) + _write_json(topology_dir / "manifest.json", manifest.model_dump(mode="json")) + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def _run_worker_cli(run_request_path: Path) -> None: + request = WorkerRunRequest.model_validate(_read_json(run_request_path)) + _worker_run(request) + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Megatron oracle harness worker") + parser.add_argument("--worker-run", action="store_true") + parser.add_argument("--run-request", type=Path) + return parser.parse_args(argv) + + +def _main(argv: list[str]) -> int: + args = _parse_args(argv) + if not args.worker_run: + raise SystemExit("This module is intended for test imports or --worker-run") + if args.run_request is None: + raise SystemExit("--run-request is required with --worker-run") + _run_worker_cli(args.run_request) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main(sys.argv[1:])) diff --git a/tests/integration/test_megatron_lora_oracle_correctness.py b/tests/integration/test_megatron_lora_oracle_correctness.py new file mode 100644 index 00000000..f6949b81 --- /dev/null +++ b/tests/integration/test_megatron_lora_oracle_correctness.py @@ -0,0 +1,100 @@ +import pytest + +from .megatron_oracle_harness import ( + ORACLE_TOPOLOGY, + PHASE_A_TOPOLOGIES, + PHASE_B_TOPOLOGIES, + SENSITIVITY_MUTATION_ENV, + SENSITIVITY_TOPOLOGY, + available_gpu_count, + default_case_config, + ensure_oracle_reference_artifacts, + phase_b_dp_enabled, + regenerate_requested, + run_and_compare_topology, + run_sensitivity_check, + sensitivity_enabled, +) + + +def _require_gpus_for(topology_world_size: int) -> None: + gpu_count = available_gpu_count() + if gpu_count < topology_world_size: + pytest.skip( + f"Need {topology_world_size} GPUs for topology run, only found {gpu_count}" + ) + + +def _skip_if_sensitivity_mode() -> None: + if sensitivity_enabled(): + pytest.skip( + f"{SENSITIVITY_MUTATION_ENV} is enabled; running sensitivity check only." + ) + + +def _run_topology_case( # type: ignore[no-untyped-def] + topology, + case_config, + *, + regenerate: bool, +) -> None: + _require_gpus_for(topology.world_size()) + run_and_compare_topology( + case_config=case_config, + topology=topology, + regenerate=regenerate, + ) + + +def test_000_megatron_lora_oracle_sensitivity_check() -> None: + if not sensitivity_enabled(): + pytest.skip( + f"Set {SENSITIVITY_MUTATION_ENV}=drop_finalize to enable sensitivity check." + ) + _require_gpus_for(SENSITIVITY_TOPOLOGY.world_size()) + run_sensitivity_check( + case_config=default_case_config(), + regenerate=regenerate_requested(), + ) + + +def test_megatron_lora_oracle_phase_a_matrix() -> None: + _skip_if_sensitivity_mode() + case_config = default_case_config() + regenerate = regenerate_requested() + _require_gpus_for(ORACLE_TOPOLOGY.world_size()) + ensure_oracle_reference_artifacts( + case_config=case_config, + regenerate=regenerate, + ) + for topology in PHASE_A_TOPOLOGIES: + _run_topology_case( + topology, + case_config, + regenerate=regenerate and topology.slug() != ORACLE_TOPOLOGY.slug(), + ) + + +@pytest.mark.parametrize( + "topology_index", + range(len(PHASE_B_TOPOLOGIES)), + ids=[topology.slug() for topology in PHASE_B_TOPOLOGIES], +) +def test_megatron_lora_oracle_phase_b_dp_matrix(topology_index: int) -> None: + _skip_if_sensitivity_mode() + if not phase_b_dp_enabled(): + pytest.xfail( + "DP matrix currently blocked until Megatron backend DP support is enabled" + ) + case_config = default_case_config() + regenerate = regenerate_requested() + _require_gpus_for(ORACLE_TOPOLOGY.world_size()) + ensure_oracle_reference_artifacts( + case_config=case_config, + regenerate=regenerate, + ) + _run_topology_case( + PHASE_B_TOPOLOGIES[topology_index], + case_config, + regenerate=regenerate, + ) From fde2ff3c45ac98e3cb743299ba29742c2e246a81 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 10 Mar 2026 05:29:15 +0000 Subject: [PATCH 04/19] Minor typing changes --- src/art/megatron/train.py | 4 +-- tests/integration/megatron_oracle_harness.py | 27 +++++++++++++++----- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index abc2ef7b..ef58f5d0 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -296,7 +296,7 @@ def run_training_step( attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) for chunk in model_chunks: - cast(Any, chunk).zero_grad_buffer() + chunk.zero_grad_buffer() # ty: ignore[call-non-callable] new_logprobs: torch.Tensor = -model_chunks[0]( input_ids=inputs["tokens"], @@ -307,7 +307,7 @@ def run_training_step( ) loss_info = loss_fn( - cast(Any, inputs), + inputs, # ty: ignore[invalid-argument-type] new_logprobs, ref_logprobs, None, diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index 0d61b41f..273ac368 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -10,7 +10,7 @@ import shutil import subprocess import sys -from typing import Any, Callable, Literal, cast +from typing import Any, Callable, Literal, TypeVar, cast import numpy as np from pydantic import BaseModel, ConfigDict, Field @@ -162,6 +162,15 @@ class ComparisonFailure(BaseModel): message: str +T = TypeVar("T") + + +def _require_not_none(value: T | None, name: str) -> T: + if value is None: + raise RuntimeError(f"{name} is None") + return value + + PHASE_A_TOPOLOGIES = [ Topology(tp=1, ep=1, etp=1, dp=1, sp=0, phase="A"), Topology(tp=2, ep=1, etp=1, dp=1, sp=1, phase="A"), @@ -1114,9 +1123,13 @@ def _capture_lora_grads() -> None: raise RuntimeError("Failed to collect current LoRA state on rank 0") if torch.distributed.get_rank() == 0: - assert captured_grads is not None - assert initial_lora_state is not None - assert current_lora_state is not None + grads = _require_not_none(captured_grads, "captured_grads") + initial_state = _require_not_none( + initial_lora_state, "initial_lora_state" + ) + current_state = _require_not_none( + current_lora_state, "current_lora_state" + ) output_rel = Path("traces") / f"output_step_{step_index:03d}.pt" grads_rel = Path("traces") / f"grads_step_{step_index:03d}.safetensors" deltas_rel = ( @@ -1128,10 +1141,10 @@ def _capture_lora_grads() -> None: step_result.new_logprobs.detach().cpu().float(), topology_dir / output_rel, ) - save_file(captured_grads, str(topology_dir / grads_rel)) - deltas = _delta_state(initial_lora_state, current_lora_state) + save_file(grads, str(topology_dir / grads_rel)) + deltas = _delta_state(initial_state, current_state) save_file(deltas, str(topology_dir / deltas_rel)) - save_file(current_lora_state, str(topology_dir / lora_rel)) + save_file(current_state, str(topology_dir / lora_rel)) step_traces.append( StepTrace( From d2c11614d6407c1d71bbed09295d31f3e39ac95e Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Thu, 12 Mar 2026 09:40:34 +0000 Subject: [PATCH 05/19] megatron: extend LoRA grad-sync semantics across tp/expert-tp --- src/art/megatron/finalize_grads.py | 38 ++++++------ src/art/megatron/lora.py | 97 +++++++++++++++++++++--------- 2 files changed, 89 insertions(+), 46 deletions(-) diff --git a/src/art/megatron/finalize_grads.py b/src/art/megatron/finalize_grads.py index 8c496667..83e8cc4f 100644 --- a/src/art/megatron/finalize_grads.py +++ b/src/art/megatron/finalize_grads.py @@ -8,14 +8,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors GradSyncDomain = Literal["tp_default", "expert_tp"] -GradSyncOp = Literal["none", "avg"] +GradSyncOp = Literal["none", "sum", "avg"] TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" GRAD_SYNC_OP_NONE: GradSyncOp = "none" +GRAD_SYNC_OP_SUM: GradSyncOp = "sum" GRAD_SYNC_OP_AVG: GradSyncOp = "avg" VALID_DOMAINS = (TP_DEFAULT_GRAD_SYNC_DOMAIN, EXPERT_TP_GRAD_SYNC_DOMAIN) -VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_AVG) +VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_SUM, GRAD_SYNC_OP_AVG) def _iter_named_trainable_parameters( @@ -37,31 +38,36 @@ def _resolve_domain_group( domain: GradSyncDomain, ) -> torch.distributed.ProcessGroup | None: if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN: - return None + group = ps.get_tensor_model_parallel_group(check_initialized=False) + if group is None or group.size() <= 1: + return None + return group if domain != EXPERT_TP_GRAD_SYNC_DOMAIN: raise RuntimeError(f"Unknown grad sync domain: {domain}") group = ps.get_expert_tensor_parallel_group(check_initialized=False) - if group is None: - return None - if group.size() <= 1: + if group is None or group.size() <= 1: return None return group def _resolve_reduce_op(op: GradSyncOp) -> Any: + if op == GRAD_SYNC_OP_SUM: + return torch.distributed.ReduceOp.SUM if op == GRAD_SYNC_OP_AVG: return torch.distributed.ReduceOp.AVG raise RuntimeError(f"Unknown grad sync op: {op}") def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: - """Run Megatron finalize, then apply non-default grad-sync reductions. + """Run Megatron finalize, then apply extra LoRA grad-sync reductions. - Megatron finalize handles DP/CP (and expert-DP via `param.allreduce=False`) internally. - This extension only handles extra reductions outside Megatron's default TP path, - currently expert-TP reductions for params annotated with grad_sync_* metadata. + Megatron finalize handles DP/CP(via `param.allreduce=True`)(and expert-DP via `param.allreduce=False`) internally. + This extension handles extra TP/expert-TP reductions for params annotated + with grad_sync_* metadata. """ + # All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, + # embedding grads across first and last pipeline stages (if not tied) finalize_model_grads(model) buckets: dict[ @@ -73,10 +79,8 @@ def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: domain: GradSyncDomain = getattr( param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN ) - if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN: + if _resolve_domain_group(domain) is None: continue - if domain not in VALID_DOMAINS: - raise RuntimeError(f"{name}: unsupported grad_sync_domain={domain}") op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) if op not in VALID_SYNC_OPS: @@ -93,7 +97,7 @@ def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: raise RuntimeError( f"{name}: expected non-None main_grad for domain={domain} reduce_op={op}" ) - local_grad = cast( + local_grad = cast( # local part of dtensor torch.Tensor, grad._local_tensor if hasattr(grad, "_local_tensor") else grad ) buckets[(domain, op, local_grad.dtype, local_grad.device)].append( @@ -101,9 +105,9 @@ def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: ) for (domain, op, _dtype, _device), entries in buckets.items(): - group = _resolve_domain_group(domain) - if group is None: - continue + group = _resolve_domain_group( + domain + ) # already checked if the domain is one we are handling grads = [grad for _name, grad in entries] coalesced = _flatten_dense_tensors(grads) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 12a38dec..b594bf18 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -23,11 +23,12 @@ ShardDomain = Literal["tp", "expert_tp"] GradSyncDomain = Literal["tp_default", "expert_tp"] -GradSyncOp = Literal["none", "avg"] +GradSyncOp = Literal["none", "sum", "avg"] TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" GRAD_SYNC_OP_NONE: GradSyncOp = "none" +GRAD_SYNC_OP_SUM: GradSyncOp = "sum" GRAD_SYNC_OP_AVG: GradSyncOp = "avg" @@ -38,7 +39,7 @@ class LoRAParallelSpec(BaseModel): shard_domain: ShardDomain = "tp" sharded: bool = False - shard_axis: int | None = None + shard_dim: int | None = None grad_sync_domain: GradSyncDomain = TP_DEFAULT_GRAD_SYNC_DOMAIN grad_sync_op: GradSyncOp = GRAD_SYNC_OP_NONE @@ -95,7 +96,7 @@ def _set_lora_parallel_metadata( setattr(param, "lora_shard_domain", parallel_spec.shard_domain) setattr(param, "lora_tp_sharded", parallel_spec.sharded) setattr(param, "lora_tp_replicated", replicated) - setattr(param, "lora_tp_shard_axis", parallel_spec.shard_axis) + setattr(param, "lora_tp_shard_dim", parallel_spec.shard_dim) setattr(param, "grad_sync_domain", parallel_spec.grad_sync_domain) setattr(param, "grad_sync_op", parallel_spec.grad_sync_op) # Megatron DDP routing flag: @@ -115,6 +116,21 @@ def _set_lora_parallel_metadata( ), ) + # Megatron optimizer and checkpoint logic rely on tensor model-parallel metadata + # to distinguish true shards from TP-duplicate params. + if parallel_spec.sharded: + setattr(param, "tensor_model_parallel", True) + setattr( + param, "partition_dim", _normalize_axis(parallel_spec.shard_dim, param.ndim) + ) + # stride > 1 means the dim is split into blocks and each tp rank holds a shard of the block + # this might happen for fused e.g. gate_(up|proj), but loras are individual per module + setattr(param, "partition_stride", 1) + else: + setattr(param, "tensor_model_parallel", False) + setattr(param, "partition_dim", -1) + setattr(param, "partition_stride", 1) + class LoRA(torch.nn.Module): def __init__( @@ -238,7 +254,7 @@ def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None domain = getattr(into, "lora_shard_domain") sharded = bool(getattr(into, "lora_tp_sharded")) if sharded: - axis = getattr(into, "lora_tp_shard_axis") + axis = getattr(into, "lora_tp_shard_dim") if axis is None: raise RuntimeError( f"{self.adapter_model_prefix}: missing shard axis for sharded parameter" @@ -290,11 +306,11 @@ def _should_export_parameter(self, param: torch.nn.Parameter) -> bool: def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: domain = getattr(param, "lora_shard_domain") sharded = bool(getattr(param, "lora_tp_sharded", False)) - shard_axis = getattr(param, "lora_tp_shard_axis", None) + shard_dim = getattr(param, "lora_tp_shard_dim", None) return { "domain": domain, "sharded": sharded, - "shard_axis": shard_axis, + "shard_dim": shard_dim, "shard_world_size": _get_shard_world_size(domain) if sharded else 1, "shard_rank": _get_shard_rank(domain) if sharded else 0, } @@ -367,15 +383,15 @@ def __init__( a_parallel_spec = LoRAParallelSpec( shard_domain="tp", sharded=True, - shard_axis=-2, + shard_dim=-2, grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions ) b_parallel_spec = a_parallel_spec.model_copy( update={ "sharded": False, - "shard_axis": None, - "grad_sync_op": GRAD_SYNC_OP_AVG, # megatron reduces across TP ranks + "shard_dim": None, + "grad_sync_op": GRAD_SYNC_OP_SUM, # sum replicated TP contributions } ) self.lora = LoRA( @@ -423,6 +439,10 @@ def __init__( assert self.provider.kv_channels is not None assert self.provider.num_query_groups is not None assert self.provider.num_attention_heads is not None + if self.provider.num_attention_heads % self.provider.num_query_groups != 0: + raise ValueError( + "num_attention_heads must be divisible by num_query_groups for QKV LoRA" + ) q_out_features = self.provider.kv_channels * self.provider.num_attention_heads kv_out_features = self.provider.kv_channels * self.provider.num_query_groups tp_world_size = ps.get_tensor_model_parallel_world_size() @@ -434,6 +454,13 @@ def __init__( ) q_out_features_per_rank = q_out_features // tp_world_size kv_out_features_per_rank = kv_out_features // tp_world_size + self.num_query_groups_per_partition = ( + self.provider.num_query_groups // tp_world_size + ) + self.num_attention_heads_per_group = ( + self.provider.num_attention_heads // self.provider.num_query_groups + ) + self.hidden_size_per_attention_head = self.provider.kv_channels assert isinstance(linear_qkv.weight, torch.Tensor) self.q_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.q_proj", @@ -470,14 +497,14 @@ def _build_qkv_lora( a_parallel_spec = LoRAParallelSpec( shard_domain="tp", sharded=False, - shard_axis=None, + shard_dim=None, grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_AVG, # megatron reduces across TP ranks + grad_sync_op=GRAD_SYNC_OP_SUM, # sum replicated TP contributions ) b_parallel_spec = a_parallel_spec.model_copy( update={ "sharded": True, - "shard_axis": -1, + "shard_dim": -1, "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions } ) @@ -508,20 +535,32 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: query = self.q_proj_lora(layernorm_output) key = self.k_proj_lora(layernorm_output) value = self.v_proj_lora(layernorm_output) - - assert isinstance(self.linear_qkv.config.kv_channels, int) - query_4d = query.reshape( - query.shape[0], query.shape[1], -1, self.linear_qkv.config.kv_channels + # Match Megatron mixed_qkv layout: + # [S, B, nqg, (nah/nqg + 2), hn] where each query-group packs + # [all query heads for that group, key, value]. + query_5d = query.reshape( + query.shape[0], + query.shape[1], + self.num_query_groups_per_partition, + self.num_attention_heads_per_group, + self.hidden_size_per_attention_head, ) - key_4d = key.reshape( - key.shape[0], key.shape[1], -1, self.linear_qkv.config.kv_channels + key_5d = key.reshape( + key.shape[0], + key.shape[1], + self.num_query_groups_per_partition, + 1, + self.hidden_size_per_attention_head, ) - value_4d = value.reshape( - value.shape[0], value.shape[1], -1, self.linear_qkv.config.kv_channels + value_5d = value.reshape( + value.shape[0], + value.shape[1], + self.num_query_groups_per_partition, + 1, + self.hidden_size_per_attention_head, ) - - qkv_4d = torch.cat([query_4d, key_4d, value_4d], dim=2) - adapter_output = qkv_4d.reshape(qkv_4d.shape[0], qkv_4d.shape[1], -1) + qkv_5d = torch.cat([query_5d, key_5d, value_5d], dim=3) + adapter_output = qkv_5d.reshape(qkv_5d.shape[0], qkv_5d.shape[1], -1) return linear_output + adapter_output, bias @@ -566,14 +605,14 @@ def _build_fc1_lora( a_parallel_spec = LoRAParallelSpec( shard_domain="expert_tp", sharded=False, - shard_axis=None, + shard_dim=None, grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_AVG, # we handle this with extended finalize_grads + grad_sync_op=GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads ) b_parallel_spec = a_parallel_spec.model_copy( update={ "sharded": True, - "shard_axis": -1, + "shard_dim": -1, "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions } @@ -619,16 +658,16 @@ def __init__( a_parallel_spec = LoRAParallelSpec( shard_domain="expert_tp", sharded=True, - shard_axis=-2, + shard_dim=-2, grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions ) b_parallel_spec = a_parallel_spec.model_copy( update={ "sharded": False, - "shard_axis": None, + "shard_dim": None, "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_AVG, # we handle this with extended finalize_grads + "grad_sync_op": GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads } ) self.lora = LoRA( From e4180184bb649b58daef258503130356040d5656 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Thu, 12 Mar 2026 09:40:37 +0000 Subject: [PATCH 06/19] megatron: add MoE routing replay core and unit tests --- src/art/megatron/routing_replay.py | 832 ++++++++++++++++++++++++++ tests/unit/test_moe_routing_replay.py | 189 ++++++ 2 files changed, 1021 insertions(+) create mode 100644 src/art/megatron/routing_replay.py create mode 100644 tests/unit/test_moe_routing_replay.py diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py new file mode 100644 index 00000000..91865b80 --- /dev/null +++ b/src/art/megatron/routing_replay.py @@ -0,0 +1,832 @@ +from __future__ import annotations + +import json +from pathlib import Path +import re +import types +from typing import Any, Protocol + +from pydantic import BaseModel, ConfigDict, model_validator +from safetensors.torch import load_file, save_file +import torch + +ROUTER_NAME_TOKEN = ".mlp.router" +ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1" +GLOBAL_TOKEN_UIDS_KEY = "global_token_uids" + +_ROUTER_LAYER_PATTERN = re.compile(r"decoder\.layers\.(?P\d+)\.mlp\.router$") +_TRACE_CHUNK_PREFIX_PATTERN = re.compile(r"^chunk(?P\d+)\.(?P.+)$") + + +def _to_tensor_cpu_contiguous( + tensor: torch.Tensor, *, dtype: torch.dtype +) -> torch.Tensor: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") + return tensor.detach().to(device="cpu", dtype=dtype).contiguous() + + +def _normalize_step_index(step_index: int) -> str: + if step_index < 0: + raise ValueError(f"step_index must be non-negative, got {step_index}") + return f"{step_index:06d}" + + +def _build_tensor_key(router_key: str, call_index: int, field_name: str) -> str: + return f"{router_key}/call_{call_index}/{field_name}" + + +def _flatten_router_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.ndim < 2: + raise RuntimeError( + f"Router tensor must have rank >=2, got shape={tuple(tensor.shape)}" + ) + num_experts = int(tensor.shape[-1]) + return tensor.reshape(-1, num_experts).contiguous() + + +def _extract_router_output_tensors(output: Any) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(output, (list, tuple)) and len(output) >= 2: + probs, routing_map = output[0], output[1] + elif isinstance(output, dict): + probs = output.get("probs") + routing_map = output.get("routing_map") + else: + raise RuntimeError(f"Unsupported router output type: {type(output)}") + + if not isinstance(probs, torch.Tensor): + raise RuntimeError(f"Expected probs tensor, got {type(probs)}") + if not isinstance(routing_map, torch.Tensor): + raise RuntimeError(f"Expected routing_map tensor, got {type(routing_map)}") + + probs_2d = _flatten_router_tensor(probs.to(torch.float32)) + routing_map_2d = _flatten_router_tensor(routing_map.bool()) + if probs_2d.shape != routing_map_2d.shape: + raise RuntimeError( + "Router output shape mismatch: " + f"probs={tuple(probs_2d.shape)} routing_map={tuple(routing_map_2d.shape)}" + ) + return probs_2d, routing_map_2d + + +def build_router_key_from_module_name(*, chunk_index: int, module_name: str) -> str: + match = _ROUTER_LAYER_PATTERN.search(module_name) + if match is None: + raise RuntimeError( + f"Unable to derive router key from module name '{module_name}'. " + f"Expected suffix matching '{_ROUTER_LAYER_PATTERN.pattern}'." + ) + layer_index = int(match.group("layer")) + return f"chunk_{chunk_index:02d}.layer_{layer_index:04d}.mlp.router" + + +def build_router_key_from_trace_name(trace_module_name: str) -> str: + chunk_match = _TRACE_CHUNK_PREFIX_PATTERN.match(trace_module_name) + if chunk_match is None: + raise RuntimeError( + "Forward trace router module name must start with 'chunk.'; " + f"got '{trace_module_name}'" + ) + chunk_index = int(chunk_match.group("chunk")) + module_name = chunk_match.group("name") + return build_router_key_from_module_name( + chunk_index=chunk_index, + module_name=module_name, + ) + + +class ParallelTopology(BaseModel): + tp: int + ep: int + etp: int = 1 + dp: int = 1 + sp: bool = False + cp: int = 1 + pp: int = 1 + vpp: int = 1 + + +class RouterCallRoute(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + expert_indices: torch.Tensor + expert_probs: torch.Tensor + expert_mask: torch.Tensor + routing_map: torch.Tensor | None = None + num_experts: int + + @model_validator(mode="after") + def _validate(self) -> "RouterCallRoute": + self.expert_indices = _to_tensor_cpu_contiguous( + self.expert_indices, dtype=torch.int32 + ) + self.expert_probs = _to_tensor_cpu_contiguous( + self.expert_probs, dtype=torch.float32 + ) + self.expert_mask = _to_tensor_cpu_contiguous(self.expert_mask, dtype=torch.bool) + if self.routing_map is not None: + self.routing_map = _to_tensor_cpu_contiguous( + self.routing_map, dtype=torch.bool + ) + + if self.expert_indices.ndim != 2: + raise RuntimeError( + "expert_indices must have shape [num_tokens, max_topk], got " + f"{tuple(self.expert_indices.shape)}" + ) + if self.expert_probs.shape != self.expert_indices.shape: + raise RuntimeError( + "expert_probs shape must match expert_indices shape, got " + f"{tuple(self.expert_probs.shape)} vs {tuple(self.expert_indices.shape)}" + ) + if self.expert_mask.shape != self.expert_indices.shape: + raise RuntimeError( + "expert_mask shape must match expert_indices shape, got " + f"{tuple(self.expert_mask.shape)} vs {tuple(self.expert_indices.shape)}" + ) + if self.num_experts <= 0: + raise RuntimeError(f"num_experts must be >0, got {self.num_experts}") + if self.routing_map is not None: + expected = (self.expert_indices.shape[0], self.num_experts) + if tuple(self.routing_map.shape) != expected: + raise RuntimeError( + "routing_map shape mismatch: " + f"expected={expected}, got={tuple(self.routing_map.shape)}" + ) + return self + + @property + def num_global_tokens(self) -> int: + return int(self.expert_indices.shape[0]) + + @property + def max_topk(self) -> int: + return int(self.expert_indices.shape[1]) + + +class StepRouterRoutes(BaseModel): + calls: dict[int, RouterCallRoute] + + @model_validator(mode="after") + def _validate_calls(self) -> "StepRouterRoutes": + if not self.calls: + raise RuntimeError("StepRouterRoutes.calls cannot be empty") + for call_index in self.calls: + if call_index < 0: + raise RuntimeError(f"call_index must be >=0, got {call_index}") + return self + + +class StepRoutes(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + routers: dict[str, StepRouterRoutes] + global_token_uids: torch.Tensor + + @model_validator(mode="after") + def _validate(self) -> "StepRoutes": + if not self.routers: + raise RuntimeError("StepRoutes.routers cannot be empty") + self.global_token_uids = _to_tensor_cpu_contiguous( + self.global_token_uids, dtype=torch.int64 + ) + if self.global_token_uids.ndim != 1: + raise RuntimeError( + "global_token_uids must have shape [num_global_tokens], got " + f"{tuple(self.global_token_uids.shape)}" + ) + if int(torch.unique(self.global_token_uids).numel()) != int( + self.global_token_uids.numel() + ): + raise RuntimeError("global_token_uids must be unique per step") + expected_tokens = int(self.global_token_uids.numel()) + for router_key, step_router in self.routers.items(): + for call_index, route in step_router.calls.items(): + if route.num_global_tokens != expected_tokens: + raise RuntimeError( + "Route token count mismatch for " + f"router='{router_key}' call={call_index}: " + f"route_tokens={route.num_global_tokens}, " + f"expected_tokens={expected_tokens}" + ) + return self + + +class MoeRoutingReplayBundle(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + format_version: str = ROUTER_KEY_FORMAT_VERSION + topology: ParallelTopology + num_steps: int + max_topk: int + router_keys: list[str] + steps: dict[int, StepRoutes] + + @model_validator(mode="after") + def _validate(self) -> "MoeRoutingReplayBundle": + if self.format_version != ROUTER_KEY_FORMAT_VERSION: + raise RuntimeError( + f"Unsupported format_version={self.format_version}; " + f"expected={ROUTER_KEY_FORMAT_VERSION}" + ) + if self.num_steps <= 0: + raise RuntimeError(f"num_steps must be >0, got {self.num_steps}") + if self.max_topk < 0: + raise RuntimeError(f"max_topk must be >=0, got {self.max_topk}") + if set(self.steps.keys()) != set(range(self.num_steps)): + raise RuntimeError( + "steps must be indexed from 0..num_steps-1 without gaps: " + f"num_steps={self.num_steps}, step_keys={sorted(self.steps.keys())}" + ) + if not self.router_keys: + raise RuntimeError("router_keys cannot be empty") + router_key_set = set(self.router_keys) + for step_index, step_routes in self.steps.items(): + step_router_keys = set(step_routes.routers.keys()) + if step_router_keys != router_key_set: + raise RuntimeError( + f"Step {step_index} router set mismatch. " + f"expected={sorted(router_key_set)}, got={sorted(step_router_keys)}" + ) + return self + + @classmethod + def from_dir(cls, bundle_dir: str | Path) -> "MoeRoutingReplayBundle": + base_dir = Path(bundle_dir) + manifest_path = base_dir / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"Missing routing replay manifest: {manifest_path}") + with manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + if manifest.get("format_version") != ROUTER_KEY_FORMAT_VERSION: + raise RuntimeError( + "Unsupported routing replay manifest version: " + f"{manifest.get('format_version')}" + ) + + topology = ParallelTopology.model_validate(manifest["topology"]) + num_steps = int(manifest["num_steps"]) + max_topk = int(manifest["max_topk"]) + router_keys = [str(key) for key in manifest["router_keys"]] + manifest_steps = manifest["steps"] + + steps: dict[int, StepRoutes] = {} + for step_index in range(num_steps): + step_manifest = manifest_steps[str(step_index)] + step_file = base_dir / step_manifest["file"] + if not step_file.exists(): + raise FileNotFoundError( + f"Missing routing replay step file for step={step_index}: {step_file}" + ) + step_tensors = load_file(str(step_file)) + if GLOBAL_TOKEN_UIDS_KEY not in step_tensors: + raise RuntimeError( + f"Step file missing '{GLOBAL_TOKEN_UIDS_KEY}': {step_file}" + ) + global_token_uids = step_tensors[GLOBAL_TOKEN_UIDS_KEY] + + routers: dict[str, StepRouterRoutes] = {} + for router_key in router_keys: + router_step_manifest = step_manifest["routers"].get(router_key) + if router_step_manifest is None: + raise RuntimeError( + f"Step manifest missing router_key='{router_key}' for step={step_index}" + ) + calls: dict[int, RouterCallRoute] = {} + for call_index_raw, call_manifest in router_step_manifest.items(): + call_index = int(call_index_raw) + expert_indices_key = _build_tensor_key( + router_key, call_index, "expert_indices" + ) + expert_probs_key = _build_tensor_key( + router_key, call_index, "expert_probs" + ) + expert_mask_key = _build_tensor_key( + router_key, call_index, "expert_mask" + ) + routing_map_key = _build_tensor_key( + router_key, call_index, "routing_map" + ) + if expert_indices_key not in step_tensors: + raise RuntimeError( + f"Missing tensor key '{expert_indices_key}' in {step_file}" + ) + if expert_probs_key not in step_tensors: + raise RuntimeError( + f"Missing tensor key '{expert_probs_key}' in {step_file}" + ) + if expert_mask_key not in step_tensors: + raise RuntimeError( + f"Missing tensor key '{expert_mask_key}' in {step_file}" + ) + routing_map = ( + step_tensors[routing_map_key] + if routing_map_key in step_tensors + else None + ) + calls[call_index] = RouterCallRoute( + expert_indices=step_tensors[expert_indices_key], + expert_probs=step_tensors[expert_probs_key], + expert_mask=step_tensors[expert_mask_key], + routing_map=routing_map, + num_experts=int(call_manifest["num_experts"]), + ) + routers[router_key] = StepRouterRoutes(calls=calls) + steps[step_index] = StepRoutes( + routers=routers, + global_token_uids=global_token_uids, + ) + + return cls( + format_version=ROUTER_KEY_FORMAT_VERSION, + topology=topology, + num_steps=num_steps, + max_topk=max_topk, + router_keys=router_keys, + steps=steps, + ) + + def to_dir(self, bundle_dir: str | Path) -> None: + base_dir = Path(bundle_dir) + base_dir.mkdir(parents=True, exist_ok=True) + + manifest_steps: dict[str, dict[str, Any]] = {} + for step_index in range(self.num_steps): + step_routes = self.steps[step_index] + step_file_name = f"step_{_normalize_step_index(step_index)}.safetensors" + step_file_path = base_dir / step_file_name + step_tensors: dict[str, torch.Tensor] = { + GLOBAL_TOKEN_UIDS_KEY: _to_tensor_cpu_contiguous( + step_routes.global_token_uids, dtype=torch.int64 + ) + } + step_manifest_routers: dict[str, dict[str, dict[str, int]]] = {} + for router_key in self.router_keys: + router_routes = step_routes.routers[router_key] + call_manifest: dict[str, dict[str, int]] = {} + for call_index, route in sorted(router_routes.calls.items()): + step_tensors[ + _build_tensor_key(router_key, call_index, "expert_indices") + ] = _to_tensor_cpu_contiguous( + route.expert_indices, dtype=torch.int32 + ) + step_tensors[ + _build_tensor_key(router_key, call_index, "expert_probs") + ] = _to_tensor_cpu_contiguous( + route.expert_probs, dtype=torch.float32 + ) + step_tensors[ + _build_tensor_key(router_key, call_index, "expert_mask") + ] = _to_tensor_cpu_contiguous(route.expert_mask, dtype=torch.bool) + if route.routing_map is not None: + step_tensors[ + _build_tensor_key(router_key, call_index, "routing_map") + ] = _to_tensor_cpu_contiguous( + route.routing_map, dtype=torch.bool + ) + call_manifest[str(call_index)] = {"num_experts": route.num_experts} + step_manifest_routers[router_key] = call_manifest + save_file(step_tensors, str(step_file_path)) + manifest_steps[str(step_index)] = { + "file": step_file_name, + "routers": step_manifest_routers, + } + + manifest = { + "format_version": ROUTER_KEY_FORMAT_VERSION, + "topology": self.topology.model_dump(mode="json"), + "num_steps": self.num_steps, + "max_topk": self.max_topk, + "router_keys": self.router_keys, + "steps": manifest_steps, + } + with (base_dir / "manifest.json").open("w", encoding="utf-8") as handle: + json.dump(manifest, handle, indent=2, sort_keys=True) + + +class LocalTokenIndexer(Protocol): + def build_local_token_uids( + self, + *, + global_token_uids: torch.Tensor, + num_local_tokens: int, + sequence_parallel: bool, + context_parallel_size: int, + ) -> torch.Tensor: + """Build local token uid order for current rank.""" + + +class TopologyAwareLocalTokenIndexer: + def __init__(self, parallel_state_module: Any | None = None) -> None: + self._parallel_state = parallel_state_module + + def _ps(self) -> Any: + if self._parallel_state is not None: + return self._parallel_state + from megatron.core import parallel_state as ps + + self._parallel_state = ps + return ps + + def build_local_token_uids( + self, + *, + global_token_uids: torch.Tensor, + num_local_tokens: int, + sequence_parallel: bool, + context_parallel_size: int, + ) -> torch.Tensor: + ps = self._ps() + + local_uids = global_token_uids.to(dtype=torch.int64, device="cpu").view(1, -1) + + cp_size = int(ps.get_context_parallel_world_size()) + if context_parallel_size > 1 and cp_size > 1: + from megatron.core.utils import get_batch_on_this_cp_rank + + local_uids = get_batch_on_this_cp_rank({"tokens": local_uids})["tokens"] + + tp_size = int(ps.get_tensor_model_parallel_world_size()) + tp_rank = int(ps.get_tensor_model_parallel_rank()) if tp_size > 1 else 0 + if sequence_parallel and tp_size > 1: + tokens_per_tp_rank = local_uids.shape[1] // tp_size + start = tp_rank * tokens_per_tp_rank + local_uids = local_uids[:, start : start + tokens_per_tp_rank] + + return local_uids.reshape(-1).contiguous() + + +def _patch_alltoall_dispatcher_preprocess() -> None: + try: + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAlltoAllTokenDispatcher, + ) + except Exception: + return + + if hasattr(MoEAlltoAllTokenDispatcher, "_art_router_replay_preprocess_patched"): + return + + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess( + self: Any, routing_map: torch.Tensor, *args: Any, **kwargs: Any + ): + result = original_preprocess(self, routing_map, *args, **kwargs) + if ( + not getattr(self, "drop_and_pad", False) + and getattr(self.config, "moe_expert_capacity_factor", None) is None + and not ( + getattr(self.config, "moe_router_padding_for_quantization", None) + or getattr(self.config, "moe_router_padding_for_fp8", None) + ) + ): + self.num_out_tokens = int(routing_map.sum().item()) + return result + + setattr(MoEAlltoAllTokenDispatcher, "preprocess", patched_preprocess) + setattr(MoEAlltoAllTokenDispatcher, "_art_router_replay_preprocess_patched", True) + + +class MoeRoutingReplayController: + def __init__( + self, + *, + bundle: MoeRoutingReplayBundle, + strict: bool, + local_token_indexer: LocalTokenIndexer | None = None, + ) -> None: + self.bundle = bundle + self.strict = strict + self.local_token_indexer = ( + local_token_indexer or TopologyAwareLocalTokenIndexer() + ) + + self._active_step_index: int | None = None + self._active_sample_index: int | None = None + self._active_step_routes: StepRoutes | None = None + self._router_call_cursors: dict[str, int] = {} + self._global_uid_to_row_index: dict[int, int] = {} + self._local_router_keys: set[str] = set() + + self._patched_router_modules: list[dict[str, Any]] = [] + + def install_router_patches(self, model_chunks: list[Any]) -> None: + if self._patched_router_modules: + return + _patch_alltoall_dispatcher_preprocess() + + for chunk_index, chunk in enumerate(model_chunks): + for module_name, module in chunk.named_modules(): + if ROUTER_NAME_TOKEN not in module_name: + continue + if not hasattr(module, "routing"): + continue + router_key = build_router_key_from_module_name( + chunk_index=chunk_index, + module_name=module_name, + ) + if self.strict and router_key not in self.bundle.router_keys: + raise RuntimeError( + "Router key from model is missing in replay bundle: " + f"router_key='{router_key}'" + ) + + original_routing = module.routing + if getattr(module, "_art_router_replay_patched", False): + continue + + sequence_parallel = bool( + getattr(getattr(module, "config", None), "sequence_parallel", False) + ) + context_parallel_size = int( + getattr(getattr(module, "config", None), "context_parallel_size", 1) + ) + + def routing_wrapper( + _module: Any, + logits: torch.Tensor, + *args: Any, + _router_key: str = router_key, + _sequence_parallel: bool = sequence_parallel, + _context_parallel_size: int = context_parallel_size, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + live_probs, live_routing_map = original_routing( + logits, *args, **kwargs + ) + replay_probs, replay_routing_map = self.get_route_for_router( + router_key=_router_key, + logits=live_probs, + sequence_parallel=_sequence_parallel, + context_parallel_size=_context_parallel_size, + ) + # same result, but autograd goes through + probs = ( + live_probs + + ( + replay_probs.to( + device=live_probs.device, + dtype=live_probs.dtype, + ) + - live_probs + ).detach() + ) + routing_map = replay_routing_map.to( + device=live_routing_map.device, + dtype=live_routing_map.dtype, + ) + return probs, routing_map + + module.routing = types.MethodType(routing_wrapper, module) + module._art_router_replay_patched = True + self._local_router_keys.add(router_key) + self._patched_router_modules.append( + { + "module": module, + "router_key": router_key, + "original_routing": original_routing, + } + ) + + def remove_router_patches(self) -> None: + for item in self._patched_router_modules: + module = item["module"] + module.routing = item["original_routing"] + if hasattr(module, "_art_router_replay_patched"): + delattr(module, "_art_router_replay_patched") + self._patched_router_modules.clear() + self._local_router_keys.clear() + + def set_step(self, *, step_index: int, sample_index: int) -> None: + if step_index not in self.bundle.steps: + raise RuntimeError( + f"Replay bundle missing step_index={step_index}. " + f"Available steps={sorted(self.bundle.steps.keys())}" + ) + step_routes = self.bundle.steps[step_index] + self._active_step_index = step_index + self._active_sample_index = sample_index + self._active_step_routes = step_routes + for local_router_key in sorted(self._local_router_keys): + if local_router_key not in step_routes.routers: + raise RuntimeError( + "Replay bundle step is missing local router key: " + f"step={step_index}, router='{local_router_key}'" + ) + self._router_call_cursors = { + router_key: 0 for router_key in sorted(self._local_router_keys) + } + self._global_uid_to_row_index = { + int(uid.item()): row_index + for row_index, uid in enumerate(step_routes.global_token_uids) + } + + def finalize_step(self) -> None: + if self._active_step_routes is None: + raise RuntimeError("finalize_step called before set_step") + for router_key in sorted(self._local_router_keys): + router_routes = self._active_step_routes.routers[router_key] + consumed = self._router_call_cursors.get(router_key, 0) + expected = len(router_routes.calls) + if consumed != expected: + raise RuntimeError( + "Routing replay step consumption mismatch: " + f"step={self._active_step_index}, router='{router_key}', " + f"consumed={consumed}, expected={expected}" + ) + self._active_step_index = None + self._active_sample_index = None + self._active_step_routes = None + self._router_call_cursors = {} + self._global_uid_to_row_index = {} + + def get_route_for_router( + self, + *, + router_key: str, + logits: torch.Tensor, + sequence_parallel: bool, + context_parallel_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + step_routes = self._active_step_routes + call_index = self._router_call_cursors.get(router_key, 0) + router_calls = step_routes.routers[router_key].calls + route = router_calls[call_index] + self._router_call_cursors[router_key] = call_index + 1 + + num_local_tokens = int(logits.shape[0]) + num_experts = int(logits.shape[1]) + + local_uids = self.local_token_indexer.build_local_token_uids( + global_token_uids=step_routes.global_token_uids, + num_local_tokens=num_local_tokens, + sequence_parallel=sequence_parallel, + context_parallel_size=context_parallel_size, + ) + row_index_tensor = torch.tensor( + [self._global_uid_to_row_index[int(uid)] for uid in local_uids.tolist()], + dtype=torch.int64, + ) + + local_indices = route.expert_indices.index_select(0, row_index_tensor) + local_probs = route.expert_probs.index_select(0, row_index_tensor) + local_mask = route.expert_mask.index_select(0, row_index_tensor) + + probs = torch.zeros( + (num_local_tokens, num_experts), + dtype=logits.dtype, + device=logits.device, + ) + routing_map = torch.zeros( + (num_local_tokens, num_experts), + dtype=torch.bool, + device=logits.device, + ) + + if local_indices.numel() > 0: + indices_device = local_indices.to(device=logits.device, dtype=torch.long) + probs_device = local_probs.to(device=logits.device, dtype=logits.dtype) + mask_device = local_mask.to(device=logits.device, dtype=torch.bool) + row_index_device = ( + torch.arange(num_local_tokens, device=logits.device) + .unsqueeze(1) + .expand_as(indices_device) + ) + + selected_rows = row_index_device[mask_device] + selected_cols = indices_device[mask_device] + selected_probs = probs_device[mask_device] + + if selected_rows.numel() > 0: + probs[selected_rows, selected_cols] = selected_probs + routing_map[selected_rows, selected_cols] = True + + return probs, routing_map + + +def _compact_route_from_dense( + probs_2d: torch.Tensor, + routing_map_2d: torch.Tensor, +) -> RouterCallRoute: + num_tokens, num_experts = probs_2d.shape + if num_tokens == 0: + return RouterCallRoute( + expert_indices=torch.zeros((0, 0), dtype=torch.int32), + expert_probs=torch.zeros((0, 0), dtype=torch.float32), + expert_mask=torch.zeros((0, 0), dtype=torch.bool), + num_experts=num_experts, + ) + + max_topk = int(routing_map_2d.sum(dim=1).max().item()) + expert_indices = torch.zeros((num_tokens, max_topk), dtype=torch.int32) + expert_probs = torch.zeros((num_tokens, max_topk), dtype=torch.float32) + expert_mask = torch.zeros((num_tokens, max_topk), dtype=torch.bool) + for token_index in range(num_tokens): + expert_ids = torch.nonzero( + routing_map_2d[token_index], as_tuple=False + ).flatten() + slot_count = int(expert_ids.numel()) + if slot_count == 0: + continue + expert_indices[token_index, :slot_count] = expert_ids.to(torch.int32) + expert_probs[token_index, :slot_count] = probs_2d[token_index, expert_ids].to( + torch.float32 + ) + expert_mask[token_index, :slot_count] = True + + return RouterCallRoute( + expert_indices=expert_indices, + expert_probs=expert_probs, + expert_mask=expert_mask, + num_experts=num_experts, + ) + + +def build_bundle_from_forward_trace_dir( + *, + traces_dir: str | Path, + num_steps: int, + topology: ParallelTopology, +) -> MoeRoutingReplayBundle: + """Build a replay bundle from saved forward traces for the correctness harness. + + This helper is intended for testing/oracle routing replay workflows and is not + part of inference routing capture/export. + """ + trace_dir = Path(traces_dir) + steps: dict[int, StepRoutes] = {} + router_keys_union: set[str] = set() + max_topk = 0 + + for step_index in range(num_steps): + trace_path = trace_dir / f"forward_trace_step_{step_index:03d}.pt" + if not trace_path.exists(): + raise FileNotFoundError( + f"Missing forward trace for step={step_index}: {trace_path}" + ) + step_trace: dict[str, list[dict[str, Any]]] = torch.load( + trace_path, map_location="cpu", weights_only=False + ) + + step_routers: dict[str, StepRouterRoutes] = {} + step_global_tokens: int | None = None + for module_name in sorted(step_trace.keys()): + if ROUTER_NAME_TOKEN not in module_name: + continue + router_key = build_router_key_from_trace_name(module_name) + router_calls: dict[int, RouterCallRoute] = {} + for call_index, call_entry in enumerate(step_trace[module_name]): + output = call_entry.get("output") + probs_2d, routing_map_2d = _extract_router_output_tensors(output) + compact_route = _compact_route_from_dense(probs_2d, routing_map_2d) + router_calls[call_index] = compact_route + max_topk = max(max_topk, compact_route.max_topk) + token_count = compact_route.num_global_tokens + if step_global_tokens is None: + step_global_tokens = token_count + elif step_global_tokens != token_count: + raise RuntimeError( + "Inconsistent token count across routers within step: " + f"step={step_index}, expected={step_global_tokens}, got={token_count}, " + f"router='{router_key}', call={call_index}" + ) + + if not router_calls: + raise RuntimeError( + f"Router trace has no calls for module '{module_name}' at step={step_index}" + ) + step_routers[router_key] = StepRouterRoutes(calls=router_calls) + router_keys_union.add(router_key) + + if not step_routers: + raise RuntimeError( + f"No router traces found for step={step_index} in {trace_path}" + ) + if step_global_tokens is None: + raise RuntimeError( + f"Could not infer token count for step={step_index} from router traces" + ) + global_token_uids = torch.arange(step_global_tokens, dtype=torch.int64) + steps[step_index] = StepRoutes( + routers=step_routers, + global_token_uids=global_token_uids, + ) + + router_keys = sorted(router_keys_union) + for step_index, step_routes in steps.items(): + if set(step_routes.routers.keys()) != set(router_keys): + raise RuntimeError( + f"Step {step_index} router keys differ from global set: " + f"step_keys={sorted(step_routes.routers.keys())}, router_keys={router_keys}" + ) + + return MoeRoutingReplayBundle( + format_version=ROUTER_KEY_FORMAT_VERSION, + topology=topology, + num_steps=num_steps, + max_topk=max_topk, + router_keys=router_keys, + steps=steps, + ) diff --git a/tests/unit/test_moe_routing_replay.py b/tests/unit/test_moe_routing_replay.py new file mode 100644 index 00000000..980784c7 --- /dev/null +++ b/tests/unit/test_moe_routing_replay.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from pathlib import Path +import tempfile + +import pytest +import torch +from torch import nn + +from art.megatron.routing_replay import ( + MoeRoutingReplayBundle, + MoeRoutingReplayController, + ParallelTopology, + RouterCallRoute, + StepRouterRoutes, + StepRoutes, +) + + +def _dense_from_compact( + route: RouterCallRoute, + *, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = route.expert_indices.shape[0] + num_experts = route.num_experts + probs = torch.zeros((num_tokens, num_experts), dtype=dtype) + routing_map = torch.zeros((num_tokens, num_experts), dtype=torch.bool) + for token_idx in range(num_tokens): + for slot in range(route.expert_indices.shape[1]): + if not bool(route.expert_mask[token_idx, slot]): + continue + expert_idx = int(route.expert_indices[token_idx, slot].item()) + probs[token_idx, expert_idx] = route.expert_probs[token_idx, slot].to(dtype) + routing_map[token_idx, expert_idx] = True + return probs, routing_map + + +def _make_bundle() -> tuple[MoeRoutingReplayBundle, RouterCallRoute]: + router_key = "chunk_00.layer_0000.mlp.router" + route = RouterCallRoute( + expert_indices=torch.tensor( + [ + [0, 2], + [1, 0], + [2, 1], + [1, 0], + ], + dtype=torch.int32, + ), + expert_probs=torch.tensor( + [ + [0.70, 0.30], + [1.00, 0.00], + [0.65, 0.35], + [1.00, 0.00], + ], + dtype=torch.float32, + ), + expert_mask=torch.tensor( + [ + [True, True], + [True, False], + [True, True], + [True, False], + ], + dtype=torch.bool, + ), + num_experts=3, + ) + bundle = MoeRoutingReplayBundle( + topology=ParallelTopology(tp=1, ep=1, etp=1, dp=1, sp=False, cp=1, pp=1, vpp=1), + num_steps=1, + max_topk=2, + router_keys=[router_key], + steps={ + 0: StepRoutes( + routers={router_key: StepRouterRoutes(calls={0: route})}, + global_token_uids=torch.arange(4, dtype=torch.int64), + ) + }, + ) + return bundle, route + + +class _IdentityIndexer: + def build_local_token_uids( + self, + *, + global_token_uids: torch.Tensor, + num_local_tokens: int, + sequence_parallel: bool, + context_parallel_size: int, + ) -> torch.Tensor: + del sequence_parallel, context_parallel_size + if int(global_token_uids.numel()) < num_local_tokens: + raise RuntimeError("num_local_tokens exceeds global token count") + return global_token_uids[:num_local_tokens].clone() + + +class _FakeRouter(nn.Module): + def __init__(self) -> None: + super().__init__() + self.config = type( + "Config", + (), + {"sequence_parallel": False, "context_parallel_size": 1}, + )() + + def routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + probs = torch.softmax(logits, dim=-1) + routing_map = torch.zeros_like(logits, dtype=torch.bool) + return probs, routing_map + + +class _FakeMlp(nn.Module): + def __init__(self) -> None: + super().__init__() + self.router = _FakeRouter() + + +class _FakeLayer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.mlp = _FakeMlp() + + +class _FakeDecoder(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = nn.ModuleList([_FakeLayer()]) + + +class _FakeChunk(nn.Module): + def __init__(self) -> None: + super().__init__() + self.decoder = _FakeDecoder() + + +def test_bundle_roundtrip_disk() -> None: + bundle, route = _make_bundle() + with tempfile.TemporaryDirectory() as tmp_dir: + bundle_path = Path(tmp_dir) + bundle.to_dir(bundle_path) + loaded = MoeRoutingReplayBundle.from_dir(bundle_path) + + assert loaded.num_steps == 1 + assert loaded.max_topk == 2 + assert loaded.router_keys == bundle.router_keys + loaded_route = loaded.steps[0].routers[bundle.router_keys[0]].calls[0] + assert torch.equal(loaded_route.expert_indices, route.expert_indices) + assert torch.equal(loaded_route.expert_probs, route.expert_probs) + assert torch.equal(loaded_route.expert_mask, route.expert_mask) + + +def test_controller_patches_router_and_replays() -> None: + bundle, route = _make_bundle() + controller = MoeRoutingReplayController( + bundle=bundle, + strict=True, + local_token_indexer=_IdentityIndexer(), + ) + chunk = _FakeChunk() + controller.install_router_patches([chunk]) + controller.set_step(step_index=0, sample_index=0) + + logits = torch.randn((4, 3), dtype=torch.float32) + replay_probs, replay_map = chunk.decoder.layers[0].mlp.router.routing(logits) + expected_probs, expected_map = _dense_from_compact(route, dtype=logits.dtype) + + assert torch.equal(replay_map.cpu(), expected_map) + assert torch.allclose(replay_probs.cpu(), expected_probs, atol=0.0, rtol=0.0) + + controller.finalize_step() + controller.remove_router_patches() + + +def test_controller_finalize_fails_when_unconsumed_calls_remain() -> None: + bundle, _route = _make_bundle() + controller = MoeRoutingReplayController( + bundle=bundle, + strict=True, + local_token_indexer=_IdentityIndexer(), + ) + chunk = _FakeChunk() + controller.install_router_patches([chunk]) + controller.set_step(step_index=0, sample_index=0) + with pytest.raises(RuntimeError, match="consumption mismatch"): + controller.finalize_step() From bc5e7a48dfbd6b5d9c7ef2e6d57c500c6d3cc044 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Thu, 12 Mar 2026 09:40:39 +0000 Subject: [PATCH 07/19] megatron runtime/service: wire routing replay into training jobs --- src/art/dev/train.py | 8 +++- src/art/megatron/service.py | 9 ++++ src/art/megatron/train.py | 83 +++++++++++++++++++++++++++++++++++-- 3 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index b0e232c5..0ada9ccb 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -1,7 +1,10 @@ -from typing import Literal +from typing import TYPE_CHECKING, Literal from typing_extensions import TypedDict +if TYPE_CHECKING: + from art.megatron.routing_replay import MoeRoutingReplayBundle + class TrainConfig(TypedDict, total=False): advantage_balance: float @@ -22,6 +25,9 @@ class TrainConfig(TypedDict, total=False): logprob_calculation_chunk_size: int mask_prob_ratio: bool max_negative_advantage_importance_sampling_weight: float + moe_routing_replay_bundle: "MoeRoutingReplayBundle | None" + moe_routing_replay_path: str | None + moe_routing_replay_strict: bool num_trajectories_learning_rate_multiplier_power: float plot_tensors: bool ppo: bool diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index e4c99a98..42ec4f9a 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -36,6 +36,8 @@ class MegatronTrainingJob(BaseModel): disk_packed_tensors: DiskPackedTensors config: types.TrainConfig experimental_config: dev.TrainConfig + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True @dataclass @@ -241,12 +243,19 @@ async def train( for job_name in os.listdir(jobs_dir): if job_name.endswith(".json"): os.remove(os.path.join(jobs_dir, job_name)) + if _config.get("moe_routing_replay_bundle") is not None: + raise RuntimeError( + "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " + "MegatronService subprocess jobs must use moe_routing_replay_path." + ) job = MegatronTrainingJob( lora_path=lora_path, optimizer_state_path=self._optimizer_state_path, disk_packed_tensors=disk_packed_tensors, config=config, experimental_config=_config, + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True), ) job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") with open(job_path, "w") as f: diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index ef58f5d0..33dc8172 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -39,6 +39,10 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: from art.megatron.lora import apply_lora_adapters from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu from art.megatron.provider import get_provider +from art.megatron.routing_replay import ( + MoeRoutingReplayBundle, + MoeRoutingReplayController, +) from art.preprocessing.pack import ( DiskPackedTensors, PackedTensors, @@ -54,6 +58,8 @@ class TrainingJob(BaseModel): disk_packed_tensors: DiskPackedTensors config: types.TrainConfig experimental_config: dev.TrainConfig + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True class TrainingRuntime(BaseModel): @@ -64,6 +70,7 @@ class TrainingRuntime(BaseModel): optimizer: Any rank: int world_size: int + moe_routing_replay_controller: MoeRoutingReplayController | None = None class TrainStepResult(BaseModel): @@ -129,11 +136,47 @@ def _default_optimizer_config() -> OptimizerConfig: ) +def configure_moe_routing_replay( + runtime: TrainingRuntime, + *, + replay_bundle_path: str | None = None, + replay_bundle: MoeRoutingReplayBundle | None = None, + strict: bool = True, +) -> None: + if runtime.moe_routing_replay_controller is not None: + runtime.moe_routing_replay_controller.remove_router_patches() + runtime.moe_routing_replay_controller = None + + if replay_bundle is not None and replay_bundle_path is not None: + raise RuntimeError( + "Provide either replay_bundle_path or replay_bundle, not both" + ) + if replay_bundle is None and replay_bundle_path is None: + return + + if replay_bundle is None: + if replay_bundle_path is None: + raise RuntimeError( + "replay_bundle_path is required when replay_bundle is None" + ) + replay_bundle = MoeRoutingReplayBundle.from_dir(replay_bundle_path) + + controller = MoeRoutingReplayController( + bundle=replay_bundle, + strict=strict, + ) + controller.install_router_patches(runtime.model) + runtime.moe_routing_replay_controller = controller + + def build_training_runtime( *, model_identifier: str | None = None, provider_configure: Callable[[Any], None] | None = None, optimizer_config: OptimizerConfig | None = None, + moe_routing_replay_path: str | None = None, + moe_routing_replay_bundle: MoeRoutingReplayBundle | None = None, + moe_routing_replay_strict: bool = True, print_env: bool = True, print_optimizer_stats: bool = True, ) -> TrainingRuntime: @@ -186,13 +229,20 @@ def build_training_runtime( percent = (num_params / total_params) * 100 if total_params > 0 else 0 print(f"Optimizer parameters as percent of total: {percent:0.2f}%") - return TrainingRuntime( + runtime = TrainingRuntime( provider=provider, model=model, optimizer=optimizer, rank=rank, world_size=world_size, ) + configure_moe_routing_replay( + runtime, + replay_bundle_path=moe_routing_replay_path, + replay_bundle=moe_routing_replay_bundle, + strict=moe_routing_replay_strict, + ) + return runtime def iter_modules(model_chunks: list[MegatronModule]) -> Any: @@ -204,12 +254,17 @@ def iter_modules(model_chunks: list[MegatronModule]) -> Any: def load_adapter_into_model( model_chunks: list[MegatronModule], adapter_model: dict[str, torch.Tensor], + optimizer: Any | None = None, ) -> None: with torch.no_grad(): for module in iter_modules(model_chunks): if hasattr(module, "load_lora"): module.load_lora(adapter_model) # type: ignore[attr-defined] + if optimizer is None: + return + optimizer.reload_model_params() + def collect_sharded_lora_state( model_chunks: list[MegatronModule], @@ -285,7 +340,17 @@ def run_training_step( config: types.TrainConfig, experimental_config: dev.TrainConfig, ref_logprobs: torch.Tensor | None = None, + step_index: int | None = None, + sample_index: int | None = None, + moe_routing_replay_controller: MoeRoutingReplayController | None = None, ) -> TrainStepResult: + if moe_routing_replay_controller is not None: + assert step_index is not None and sample_index is not None + moe_routing_replay_controller.set_step( + step_index=step_index, + sample_index=sample_index, + ) + device = next(model_chunks[0].parameters()).device _move_inputs_to_device(inputs, device) @@ -322,6 +387,9 @@ def run_training_step( ) reduced_loss = _reduce_loss(loss) + if moe_routing_replay_controller is not None: + moe_routing_replay_controller.finalize_step() + return TrainStepResult( reduced_loss=reduced_loss, probs_corr=float(loss_info.probs_corr.item()), @@ -360,6 +428,12 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: config = job.config experimental_config = job.experimental_config + configure_moe_routing_replay( + runtime, + replay_bundle_path=job.moe_routing_replay_path, + strict=job.moe_routing_replay_strict, + ) + print0(runtime.rank, "Loaded job from", job_path) print0(runtime.rank, "Job:", job) @@ -368,7 +442,7 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") print0(runtime.rank, "Loading adapter model from", adapter_model_path) adapter_model = load_file(adapter_model_path) - load_adapter_into_model(runtime.model, adapter_model) + load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) optimizer_shard_path = os.path.join( job.optimizer_state_path, @@ -401,7 +475,7 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: repeat = math.ceil(num_indices / len(indices)) indices = (indices * repeat)[:num_indices] - for index in indices: + for step_index, index in enumerate(indices): inputs = select_indexed_inputs(packed_tensors, index) step_result = run_training_step( model_chunks=runtime.model, @@ -411,6 +485,9 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: config=config, experimental_config=experimental_config, ref_logprobs=None, + step_index=step_index, + sample_index=index, + moe_routing_replay_controller=runtime.moe_routing_replay_controller, ) print0( runtime.rank, From c5e06d96fb7f378e65e553c8b88f1471ce83544c Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Thu, 12 Mar 2026 09:40:47 +0000 Subject: [PATCH 08/19] oracle worker/trace: capture forward traces and emit replay bundles --- tests/integration/megatron_forward_trace.py | 489 ++++++++++++++++++ tests/integration/megatron_oracle_worker.py | 521 ++++++++++++++++++++ 2 files changed, 1010 insertions(+) create mode 100644 tests/integration/megatron_forward_trace.py create mode 100644 tests/integration/megatron_oracle_worker.py diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py new file mode 100644 index 00000000..2ca418aa --- /dev/null +++ b/tests/integration/megatron_forward_trace.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, cast + +import torch + +CAPTURE_NAME_TOKENS = ( + ".self_attention.linear_qkv", + ".self_attention.linear_qkv.q_proj_lora", + ".self_attention.linear_qkv.k_proj_lora", + ".self_attention.linear_qkv.v_proj_lora", + ".self_attention.linear_proj", + ".self_attention.linear_proj.lora", + ".mlp.router", + ".mlp.experts.linear_fc1", + ".mlp.experts.linear_fc1.gate_lora", + ".mlp.experts.linear_fc1.up_lora", + ".mlp.experts.linear_fc2", + ".mlp.experts.linear_fc2.lora", +) +ROUTER_NAME_TOKEN = ".mlp.router" + + +def _safe_int(value: Any, default: int = 0) -> int: + """Coerces scalar values to int for trace metadata.""" + try: + return int(value) + except Exception: + return default + + +def _safe_ps_stat(name: str, default: int) -> int: + """Reads one Megatron parallel-state integer when available.""" + try: + from megatron.core import parallel_state as ps + + getter = getattr(ps, name) + return _safe_int(getter(), default) + except Exception: + return default + + +def _rank_metadata() -> dict[str, int]: + """Builds lightweight distributed metadata for one trace call.""" + rank = 0 + world_size = 1 + if torch.distributed.is_initialized(): + rank = _safe_int(torch.distributed.get_rank(), 0) + world_size = _safe_int(torch.distributed.get_world_size(), 1) + return { + "global_rank": rank, + "world_size": world_size, + "tp_rank": _safe_ps_stat("get_tensor_model_parallel_rank", 0), + "tp_world_size": _safe_ps_stat("get_tensor_model_parallel_world_size", 1), + "ep_rank": _safe_ps_stat("get_expert_model_parallel_rank", 0), + "ep_world_size": _safe_ps_stat("get_expert_model_parallel_world_size", 1), + "etp_rank": _safe_ps_stat("get_expert_tensor_parallel_rank", 0), + "etp_world_size": _safe_ps_stat("get_expert_tensor_parallel_world_size", 1), + "dp_rank": _safe_ps_stat("get_data_parallel_rank", 0), + "dp_world_size": _safe_ps_stat("get_data_parallel_world_size", 1), + "expert_dp_rank": _safe_ps_stat("get_expert_data_parallel_rank", 0), + "expert_dp_world_size": _safe_ps_stat("get_expert_data_parallel_world_size", 1), + } + + +def _shard_world_size_for_domain(domain: Any) -> int: + """Returns shard-group world size for one LoRA shard domain.""" + if domain == "tp": + return _safe_ps_stat("get_tensor_model_parallel_world_size", 1) + if domain == "expert_tp": + return _safe_ps_stat("get_expert_tensor_parallel_world_size", 1) + return 1 + + +def _extract_primary_tensor(value: Any) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, dict): + for item in value.values(): + tensor = _extract_primary_tensor(item) + if tensor is not None: + return tensor + if isinstance(value, (list, tuple)): + for item in value: + tensor = _extract_primary_tensor(item) + if tensor is not None: + return tensor + return None + + +def _materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: + if hasattr(tensor, "full_tensor"): + tensor = cast(torch.Tensor, tensor.full_tensor()) + elif hasattr(tensor, "to_local"): + tensor = cast(torch.Tensor, tensor.to_local()) + elif hasattr(tensor, "_local_tensor"): + tensor = cast(torch.Tensor, tensor._local_tensor) + return tensor.detach().cpu() + + +def _materialize_trace_value(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return _materialize_tensor(value) + if isinstance(value, dict): + return {key: _materialize_trace_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_materialize_trace_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_materialize_trace_value(item) for item in value) + return value + + +def _extract_router_topk(output: Any) -> tuple[torch.Tensor, torch.Tensor] | None: + if not isinstance(output, tuple) or len(output) < 2: + return None + probs = output[0] + routing_map = output[1] + if not isinstance(probs, torch.Tensor) or not isinstance(routing_map, torch.Tensor): + return None + probs = _materialize_tensor(probs.float()) + routing_map = _materialize_tensor(routing_map) + topk = int(routing_map.sum(dim=-1).max().item()) + if topk < 0: + raise RuntimeError(f"Invalid router topk={topk}") + if topk == 0: + topk_scores = probs.new_zeros((probs.shape[0], 0)) + topk_ids = torch.zeros((probs.shape[0], 0), dtype=torch.int64) + else: + topk_scores, topk_ids = torch.topk(probs, k=topk, dim=-1) + return topk_ids.contiguous(), topk_scores.contiguous() + + +class ForwardTraceCapture: + def __init__( + self, + model_chunks: list[Any], + *, + enabled: bool, + capture_name_tokens: tuple[str, ...] = CAPTURE_NAME_TOKENS, + ) -> None: + self.enabled = enabled + self.capture_name_tokens = capture_name_tokens + self.current_step_index: int | None = None + self.current_step_trace: dict[str, list[dict[str, Any]]] = {} + self._hook_handles: list[Any] = [] + if not enabled: + return + self._register_hooks(model_chunks) + + def _register_hooks(self, model_chunks: list[Any]) -> None: + for chunk_index, chunk in enumerate(model_chunks): + for module_name, module in chunk.named_modules(): + trace_module_name = f"chunk{chunk_index}.{module_name}" + is_layer_output = ( + ".decoder.layers." in module_name + and module_name.rsplit(".", 1)[-1].isdigit() + ) + if not is_layer_output and not any( + module_name.endswith(token) for token in self.capture_name_tokens + ): + continue + self._hook_handles.append( + module.register_forward_hook( + self._make_hook(trace_module_name, module) + ) + ) + + @staticmethod + def _sequence_parallel_enabled(module: Any) -> bool: + """Returns sequence-parallel flag from module/provider/config when present.""" + for owner in ( + module, + getattr(module, "provider", None), + getattr(module, "config", None), + ): + if owner is None: + continue + value = getattr(owner, "sequence_parallel", None) + if isinstance(value, bool): + return value + return False + + @staticmethod + def _lora_primary_output_merge_hint(module: Any) -> dict[str, Any] | None: + """Infers the correct output merge op for LoRA modules.""" + if module.__class__.__name__ != "LoRA": + return None + lora_module = module + b_param = getattr(lora_module, "B_T", None) + if b_param is None: + return None + b_domain = getattr(b_param, "lora_shard_domain", None) + b_world_size = _shard_world_size_for_domain(b_domain) + if bool(getattr(b_param, "lora_tp_sharded", False)) and b_world_size > 1: + shard_dim = getattr(b_param, "lora_tp_shard_dim", None) + if isinstance(shard_dim, int): + return {"op": "concat", "dim": shard_dim} + a_param = getattr(lora_module, "A_T", None) + if a_param is None: + return None + a_domain = getattr(a_param, "lora_shard_domain", None) + a_world_size = _shard_world_size_for_domain(a_domain) + if bool(getattr(a_param, "lora_tp_sharded", False)) and a_world_size > 1: + return {"op": "sum"} + return None + + def _infer_primary_output_merge_hint( + self, name: str, module: Any + ) -> dict[str, Any] | None: + """Chooses canonical cross-rank concat axis for one module output.""" + if ROUTER_NAME_TOKEN in name: + return {"op": "concat", "dim": 0} + + lora_hint = self._lora_primary_output_merge_hint(module) + if lora_hint is not None: + return lora_hint + + gather_output = getattr(module, "gather_output", None) + if isinstance(gather_output, bool) and not gather_output: + return {"op": "concat", "dim": -1} + + if ".self_attention.linear_qkv" in name: + return {"op": "concat", "dim": -1} + + if ".mlp.experts." in name: + return {"op": "concat", "dim": 0} + + if bool( + getattr(module, "input_is_parallel", False) + ) and self._sequence_parallel_enabled(module): + return {"op": "concat", "dim": 0} + + return None + + def _build_merge_hints(self, name: str, module: Any) -> dict[str, dict[str, Any]]: + """Builds field-level tensor merge hints for one call record.""" + hints: dict[str, dict[str, Any]] = {} + primary_output_hint = self._infer_primary_output_merge_hint(name, module) + if primary_output_hint is not None: + hints["primary_output"] = primary_output_hint + if ROUTER_NAME_TOKEN in name: + concat_dim0 = {"op": "concat", "dim": 0} + hints["output"] = concat_dim0 + hints["router_topk_ids"] = concat_dim0 + hints["router_topk_scores"] = concat_dim0 + return hints + + def _make_hook(self, name: str, module: Any): + def _hook(_module: Any, inputs: Any, output: Any) -> None: + if self.current_step_index is None: + return + call_index = len(self.current_step_trace.get(name, [])) + trace_item: dict[str, Any] = { + "call_index": call_index, + "module_type": module.__class__.__name__, + "rank_meta": _rank_metadata(), + "merge_hints": self._build_merge_hints(name, module), + "inputs": _materialize_trace_value(inputs), + "output": _materialize_trace_value(output), + "primary_input": self.guess_primary_tensor(inputs), + "primary_output": self.guess_primary_tensor(output), + } + if ROUTER_NAME_TOKEN in name: + router_topk = _extract_router_topk(output) + if router_topk is not None: + topk_ids, topk_scores = router_topk + trace_item["router_topk_ids"] = topk_ids + trace_item["router_topk_scores"] = topk_scores + self.current_step_trace.setdefault(name, []).append(trace_item) + + return _hook + + @staticmethod + def guess_primary_tensor(value: Any) -> torch.Tensor | None: + tensor = _extract_primary_tensor(value) + if tensor is None: + return None + return _materialize_tensor(tensor) + + def set_step(self, step_index: int) -> None: + self.current_step_index = step_index + self.current_step_trace = {} + + @classmethod + def _merge_rank_values( + cls, + values_by_rank: list[Any], + *, + preferred_cat_dim: int | None = None, + preferred_reduce: str | None = None, + ) -> Any: + if not values_by_rank: + raise RuntimeError("Cannot merge empty rank value list") + if all(isinstance(value, torch.Tensor) for value in values_by_rank): + tensors = cast(list[torch.Tensor], values_by_rank) + if preferred_reduce == "sum" and all( + tensors[0].shape == tensor.shape for tensor in tensors[1:] + ): + return torch.stack(tensors, dim=0).sum(dim=0) + if ( + preferred_cat_dim is not None + and all(tensor.ndim > 0 for tensor in tensors) + and cls._can_cat_along_dim(tensors, dim=preferred_cat_dim) + ): + return torch.cat(tensors, dim=preferred_cat_dim) + if all( + tensors[0].shape == tensor.shape and torch.equal(tensors[0], tensor) + for tensor in tensors[1:] + ): + return tensors[0] + if all(tensor.ndim > 0 for tensor in tensors): + if cls._can_cat_along_dim(tensors, dim=0): + return torch.cat(tensors, dim=0) + if cls._can_cat_along_dim(tensors, dim=-1): + return torch.cat(tensors, dim=-1) + if all(tensors[0].shape == tensor.shape for tensor in tensors[1:]): + return torch.stack(tensors, dim=0) + return tensors + if all(isinstance(value, dict) for value in values_by_rank): + dicts = cast(list[dict[str, Any]], values_by_rank) + keys = sorted(set().union(*(value.keys() for value in dicts))) + return { + key: cls._merge_rank_values( + [value[key] for value in dicts if key in value], + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + for key in keys + } + if all(isinstance(value, list) for value in values_by_rank): + lists = cast(list[list[Any]], values_by_rank) + if any(len(values) != len(lists[0]) for values in lists[1:]): + return lists + return [ + cls._merge_rank_values( + [value[index] for value in lists], + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + for index in range(len(lists[0])) + ] + if all(isinstance(value, tuple) for value in values_by_rank): + tuples = cast(list[tuple[Any, ...]], values_by_rank) + if any(len(values) != len(tuples[0]) for values in tuples[1:]): + return tuples + return tuple( + cls._merge_rank_values( + [value[index] for value in tuples], + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + for index in range(len(tuples[0])) + ) + if all(value == values_by_rank[0] for value in values_by_rank[1:]): + return values_by_rank[0] + return values_by_rank[0] + + @classmethod + def _merge_rank_call_entries( + cls, + rank_call_entries: list[dict[str, Any]], + ) -> dict[str, Any]: + """Merges one module call across ranks using per-field merge hints.""" + merged_call: dict[str, Any] = {} + keys = sorted(set().union(*(entry.keys() for entry in rank_call_entries))) + for key in keys: + values = [entry[key] for entry in rank_call_entries if key in entry] + if key == "rank_meta": + merged_call[key] = values + continue + preferred_cat_dim: int | None = None + preferred_reduce: str | None = None + if values and key not in {"merge_hints", "call_index", "module_type"}: + hint_values = [ + cast(dict[str, Any], entry["merge_hints"]).get(key) + for entry in rank_call_entries + if isinstance(entry.get("merge_hints"), dict) + ] + op_hints = [ + hint + for hint in hint_values + if isinstance(hint, dict) and isinstance(hint.get("op"), str) + ] + if op_hints: + selected_hint = op_hints[0] + op = selected_hint.get("op") + if op == "concat": + dim = selected_hint.get("dim") + if isinstance(dim, int): + preferred_cat_dim = dim + elif op == "sum": + preferred_reduce = "sum" + if ( + preferred_reduce is None + and preferred_cat_dim == 0 + and all(isinstance(value, torch.Tensor) for value in values) + ): + merged_call[f"{key}__row_splits"] = [ + int(cast(torch.Tensor, value).shape[0]) for value in values + ] + merged_call[key] = cls._merge_rank_values( + values, + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + return merged_call + + @staticmethod + def _can_cat_along_dim(tensors: list[torch.Tensor], dim: int) -> bool: + if not tensors: + return False + if tensors[0].ndim == 0: + return False + ndim = tensors[0].ndim + axis = dim if dim >= 0 else ndim + dim + if axis < 0 or axis >= ndim: + return False + if any(tensor.ndim != ndim for tensor in tensors[1:]): + return False + for dim_index in range(ndim): + if dim_index == axis: + continue + dim_size = tensors[0].shape[dim_index] + if any(tensor.shape[dim_index] != dim_size for tensor in tensors[1:]): + return False + return True + + @classmethod + def _merge_rank_traces( + cls, + rank_traces: list[dict[str, list[dict[str, Any]]]], + ) -> dict[str, list[dict[str, Any]]]: + if len(rank_traces) == 1: + return rank_traces[0] + merged: dict[str, list[dict[str, Any]]] = {} + module_names = sorted(set().union(*(trace.keys() for trace in rank_traces))) + for module_name in module_names: + call_count = max(len(trace.get(module_name, [])) for trace in rank_traces) + module_calls: list[dict[str, Any]] = [] + for call_index in range(call_count): + rank_values = [ + trace[module_name][call_index] + for trace in rank_traces + if module_name in trace and call_index < len(trace[module_name]) + ] + if not rank_values: + continue + module_calls.append(cls._merge_rank_call_entries(rank_values)) + merged[module_name] = module_calls + return merged + + @staticmethod + def _gather_rank_traces( + local_trace: dict[str, list[dict[str, Any]]], + ) -> list[dict[str, list[dict[str, Any]]]] | None: + if ( + not torch.distributed.is_initialized() + or torch.distributed.get_world_size() == 1 + ): + return [local_trace] + gathered: list[dict[str, list[dict[str, Any]]] | None] = [ + None + ] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(gathered, local_trace) + if torch.distributed.get_rank() != 0: + return None + return cast(list[dict[str, list[dict[str, Any]]]], gathered) + + def save_current_step(self, traces_dir: Path) -> Path | None: + if not self.enabled or self.current_step_index is None: + return None + gathered_traces = self._gather_rank_traces(self.current_step_trace) + if gathered_traces is None: + return None + merged_trace = self._merge_rank_traces(gathered_traces) + traces_dir.mkdir(parents=True, exist_ok=True) + trace_path = traces_dir / f"forward_trace_step_{self.current_step_index:03d}.pt" + torch.save(merged_trace, trace_path) + return trace_path + + @staticmethod + def load_trace(trace_path: Path) -> dict[str, list[dict[str, Any]]]: + return torch.load(trace_path, map_location="cpu", weights_only=False) + + def close(self) -> None: + for handle in self._hook_handles: + handle.remove() + self._hook_handles.clear() diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py new file mode 100644 index 00000000..91c7647d --- /dev/null +++ b/tests/integration/megatron_oracle_worker.py @@ -0,0 +1,521 @@ +from __future__ import annotations + +import argparse +from contextlib import contextmanager +import os +from pathlib import Path +import random +import subprocess +import sys +from typing import Any, Callable + +import numpy as np + +from art.megatron.routing_replay import ( + ParallelTopology as ReplayParallelTopology, +) +from art.megatron.routing_replay import ( + build_bundle_from_forward_trace_dir, +) + +from .megatron_forward_trace import ForwardTraceCapture +from .megatron_oracle_harness import ( + OracleCaseConfig, + RunManifest, + SensitivityMutation, + StepTrace, + Topology, + WorkerRunRequest, + _read_json, + _require_not_none, + _write_json, +) + + +def run_worker_subprocess( + request: WorkerRunRequest, + topology_dir: Path, + *, + repo_root: Path, +) -> None: + """Runs one distributed worker subprocess and stores combined logs.""" + request_path = topology_dir / "run_request.json" + _write_json(request_path, request.model_dump(mode="json")) + worker_module = "integration.megatron_oracle_worker" + worker_cwd = repo_root / "tests" + + command = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(request.topology.world_size()), + "-m", + worker_module, + "--worker-run", + "--run-request", + str(request_path), + ] + run = subprocess.run( + command, + cwd=str(worker_cwd), + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + capture_output=True, + text=True, + check=False, + ) + combined_output = f"{run.stdout}\n{run.stderr}".strip() + (topology_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") + if run.returncode != 0: + tail = "\n".join(combined_output.splitlines()[-80:]) + raise RuntimeError( + f"Topology run failed for {request.topology.slug()} with exit code " + f"{run.returncode}.\n{tail}" + ) + + +def _set_deterministic_seed(seed: int) -> None: + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def _merge_sharded_dicts(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any]: + """Merges rank-sharded LoRA tensors into a full state dict on rank 0.""" + import torch + + merged: dict[str, list[Any]] = {} + for rank_shards in shards_by_rank: + for key, tensor in rank_shards.items(): + merged.setdefault(key, []).append(tensor.detach().cpu()) + full_state: dict[str, Any] = {} + for key, shards in merged.items(): + if len(shards) == 1: + full_state[key] = shards[0].contiguous() + continue + concat_dim = 1 if ".lora_A." in key else 0 + full_state[key] = torch.cat(shards, dim=concat_dim).contiguous() + return full_state + + +def _gather_full_state(local_state: dict[str, Any]) -> dict[str, Any] | None: + """Gathers local state dicts to rank 0 and merges them.""" + import torch + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + gathered = [None for _ in range(world_size)] if rank == 0 else None + torch.distributed.gather_object(local_state, gathered, dst=0) + if rank != 0: + return None + assert gathered is not None + entries = [entry for entry in gathered if entry is not None] + return _merge_sharded_dicts(entries) + + +def _collect_lora_state(model_chunks: list[Any]) -> dict[str, Any] | None: + """Collects full LoRA adapter state for validation and delta computation.""" + local_state: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if not hasattr(module, "sharded_lora_state_dict"): + continue + module_state = module.sharded_lora_state_dict() + for key, value in module_state.items(): + if key in local_state: + raise RuntimeError( + f"Duplicate LoRA key while collecting state: {key}" + ) + local_state[key] = value.detach().cpu() + return _gather_full_state(local_state) + + +def _collect_lora_grads(model_chunks: list[Any]) -> dict[str, Any] | None: + """Collects full LoRA gradient tensors across all ranks.""" + from art.megatron.lora import LoRA + + local_grads: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + for key, param, expert in module._export_items(): # type: ignore[attr-defined] + if not hasattr(param, "main_grad"): + raise RuntimeError( + f"LoRA param missing main_grad attribute for key '{key}'" + ) + grad = param.main_grad + if grad is None: + raise RuntimeError(f"LoRA param main_grad is None for key '{key}'") + if hasattr(grad, "_local_tensor"): + grad = grad._local_tensor + local_grads[key] = ( + grad[expert].detach().cpu().T + if expert is not None + else grad.detach().cpu().T + ) + return _gather_full_state(local_grads) + + +def _validate_loaded_state_matches_adapter( + loaded_state: dict[str, Any], + adapter_model: dict[str, Any], +) -> None: + """Checks loaded model LoRA state exactly matches adapter tensors and keys.""" + import torch + + for key in sorted(adapter_model.keys()): + assert torch.equal(loaded_state[key].cpu(), adapter_model[key].cpu()), ( + f"Loaded LoRA state mismatch for key '{key}'" + ) + + +def _configure_provider( + provider: Any, + topology: Topology, + case_config: OracleCaseConfig, +) -> None: + """Applies deterministic topology/model overrides to provider config.""" + provider.tensor_model_parallel_size = topology.tp + provider.expert_model_parallel_size = topology.ep + provider.expert_tensor_parallel_size = topology.etp + # These are intentionally pinned to 1 for now; switching to topology-driven + # values is the single lever to start CP/PP coverage in the harness. + provider.pipeline_model_parallel_size = 1 + provider.context_parallel_size = 1 + provider.sequence_parallel = topology.sp + provider.num_layers = case_config.num_layers + if hasattr(provider, "attention_dropout"): + provider.attention_dropout = 0.0 + if hasattr(provider, "hidden_dropout"): + provider.hidden_dropout = 0.0 + + +def _build_optimizer_config(case_config: OracleCaseConfig): + """Builds Megatron optimizer settings for deterministic harness runs.""" + from megatron.core.optimizer import OptimizerConfig + + optimizer_kwargs = dict( + lr=case_config.learning_rate, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.1, + ) + return OptimizerConfig( + bf16=True, + **optimizer_kwargs, + ) + + +def _assert_runtime_configuration( + model_chunks: list[Any], + case_config: OracleCaseConfig, +) -> None: + """Validates runtime model depth equals requested oracle case config.""" + observed_num_layers: set[int] = set() + + for chunk in model_chunks: + module: Any = chunk + while hasattr(module, "module"): + module = module.module + config = getattr(module, "config", None) + if config is not None and hasattr(config, "num_layers"): + observed_num_layers.add(int(config.num_layers)) + + if observed_num_layers != {case_config.num_layers}: + raise RuntimeError( + "Runtime num_layers mismatch: " + f"requested={case_config.num_layers}, observed={sorted(observed_num_layers)}" + ) + + +def _delta_state( + initial_state: dict[str, Any], + current_state: dict[str, Any], +) -> dict[str, Any]: + """Computes LoRA parameter deltas while enforcing stable key sets.""" + initial_keys = set(initial_state.keys()) + current_keys = set(current_state.keys()) + if initial_keys != current_keys: + missing = sorted(initial_keys - current_keys) + extra = sorted(current_keys - initial_keys) + raise KeyError( + f"LoRA state keys changed during training: missing={missing[:3]} extra={extra[:3]}" + ) + return { + key: current_state[key].detach().cpu() - initial_state[key].detach().cpu() + for key in sorted(initial_keys) + } + + +@contextmanager +def _mutation_hook( + megatron_train_module: Any, + mutation: SensitivityMutation | None, + pre_optimizer_step_hook: Callable[[], None] | None = None, + loss_scale: float = 1.0, +): + """Applies optional sensitivity mutation hooks around training steps.""" + original_finalize = megatron_train_module._finalize_grads + original_optimizer_step = megatron_train_module._optimizer_step + original_loss_fn = megatron_train_module.loss_fn + + if mutation == "drop_finalize": + megatron_train_module._finalize_grads = lambda _model: None + elif mutation is not None: + raise ValueError(f"Unsupported mutation: {mutation}") + + if pre_optimizer_step_hook is not None: + + def _patched_optimizer_step(optimizer: Any, learning_rate: float): + pre_optimizer_step_hook() + return original_optimizer_step(optimizer, learning_rate) + + megatron_train_module._optimizer_step = _patched_optimizer_step + + if loss_scale <= 0: + raise ValueError(f"loss_scale must be > 0, got {loss_scale}") + if loss_scale != 1.0: + + def _scaled_loss_fn(*args: Any, **kwargs: Any): + loss = original_loss_fn(*args, **kwargs) + return loss.model_copy( + update={ + "mean_policy_loss": loss.mean_policy_loss * loss_scale, + "mean_kl": loss.mean_kl * loss_scale, + "policy_loss_sum": loss.policy_loss_sum * loss_scale, + } + ) + + megatron_train_module.loss_fn = _scaled_loss_fn + + if mutation is None: + if pre_optimizer_step_hook is None and loss_scale == 1.0: + yield + return + try: + yield + finally: + megatron_train_module._finalize_grads = original_finalize + megatron_train_module._optimizer_step = original_optimizer_step + megatron_train_module.loss_fn = original_loss_fn + + +def _worker_run(request: WorkerRunRequest) -> None: + """Executes one full distributed training trace generation worker run.""" + from safetensors.torch import load_file, save_file + import torch + + from art import dev, types + from art.megatron import train as megatron_train + from art.preprocessing.pack import packed_tensors_from_dir + + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(backend="nccl") + _set_deterministic_seed(request.case_config.seed) + + runtime = megatron_train.build_training_runtime( + model_identifier=request.case_config.base_model, + provider_configure=lambda provider: _configure_provider( + provider, request.topology, request.case_config + ), + optimizer_config=_build_optimizer_config(request.case_config), + print_env=False, + print_optimizer_stats=False, + ) + model_chunks = runtime.model + optimizer = runtime.optimizer + megatron_train.configure_moe_routing_replay( + runtime, + replay_bundle_path=request.moe_routing_replay_path, + strict=request.moe_routing_replay_strict, + ) + _assert_runtime_configuration(model_chunks, request.case_config) + + topology_dir = Path(request.topology_dir) + traces_dir = topology_dir / "traces" + traces_dir.mkdir(parents=True, exist_ok=True) + + # setup the shared initial lora + shared_init_path = Path(request.shared_init_adapter_path) + if not shared_init_path.exists(): + initial_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: + shared_init_path.parent.mkdir(parents=True, exist_ok=True) + save_file( + _require_not_none(initial_state, "initial_state"), + str(shared_init_path), + ) + torch.distributed.barrier() + + # load the shared initial lora into the model and validate we can collect it from the model + adapter_model = load_file(str(shared_init_path)) + megatron_train.load_adapter_into_model(model_chunks, adapter_model, optimizer) + loaded_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: + _validate_loaded_state_matches_adapter( + _require_not_none(loaded_state, "loaded_state"), adapter_model + ) + torch.distributed.barrier() + + # load the inputs + packed_tensors = packed_tensors_from_dir( + **request.packed_tensors.model_dump(exclude_none=True) + ) + initial_lora_state = loaded_state + + train_config = types.TrainConfig( + learning_rate=request.case_config.learning_rate, + beta=request.case_config.beta, + kl_penalty_coef=0.0, + ) + experimental_config: dev.TrainConfig = {} + step_traces: list[StepTrace] = [] + captured_grads: dict[str, Any] | None = None + forward_trace_capture = ForwardTraceCapture(model_chunks, enabled=True) + + def _capture_lora_grads() -> None: + nonlocal captured_grads + captured_grads = _collect_lora_grads(model_chunks) + + with _mutation_hook( + megatron_train, + request.mutation, + pre_optimizer_step_hook=_capture_lora_grads, + loss_scale=request.case_config.loss_scale, + ): + for step_index in range(request.case_config.num_steps): + forward_trace_capture.set_step(step_index) + sample_index = step_index % request.packed_tensors.num_sequences + inputs = megatron_train.select_indexed_inputs(packed_tensors, sample_index) + captured_grads = None + + step_result = megatron_train.run_training_step( + model_chunks=model_chunks, + optimizer=optimizer, + learning_rate=train_config.learning_rate, + inputs=inputs, + config=train_config, + experimental_config=experimental_config, + ref_logprobs=None, + step_index=step_index, + sample_index=sample_index, + moe_routing_replay_controller=runtime.moe_routing_replay_controller, + ) + forward_trace_capture.save_current_step(traces_dir) + torch.distributed.barrier() + current_lora_state = _collect_lora_state(model_chunks) + + if torch.distributed.get_rank() == 0: + # save artifacts (outputs, grads, lora deltas, current lora) + grads = _require_not_none(captured_grads, "captured_grads") + initial_state = _require_not_none( + initial_lora_state, "initial_lora_state" + ) + current_state = _require_not_none( + current_lora_state, "current_lora_state" + ) + deltas = _delta_state(initial_state, current_state) + + output_rel = Path("traces") / f"output_step_{step_index:03d}.pt" + grads_rel = Path("traces") / f"grads_step_{step_index:03d}.safetensors" + deltas_rel = ( + Path("traces") / f"deltas_step_{step_index:03d}.safetensors" + ) + lora_rel = Path(f"lora_step_{step_index:03d}.safetensors") + + torch.save( + step_result.new_logprobs.detach().cpu().float(), + topology_dir / output_rel, + ) + save_file(grads, str(topology_dir / grads_rel)) + save_file(deltas, str(topology_dir / deltas_rel)) + save_file(current_state, str(topology_dir / lora_rel)) + + # build and append the step trace + step_traces.append( + StepTrace( + step_index=step_index, + loss=float( + step_result.reduced_loss.item() + / request.case_config.loss_scale + ), + probs_corr=step_result.probs_corr, + output_file=str(output_rel), + grads_file=str(grads_rel), + deltas_file=str(deltas_rel), + lora_file=str(lora_rel), + ) + ) + torch.distributed.barrier() + + forward_trace_capture.close() + + if torch.distributed.get_rank() == 0: + # build and save the moe routing replay bundle + if request.capture_moe_routing_bundle_path is not None: + replay_bundle = build_bundle_from_forward_trace_dir( + traces_dir=traces_dir, + num_steps=request.case_config.num_steps, + topology=ReplayParallelTopology.model_validate( + request.topology.model_dump( + include={"tp", "ep", "etp", "dp", "sp", "cp", "pp", "vpp"}, + mode="python", + ) + ), + ) + replay_bundle.to_dir(request.capture_moe_routing_bundle_path) + + # build and save the run manifest + manifest = RunManifest( + case_id=request.case_id, + base_model=request.case_config.base_model, + num_layers=request.case_config.num_layers, + topology=request.topology.slug(), + world_size=request.topology.world_size(), + seed=request.case_config.seed, + num_steps=request.case_config.num_steps, + packed_tensors=request.packed_tensors, + tolerances=request.case_config.tolerances, + steps=step_traces, + ) + _write_json(topology_dir / "manifest.json", manifest.model_dump(mode="json")) + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def run_worker_cli(run_request_path: Path) -> None: + """Loads a worker request and dispatches worker execution.""" + request = WorkerRunRequest.model_validate(_read_json(run_request_path)) + _worker_run(request) + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + """Parses worker CLI arguments.""" + parser = argparse.ArgumentParser(description="Megatron oracle harness worker") + parser.add_argument("--worker-run", action="store_true") + parser.add_argument("--run-request", type=Path) + return parser.parse_args(argv) + + +def _main(argv: list[str]) -> int: + """CLI entry for worker-only execution mode.""" + args = _parse_args(argv) + if not args.worker_run: + raise SystemExit("This module is intended for test imports or --worker-run") + if args.run_request is None: + raise SystemExit("--run-request is required with --worker-run") + run_worker_cli(args.run_request) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main(sys.argv[1:])) From a73ca1a54b4cd870fd24a6711f074a112c324740 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Thu, 12 Mar 2026 09:40:50 +0000 Subject: [PATCH 09/19] oracle harness/tests: refactor suite and add oracle-replay parity flow --- tests/integration/megatron_oracle_harness.py | 1800 +++++++++-------- .../test_megatron_lora_oracle_correctness.py | 102 +- 2 files changed, 1022 insertions(+), 880 deletions(-) diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index 273ac368..a66385b0 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -1,29 +1,35 @@ from __future__ import annotations -import argparse -from contextlib import contextmanager +from functools import partial import hashlib import json +import math import os from pathlib import Path -import random +import re import shutil -import subprocess -import sys -from typing import Any, Callable, Literal, TypeVar, cast +from typing import Any, Literal, TypeVar, cast -import numpy as np from pydantic import BaseModel, ConfigDict, Field +from rich import box +from rich.console import Console +from rich.table import Table +import torch + +from .megatron_forward_trace import ForwardTraceCapture REPO_ROOT = Path(__file__).resolve().parents[2] -ARTIFACT_ROOT = Path(REPO_ROOT / ".local/megatron_lora_oracles") +ARTIFACT_ROOT = Path(REPO_ROOT / ".local/megatron_lora_correctness") +ORACLE_MOE_ROUTING_BUNDLE_DIRNAME = "oracle_moe_routing_replay" +ORACLE_REPLAY_TOPOLOGY_SUFFIX = "oracle_replay" REGENERATE_ENV = "ART_REGENERATE_MEGATRON_ORACLE" -BASE_MODEL_ENV = "ART_MEGATRON_ORACLE_BASE_MODEL" -DP_SUPPORT_ENV = "ART_MEGATRON_ORACLE_ENABLE_DP_PHASE_B" +EXTENDED_TOPOLOGIES_ENV = "ART_MEGATRON_ORACLE_ENABLE_EXTENDED_TOPOLOGIES" SENSITIVITY_MUTATION_ENV = "ART_MEGATRON_ORACLE_MUTATION" -SensitivityMutation = Literal["drop_finalize"] +DEFAULT_SENSITIVITY_MUTATION = "drop_finalize" +SUPPORTED_SENSITIVITY_MUTATIONS = (DEFAULT_SENSITIVITY_MUTATION,) +SensitivityMutation = str REQUIRED_PACKED_TENSOR_FILES = ( "tokens.pt", @@ -35,26 +41,68 @@ "advantages.pt", "weights.pt", ) +NON_FINITE_METRIC_VALUE = 1e30 +EXPERT_TABLE_ROW_LIMIT = 8 +EXPERT_TRIPLET_PARAM_RE = re.compile( + r"layers\.(?P\d+)\.mlp\.experts\.(?P\d+)\." + r"(?Pgate_proj|up_proj|down_proj)\." +) +PHASE_PRINT_ORDER = { + "forward": 0, + "router_scores": 1, + "router_topk_ids": 2, + "outputs": 3, + "losses": 4, + "grads": 5, + "deltas": 6, +} class Topology(BaseModel): + """Defines distributed topology settings for one Megatron run variant.""" + model_config = ConfigDict(frozen=True) tp: int ep: int etp: int = 1 dp: int = 1 - sp: int = 0 - phase: Literal["A", "B"] = "A" + sp: bool = False + cp: int = 1 + pp: int = 1 + vpp: int = 1 + + def resolved_expert_dp(self) -> int: + """Derives expert data parallel size from topology/world-size constraints.""" + attention_world = self.tp * self.cp * self.pp * self.dp + expert_divisor = self.etp * self.ep * self.pp + if attention_world % expert_divisor != 0: + raise ValueError( + "Invalid topology for Megatron expert parallelism: " + f"world_size={attention_world} is not divisible by " + f"etp*ep*pp={expert_divisor}." + ) + return attention_world // expert_divisor def slug(self) -> str: - return f"tp{self.tp}_ep{self.ep}_etp{self.etp}_dp{self.dp}_sp{self.sp}" + return ( + f"tp{self.tp}_ep{self.ep}_etp{self.etp}" + f"_dp{self.dp}_edp{self.resolved_expert_dp()}" + f"_cp{self.cp}_pp{self.pp}_vpp{self.vpp}_sp{int(self.sp)}" + ) def world_size(self) -> int: - return self.tp * self.ep * self.etp * self.dp + # Mirrors Megatron parallel-state sizing: + # attention side: world = tp * pp * cp * dp + # expert side must also divide this world size (validated in resolved_expert_dp()). + attention_world = self.tp * self.cp * self.pp * self.dp + self.resolved_expert_dp() + return attention_world class PackedTensorConfig(BaseModel): + """Controls synthetic packed tensor generation used by oracle harness runs.""" + num_sequences: int = 8 sequence_length: int = 256 prefill_tokens: int = 64 @@ -63,6 +111,8 @@ class PackedTensorConfig(BaseModel): class LoraConfig(BaseModel): + """Configures LoRA adapter dimensions and targeted module families.""" + rank: int = 1 alpha: int = 32 target_modules: list[str] = Field( @@ -79,28 +129,30 @@ class LoraConfig(BaseModel): class ToleranceProfile(BaseModel): - outputs_abs: float = 1e-2 - outputs_rel: float = 1e-2 - losses_abs: float = 1e-4 - losses_rel: float = 1e-4 - grads_abs: float = 1e-2 - grads_rel: float = 1e-2 - deltas_abs: float = 1e-2 - deltas_rel: float = 1e-2 + """Defines row-level pass/fail thresholds for variant comparison phases.""" + + relative_l2: float = 1e-2 + mean_abs_pct: float = 1.0 class OracleCaseConfig(BaseModel): + """Contains all deterministic run parameters for one oracle case.""" + base_model: str + num_layers: int = 4 seed: int = 20260305 - num_steps: int = 3 - learning_rate: float = 5e-6 + num_steps: int = 2 + learning_rate: float = 1e-3 beta: float = 0.0 + loss_scale: float = 1e4 packed_tensors: PackedTensorConfig = Field(default_factory=PackedTensorConfig) lora: LoraConfig = Field(default_factory=LoraConfig) tolerances: ToleranceProfile = Field(default_factory=ToleranceProfile) class DiskPackedTensorsSpec(BaseModel): + """Describes packed tensor artifacts persisted on disk for reuse.""" + dir: str num_sequences: int sequence_length: int @@ -109,6 +161,8 @@ class DiskPackedTensorsSpec(BaseModel): class CaseArtifacts(BaseModel): + """Holds stable case-level artifact paths used by all variants.""" + case_id: str case_dir: str packed_tensors: DiskPackedTensorsSpec @@ -116,17 +170,23 @@ class CaseArtifacts(BaseModel): class WorkerRunRequest(BaseModel): + """Defines one distributed worker invocation for generating variant artifacts.""" + case_id: str case_config: OracleCaseConfig topology: Topology topology_dir: str packed_tensors: DiskPackedTensorsSpec shared_init_adapter_path: str - allow_create_shared_init: bool = False mutation: SensitivityMutation | None = None + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True + capture_moe_routing_bundle_path: str | None = None class StepTrace(BaseModel): + """Tracks per-step trace artifact filenames and loss metadata.""" + step_index: int loss: float probs_corr: float @@ -137,8 +197,11 @@ class StepTrace(BaseModel): class RunManifest(BaseModel): + """Records run metadata and per-step trace references for one topology output.""" + case_id: str base_model: str + num_layers: int topology: str world_size: int seed: int @@ -148,18 +211,142 @@ class RunManifest(BaseModel): steps: list[StepTrace] -class ComparisonFailure(BaseModel): +class MetricRow(BaseModel): + """Represents one comparable unit (param/module/global) for one phase and step.""" + case_id: str + variant: str topology: str oracle_topology: str - metric: Literal["outputs", "losses", "grads", "lora_deltas"] step_index: int - key: str - max_abs_error: float - max_rel_error: float - abs_tolerance: float - rel_tolerance: float - message: str + phase: str + param: str + numel: float + mean_abs_diff: float + relative_l2: float + typical_abs_scale: float + mean_abs_pct: float + topk_mismatch_fraction: float | None = None + top1_mismatch_fraction: float | None = None + thresholds: dict[str, float] = Field(default_factory=dict) + pass_signal: bool = True + failure_reasons: list[str] = Field(default_factory=list) + + +class VariantSpec(BaseModel): + """Declares how to execute and evaluate one candidate variant against the oracle.""" + + name: str + topology: Topology + thresholds_by_phase: dict[str, dict[str, float]] + output_slug: str | None = None + reference_slug: str | None = None + mutation: SensitivityMutation | None = None + expected_signal: Literal["pass", "fail"] = "pass" + + def resolved_output_slug(self) -> str: + if self.output_slug is not None: + return self.output_slug + return _topology_output_slug(self.topology, self.mutation) + + def resolved_reference_slug(self) -> str: + if self.reference_slug is not None: + return self.reference_slug + return ORACLE_TOPOLOGY.slug() + + +class VariantReport(BaseModel): + """Captures full comparison output for one variant run.""" + + case_id: str + variant: str + topology: str + reference_topology: str + expected_signal: Literal["pass", "fail"] + signal: Literal["pass", "fail"] + pass_count: int + fail_count: int + step_summaries: dict[int, dict[str, Any]] + metrics: list[MetricRow] + + +class DiffAccumulator: + """Accumulates diff statistics across tensors and router-id mismatch counters.""" + + def __init__(self) -> None: + self.numel = 0 + self.abs_sum = 0.0 + self.diff_sq_sum = 0.0 + self.ref_sq_sum = 0.0 + self.ref_abs_sum = 0.0 + self.router_topk_total = 0 + self.router_topk_mismatch = 0 + self.router_top1_total = 0 + self.router_top1_mismatch = 0 + + def update(self, reference, candidate) -> None: # type: ignore[no-untyped-def] + """Adds one tensor pair into the accumulator.""" + ref = reference.detach().float() + cand = candidate.detach().float() + diff = (cand - ref).abs() + if diff.numel() == 0: + return + self.numel += int(diff.numel()) + self.abs_sum += float(diff.sum().item()) + self.diff_sq_sum += float((cand - ref).square().sum().item()) + self.ref_sq_sum += float(ref.square().sum().item()) + self.ref_abs_sum += float(ref.abs().sum().item()) + + def update_router_ids(self, reference_ids, candidate_ids) -> None: # type: ignore[no-untyped-def] + """Adds router top-k id mismatch counts into the accumulator.""" + self.router_topk_total += int(reference_ids.numel()) + self.router_topk_mismatch += int((reference_ids != candidate_ids).sum().item()) + if reference_ids.ndim >= 2 and reference_ids.shape[1] > 0: + self.router_top1_total += int(reference_ids.shape[0]) + self.router_top1_mismatch += int( + (reference_ids[:, 0] != candidate_ids[:, 0]).sum().item() + ) + + def as_summary(self) -> dict[str, float]: + """Returns normalized summary values for one row.""" + if self.numel == 0: + topk_fraction = 0.0 + top1_fraction = 0.0 + else: + topk_fraction = ( + self.router_topk_mismatch / self.router_topk_total + if self.router_topk_total > 0 + else 0.0 + ) + top1_fraction = ( + self.router_top1_mismatch / self.router_top1_total + if self.router_top1_total > 0 + else 0.0 + ) + if self.numel == 0: + return { + "numel": 0.0, + "mean_abs_diff": 0.0, + "relative_l2": 0.0, + "typical_abs_scale": 0.0, + "mean_abs_pct": 0.0, + "topk_mismatch_fraction": topk_fraction, + "top1_mismatch_fraction": top1_fraction, + } + mean_abs = self.abs_sum / self.numel + typical_abs = self.ref_abs_sum / self.numel + mean_abs_pct = (mean_abs / (typical_abs + 1e-12)) * 100.0 + return { + "numel": _finite_metric(float(self.numel), default=0.0), + "mean_abs_diff": _finite_metric(mean_abs), + "relative_l2": _finite_metric( + (self.diff_sq_sum**0.5) / max(self.ref_sq_sum**0.5, 1e-12) + ), + "typical_abs_scale": _finite_metric(typical_abs, default=0.0), + "mean_abs_pct": _finite_metric(mean_abs_pct), + "topk_mismatch_fraction": _finite_metric(topk_fraction, default=1.0), + "top1_mismatch_fraction": _finite_metric(top1_fraction, default=1.0), + } T = TypeVar("T") @@ -171,18 +358,18 @@ def _require_not_none(value: T | None, name: str) -> T: return value -PHASE_A_TOPOLOGIES = [ - Topology(tp=1, ep=1, etp=1, dp=1, sp=0, phase="A"), - Topology(tp=2, ep=1, etp=1, dp=1, sp=1, phase="A"), - Topology(tp=1, ep=2, etp=1, dp=1, sp=0, phase="A"), - Topology(tp=2, ep=2, etp=1, dp=1, sp=1, phase="A"), +TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=1, sp=False), + Topology(tp=2, ep=1, etp=1, dp=1, sp=True), + Topology(tp=1, ep=2, etp=1, dp=2, sp=False), + Topology(tp=2, ep=2, etp=1, dp=2, sp=True), ] -PHASE_B_TOPOLOGIES = [ - Topology(tp=1, ep=1, etp=1, dp=2, sp=0, phase="B"), - Topology(tp=2, ep=1, etp=1, dp=2, sp=1, phase="B"), +EXTENDED_TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=2, sp=False), + Topology(tp=2, ep=1, etp=1, dp=2, sp=True), ] -ORACLE_TOPOLOGY = PHASE_A_TOPOLOGIES[0] -SENSITIVITY_TOPOLOGY = PHASE_A_TOPOLOGIES[1] +ORACLE_TOPOLOGY = TOPOLOGIES[0] +SENSITIVITY_TOPOLOGY = TOPOLOGIES[1] def _truthy(value: str | None) -> bool: @@ -191,69 +378,57 @@ def _truthy(value: str | None) -> bool: return value.strip().lower() in {"1", "true", "yes", "on"} -def sensitivity_mutation() -> SensitivityMutation | None: +def sensitivity_mutations() -> list[SensitivityMutation]: + """Parses sensitivity mutation selectors from env as a CSV list.""" raw = os.environ.get(SENSITIVITY_MUTATION_ENV) if raw is None or raw.strip() == "": - return None + return [] normalized = raw.strip().lower() if normalized in {"1", "true", "yes", "on"}: - return "drop_finalize" - if normalized == "drop_finalize": - return "drop_finalize" + return [DEFAULT_SENSITIVITY_MUTATION] + mutations = [item.strip().lower() for item in raw.split(",") if item.strip()] + unsupported = [ + mutation + for mutation in mutations + if mutation not in SUPPORTED_SENSITIVITY_MUTATIONS + ] + if not unsupported: + return mutations + supported = ", ".join(SUPPORTED_SENSITIVITY_MUTATIONS) raise ValueError( f"Unsupported {SENSITIVITY_MUTATION_ENV} value '{raw}'. " - "Supported values: drop_finalize, 1/true/yes/on." + f"Supported values: {supported}, CSV of supported values, 1/true/yes/on." ) def sensitivity_enabled() -> bool: - return sensitivity_mutation() is not None + return bool(sensitivity_mutations()) -def phase_b_dp_enabled() -> bool: - return _truthy(os.environ.get(DP_SUPPORT_ENV)) +def extended_topologies_enabled() -> bool: + """Returns whether extended topologies are enabled for the suite.""" + return _truthy(os.environ.get(EXTENDED_TOPOLOGIES_ENV)) def regenerate_requested() -> bool: return _truthy(os.environ.get(REGENERATE_ENV)) -def default_case_config() -> OracleCaseConfig: - def _env_float(name: str, default: str) -> float: - return float(os.environ.get(name, default)) - - tolerances = ToleranceProfile( - outputs_abs=_env_float("ART_MEGATRON_ORACLE_OUTPUTS_ABS_TOL", "1e-2"), - outputs_rel=_env_float("ART_MEGATRON_ORACLE_OUTPUTS_REL_TOL", "1e-2"), - losses_abs=_env_float("ART_MEGATRON_ORACLE_LOSSES_ABS_TOL", "1e-4"), - losses_rel=_env_float("ART_MEGATRON_ORACLE_LOSSES_REL_TOL", "1e-4"), - grads_abs=_env_float("ART_MEGATRON_ORACLE_GRADS_ABS_TOL", "1e-2"), - grads_rel=_env_float("ART_MEGATRON_ORACLE_GRADS_REL_TOL", "1e-2"), - deltas_abs=_env_float("ART_MEGATRON_ORACLE_DELTAS_ABS_TOL", "1e-2"), - deltas_rel=_env_float("ART_MEGATRON_ORACLE_DELTAS_REL_TOL", "1e-2"), - ) - return OracleCaseConfig( - base_model=os.environ.get( - BASE_MODEL_ENV, - "Qwen/Qwen3-30B-A3B-Instruct-2507", - ), - seed=int(os.environ.get("ART_MEGATRON_ORACLE_SEED", "20260305")), - num_steps=int(os.environ.get("ART_MEGATRON_ORACLE_NUM_STEPS", "3")), - learning_rate=float(os.environ.get("ART_MEGATRON_ORACLE_LR", "5e-6")), - beta=float(os.environ.get("ART_MEGATRON_ORACLE_BETA", "0.0")), - tolerances=tolerances, - ) +def case_config( + base_model: str = "Qwen/Qwen3-30B-A3B-Instruct-2507", +) -> OracleCaseConfig: + """Builds the deterministic default oracle case config.""" + return OracleCaseConfig(base_model=base_model) def available_gpu_count() -> int: import torch - if not torch.cuda.is_available(): - return 0 return int(torch.cuda.device_count()) def stable_case_id(case_config: OracleCaseConfig) -> str: + """Builds a deterministic case id from case config contents.""" payload = case_config.model_dump(mode="json") encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) digest = hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:16] @@ -269,7 +444,7 @@ def stable_case_id(case_config: OracleCaseConfig) -> str: def _write_json(path: Path, payload: Any) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: - json.dump(payload, handle, indent=2, sort_keys=True) + json.dump(payload, handle, indent=2, sort_keys=True, allow_nan=False) def _read_json(path: Path) -> dict[str, Any]: @@ -281,6 +456,7 @@ def _build_packed_tensors( config: PackedTensorConfig, seed: int, ) -> dict[str, Any]: + """Generates deterministic synthetic packed tensors used in integration runs.""" import torch if config.num_sequences <= 1: @@ -351,6 +527,7 @@ def _create_packed_tensors( case_config: OracleCaseConfig, packed_dir: Path, ) -> DiskPackedTensorsSpec: + """Persists packed tensors to disk and returns their descriptor.""" from art.preprocessing.pack import PackedTensors, packed_tensors_to_dir packed_tensors = cast( @@ -361,15 +538,8 @@ def _create_packed_tensors( return DiskPackedTensorsSpec.model_validate(descriptor) -def _validate_packed_tensor_files(spec: DiskPackedTensorsSpec) -> None: - tensor_dir = Path(spec.dir) - for filename in REQUIRED_PACKED_TENSOR_FILES: - file_path = tensor_dir / filename - if not file_path.exists(): - raise FileNotFoundError(f"Missing packed tensor file: {file_path}") - - def ensure_case_artifacts(case_config: OracleCaseConfig) -> CaseArtifacts: + """Ensures stable case-level artifacts (input tensors) are present and reusable.""" case_id = stable_case_id(case_config) case_dir = ARTIFACT_ROOT / case_id case_dir.mkdir(parents=True, exist_ok=True) @@ -378,7 +548,6 @@ def ensure_case_artifacts(case_config: OracleCaseConfig) -> CaseArtifacts: descriptor_path = case_dir / "packed_tensors.json" if descriptor_path.exists(): packed_spec = DiskPackedTensorsSpec.model_validate(_read_json(descriptor_path)) - _validate_packed_tensor_files(packed_spec) else: packed_spec = _create_packed_tensors(case_config, case_dir / "packed_tensors") _write_json(descriptor_path, packed_spec.model_dump(mode="json")) @@ -394,809 +563,822 @@ def ensure_case_artifacts(case_config: OracleCaseConfig) -> CaseArtifacts: def _replace_topology_dir(path: Path) -> None: + """Resets one topology output directory before regeneration.""" if path.exists(): shutil.rmtree(path) path.mkdir(parents=True, exist_ok=True) (path / "traces").mkdir(parents=True, exist_ok=True) -def _run_worker_subprocess(request: WorkerRunRequest, topology_dir: Path) -> None: - request_path = topology_dir / "run_request.json" - _write_json(request_path, request.model_dump(mode="json")) - - command = [ - sys.executable, - "-m", - "torch.distributed.run", - "--standalone", - "--nproc_per_node", - str(request.topology.world_size()), - str(Path(__file__).resolve()), - "--worker-run", - "--run-request", - str(request_path), - ] - run = subprocess.run( - command, - cwd=str(REPO_ROOT), - env={**os.environ, "PYTHONUNBUFFERED": "1"}, - capture_output=True, - text=True, - check=False, - ) - combined_output = f"{run.stdout}\n{run.stderr}".strip() - (topology_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") - if run.returncode != 0: - tail = "\n".join(combined_output.splitlines()[-80:]) - raise RuntimeError( - f"Topology run failed for {request.topology.slug()} with exit code " - f"{run.returncode}.\n{tail}" - ) - - -def ensure_topology_artifacts( - case_config: OracleCaseConfig, +def _topology_output_slug( topology: Topology, - *, - regenerate: bool = False, mutation: SensitivityMutation | None = None, -) -> Path: - case_artifacts = ensure_case_artifacts(case_config) - case_dir = Path(case_artifacts.case_dir) - topology_dir = case_dir / topology.slug() - manifest_path = topology_dir / "manifest.json" - if manifest_path.exists() and not regenerate: - return topology_dir - - _replace_topology_dir(topology_dir) - shared_init_path = Path(case_artifacts.shared_init_adapter_path) - allow_create_shared_init = topology.slug() == ORACLE_TOPOLOGY.slug() - if not allow_create_shared_init and not shared_init_path.exists(): - ensure_topology_artifacts( - case_config=case_config, - topology=ORACLE_TOPOLOGY, - regenerate=False, - mutation=None, - ) - if not allow_create_shared_init and not shared_init_path.exists(): - raise FileNotFoundError( - f"Oracle shared adapter missing after oracle generation: {shared_init_path}" - ) - if mutation is not None and topology.slug() == ORACLE_TOPOLOGY.slug(): - raise RuntimeError("Sensitivity mutation cannot be applied to oracle topology") - - request = WorkerRunRequest( - case_id=case_artifacts.case_id, - case_config=case_config, - topology=topology, - topology_dir=str(topology_dir), - packed_tensors=case_artifacts.packed_tensors, - shared_init_adapter_path=str(shared_init_path), - allow_create_shared_init=allow_create_shared_init, - mutation=mutation, - ) - _run_worker_subprocess(request, topology_dir) - if not manifest_path.exists(): - raise RuntimeError(f"Missing manifest after run: {manifest_path}") - return topology_dir - - -def ensure_oracle_reference_artifacts( - *, - case_config: OracleCaseConfig, - regenerate: bool = False, -) -> Path: - return ensure_topology_artifacts( - case_config=case_config, - topology=ORACLE_TOPOLOGY, - regenerate=regenerate, - mutation=None, - ) +) -> str: + """Builds output slug for a topology and optional mutation variant.""" + return topology.slug() if mutation is None else f"{topology.slug()}__{mutation}" def _load_manifest(topology_dir: Path) -> RunManifest: + """Loads one run manifest for a topology output directory.""" manifest_path = topology_dir / "manifest.json" - if not manifest_path.exists(): - raise FileNotFoundError(f"Missing topology manifest: {manifest_path}") return RunManifest.model_validate(_read_json(manifest_path)) def _load_output_tensor(topology_dir: Path, step: StepTrace): + """Loads one output trace tensor referenced by a step trace entry.""" import torch path = topology_dir / step.output_file - if not path.exists(): - raise FileNotFoundError(f"Missing output trace: {path}") return torch.load(path, map_location="cpu") def _load_safetensor_map(path: Path) -> dict[str, Any]: + """Loads one safetensor map from disk.""" from safetensors.torch import load_file - if not path.exists(): - raise FileNotFoundError(f"Missing safetensor trace: {path}") return load_file(str(path)) -def _tensor_error(reference, candidate) -> tuple[float, float]: - ref = reference.detach().float() - cand = candidate.detach().float() - if ref.shape != cand.shape: - return float("inf"), float("inf") - if ref.numel() == 0: - return 0.0, 0.0 - diff = (cand - ref).abs() - max_abs = float(diff.max().item()) - max_rel = float((diff / ref.abs().clamp_min(1e-12)).max().item()) - return max_abs, max_rel +def _align_sequence_parallel(reference, candidate): # type: ignore[no-untyped-def] + """Aligns sequence-parallel-shaped tensors so diff computation is topology-agnostic.""" + if reference.shape == candidate.shape: + return candidate + if ( + candidate.ndim == reference.ndim + 1 + and candidate.shape[0] * candidate.shape[1] == reference.shape[0] + and tuple(candidate.shape[2:]) == tuple(reference.shape[1:]) + ): + return candidate.reshape(reference.shape) + return None -def _build_failure( - *, - case_id: str, - topology: str, - metric: Literal["outputs", "losses", "grads", "lora_deltas"], - step_index: int, - key: str, - max_abs_error: float, - max_rel_error: float, - abs_tolerance: float, - rel_tolerance: float, - message: str, -) -> ComparisonFailure: - return ComparisonFailure( - case_id=case_id, - topology=topology, - oracle_topology=ORACLE_TOPOLOGY.slug(), - metric=metric, - step_index=step_index, - key=key, - max_abs_error=max_abs_error, - max_rel_error=max_rel_error, - abs_tolerance=abs_tolerance, - rel_tolerance=rel_tolerance, - message=message, - ) +def _is_moe_base_forward_param(name: str) -> bool: + """Returns whether this forward param is a base MoE expert internal tensor.""" + if ".mlp.experts." not in name: + return False + if any(token in name for token in (".router", ".gate_lora", ".up_lora", ".lora")): + return False + return ".linear_fc1" in name or ".linear_fc2" in name -def _compare_tensor_pair( - *, - case_id: str, - topology: str, - metric: Literal["outputs", "losses", "grads", "lora_deltas"], - step_index: int, - key: str, - reference, - candidate, - abs_tolerance: float, - rel_tolerance: float, -) -> ComparisonFailure | None: - max_abs, max_rel = _tensor_error(reference, candidate) - if max_abs <= abs_tolerance or max_rel <= rel_tolerance: +def _lookup_call_by_index( + trace: dict[str, list[dict[str, Any]]], + module_name: str, + call_index: int, +) -> dict[str, Any] | None: + calls = trace.get(module_name) + if calls is None: return None - return _build_failure( - case_id=case_id, - topology=topology, - metric=metric, - step_index=step_index, - key=key, - max_abs_error=max_abs, - max_rel_error=max_rel, - abs_tolerance=abs_tolerance, - rel_tolerance=rel_tolerance, - message=f"{metric} mismatch at step {step_index}, key '{key}'", - ) - - -def _compare_tensor_maps( - *, - case_id: str, - topology: str, - metric: Literal["grads", "lora_deltas"], - step_index: int, - reference: dict[str, Any], - candidate: dict[str, Any], - abs_tolerance: float, - rel_tolerance: float, -) -> ComparisonFailure | None: - ref_keys = set(reference.keys()) - cand_keys = set(candidate.keys()) - if ref_keys != cand_keys: - missing = sorted(ref_keys - cand_keys) - extra = sorted(cand_keys - ref_keys) - return _build_failure( - case_id=case_id, - topology=topology, - metric=metric, - step_index=step_index, - key="__keys__", - max_abs_error=float("inf"), - max_rel_error=float("inf"), - abs_tolerance=abs_tolerance, - rel_tolerance=rel_tolerance, - message=( - f"{metric} key mismatch at step {step_index}; " - f"missing={missing[:3]}, extra={extra[:3]}" - ), - ) - for key in sorted(ref_keys): - failure = _compare_tensor_pair( - case_id=case_id, - topology=topology, - metric=metric, - step_index=step_index, - key=key, - reference=reference[key], - candidate=candidate[key], - abs_tolerance=abs_tolerance, - rel_tolerance=rel_tolerance, - ) - if failure is not None: - return failure + for call in calls: + if int(call.get("call_index", -1)) == call_index: + return call + if 0 <= call_index < len(calls): + return calls[call_index] return None -def _write_failure_report(topology_dir: Path, failure: ComparisonFailure) -> None: - _write_json(topology_dir / "failure_report.json", failure.model_dump(mode="json")) +def _router_module_name_for_expert_module(module_name: str) -> str | None: + if ".mlp.experts.linear_fc1" in module_name: + return module_name.replace(".mlp.experts.linear_fc1", ".mlp.router") + if ".mlp.experts.linear_fc2" in module_name: + return module_name.replace(".mlp.experts.linear_fc2", ".mlp.router") + return None -def compare_topology_to_oracle( +def _build_moe_row_identities( *, - case_config: OracleCaseConfig, - topology: Topology, -) -> ComparisonFailure | None: - if topology.slug() == ORACLE_TOPOLOGY.slug(): + module_name: str, + call_index: int, + trace: dict[str, list[dict[str, Any]]], + row_splits: list[int] | None, +) -> list[tuple[int, int, int]] | None: + router_module_name = _router_module_name_for_expert_module(module_name) + if router_module_name is None: return None - - case_id = stable_case_id(case_config) - case_dir = ARTIFACT_ROOT / case_id - oracle_dir = case_dir / ORACLE_TOPOLOGY.slug() - topology_dir = case_dir / topology.slug() - - oracle_manifest = _load_manifest(oracle_dir) - topology_manifest = _load_manifest(topology_dir) - if len(oracle_manifest.steps) != len(topology_manifest.steps): - return _build_failure( - case_id=case_id, - topology=topology.slug(), - metric="losses", - step_index=0, - key="__step_count__", - max_abs_error=float("inf"), - max_rel_error=float("inf"), - abs_tolerance=case_config.tolerances.losses_abs, - rel_tolerance=case_config.tolerances.losses_rel, - message=( - "Step count mismatch: " - f"oracle={len(oracle_manifest.steps)} vs " - f"topology={len(topology_manifest.steps)}" - ), - ) - - import torch - - for oracle_step, topology_step in zip( - oracle_manifest.steps, topology_manifest.steps - ): - step_index = oracle_step.step_index - oracle_outputs = _load_output_tensor(oracle_dir, oracle_step) - topology_outputs = _load_output_tensor(topology_dir, topology_step) - failure = _compare_tensor_pair( - case_id=case_id, - topology=topology.slug(), - metric="outputs", - step_index=step_index, - key="logprobs", - reference=oracle_outputs, - candidate=topology_outputs, - abs_tolerance=case_config.tolerances.outputs_abs, - rel_tolerance=case_config.tolerances.outputs_rel, - ) - if failure is not None: - return failure - - oracle_loss = torch.tensor([oracle_step.loss], dtype=torch.float32) - topology_loss = torch.tensor([topology_step.loss], dtype=torch.float32) - failure = _compare_tensor_pair( - case_id=case_id, - topology=topology.slug(), - metric="losses", - step_index=step_index, - key="loss", - reference=oracle_loss, - candidate=topology_loss, - abs_tolerance=case_config.tolerances.losses_abs, - rel_tolerance=case_config.tolerances.losses_rel, - ) - if failure is not None: - return failure - - for metric, oracle_file, topo_file, abs_tol, rel_tol in ( - ( - "grads", - oracle_step.grads_file, - topology_step.grads_file, - case_config.tolerances.grads_abs, - case_config.tolerances.grads_rel, - ), - ( - "lora_deltas", - oracle_step.deltas_file, - topology_step.deltas_file, - case_config.tolerances.deltas_abs, - case_config.tolerances.deltas_rel, - ), - ): - failure = _compare_tensor_maps( - case_id=case_id, - topology=topology.slug(), - metric=metric, - step_index=step_index, - reference=_load_safetensor_map(oracle_dir / oracle_file), - candidate=_load_safetensor_map(topology_dir / topo_file), - abs_tolerance=abs_tol, - rel_tolerance=rel_tol, - ) - if failure is not None: - return failure - return None + router_call = _lookup_call_by_index(trace, router_module_name, call_index) + if router_call is None: + return None + router_topk_ids = router_call.get("router_topk_ids") + if not isinstance(router_topk_ids, torch.Tensor) or router_topk_ids.ndim != 2: + return None + token_splits_raw = router_call.get("router_topk_ids__row_splits") + if row_splits is None: + if isinstance(token_splits_raw, list): + row_splits = [ + int(v) * int(router_topk_ids.shape[1]) for v in token_splits_raw + ] + else: + row_splits = [int(router_topk_ids.numel())] + if isinstance(token_splits_raw, list): + token_splits = [int(v) for v in token_splits_raw] + else: + topk = int(router_topk_ids.shape[1]) + token_splits = [int(v) // topk for v in row_splits] + if len(row_splits) != len(token_splits): + return None + row_cursor = 0 + token_cursor = 0 + identities: list[tuple[int, int, int]] = [] + for row_count, token_count in zip(row_splits, token_splits): + local_ids = router_topk_ids[token_cursor : token_cursor + token_count] + token_cursor += token_count + local_identities: list[tuple[int, int, int]] = [] + max_expert = int(local_ids.max().item()) if local_ids.numel() > 0 else -1 + for expert_id in range(max_expert + 1): + expert_rows = (local_ids == expert_id).nonzero(as_tuple=False) + for token_offset, slot_index in expert_rows.tolist(): + local_identities.append( + (expert_id, token_cursor - token_count + token_offset, slot_index) + ) + if len(local_identities) != row_count: + return None + identities.extend(local_identities) + row_cursor += row_count + if row_cursor != sum(row_splits): + return None + return identities -def run_and_compare_topology( +def _canonicalize_moe_base_forward_tensor( *, - case_config: OracleCaseConfig, - topology: Topology, - regenerate: bool = False, -) -> None: - ensure_oracle_reference_artifacts( - case_config=case_config, - regenerate=regenerate and topology.slug() == ORACLE_TOPOLOGY.slug(), + module_name: str, + call_index: int, + tensor: torch.Tensor, + trace: dict[str, list[dict[str, Any]]], + call: dict[str, Any], +) -> torch.Tensor: + if not _is_moe_base_forward_param(module_name): + return tensor + if tensor.ndim != 2: + return tensor + row_splits_raw = call.get("primary_output__row_splits") + row_splits = ( + [int(v) for v in row_splits_raw] if isinstance(row_splits_raw, list) else None ) - ensure_topology_artifacts( - case_config=case_config, - topology=topology, - regenerate=regenerate, - mutation=None, + identities = _build_moe_row_identities( + module_name=module_name, + call_index=call_index, + trace=trace, + row_splits=row_splits, ) - failure = compare_topology_to_oracle(case_config=case_config, topology=topology) - if failure is None: - return - topology_dir = ARTIFACT_ROOT / failure.case_id / topology.slug() - _write_failure_report(topology_dir, failure) - raise AssertionError( - "Megatron oracle mismatch: " - f"topology={failure.topology}, metric={failure.metric}, " - f"step={failure.step_index}, key={failure.key}, " - f"max_abs={failure.max_abs_error:.6g}, " - f"max_rel={failure.max_rel_error:.6g}, " - f"tol_abs={failure.abs_tolerance:.6g}, " - f"tol_rel={failure.rel_tolerance:.6g}" - ) - + if identities is None or len(identities) != int(tensor.shape[0]): + return tensor + order = sorted(range(len(identities)), key=lambda index: identities[index]) + return tensor[order] -def run_sensitivity_check( - *, - case_config: OracleCaseConfig, - regenerate: bool = False, -) -> None: - mutation = sensitivity_mutation() - if mutation is None: - raise RuntimeError( - f"Sensitivity check requires {SENSITIVITY_MUTATION_ENV} to be set" - ) - ensure_oracle_reference_artifacts( - case_config=case_config, - regenerate=regenerate, - ) - ensure_topology_artifacts( - case_config=case_config, - topology=SENSITIVITY_TOPOLOGY, - regenerate=True, - mutation=mutation, - ) - failure = compare_topology_to_oracle( - case_config=case_config, - topology=SENSITIVITY_TOPOLOGY, +def _minimal_param_name(name: str) -> str: + """Returns a shorter but 1:1 param/module identifier for report readability.""" + return name.removeprefix("base_model.model.model.").replace( + "module.module.decoder.", "" ) - if failure is None: - raise AssertionError( - "Sensitivity mutation did not produce an oracle mismatch. " - f"mutation={mutation}, topology={SENSITIVITY_TOPOLOGY.slug()}" - ) - -def _set_deterministic_seed(seed: int) -> None: - import torch - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False +def _load_forward_trace( + topology_dir: Path, step_index: int +) -> dict[str, list[dict[str, Any]]]: + """Loads one merged forward-trace file for a given step.""" + trace_path = topology_dir / "traces" / f"forward_trace_step_{step_index:03d}.pt" + return ForwardTraceCapture.load_trace(trace_path) -def _merge_sharded_dicts(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any]: - import torch +def _threshold_string(thresholds: dict[str, float]) -> str: + """Formats threshold dicts into compact table cells.""" + if not thresholds: + return "-" + return ", ".join(f"{key}<={value:.3g}" for key, value in sorted(thresholds.items())) - merged: dict[str, list[Any]] = {} - for rank_shards in shards_by_rank: - for key, tensor in rank_shards.items(): - merged.setdefault(key, []).append(tensor.detach().cpu()) - full_state: dict[str, Any] = {} - for key, shards in merged.items(): - if len(shards) == 1: - full_state[key] = shards[0].contiguous() - continue - concat_dim = 1 if ".lora_A." in key else 0 - full_state[key] = torch.cat(shards, dim=concat_dim).contiguous() - return full_state - - -def _gather_full_state(local_state: dict[str, Any]) -> dict[str, Any] | None: - import torch - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - gathered = [None for _ in range(world_size)] if rank == 0 else None - torch.distributed.gather_object(local_state, gathered, dst=0) - if rank != 0: - return None - assert gathered is not None - entries = [entry for entry in gathered if entry is not None] - return _merge_sharded_dicts(entries) +def _finite_metric(value: float, *, default: float = NON_FINITE_METRIC_VALUE) -> float: + """Maps NaN/Inf metric values to a large finite sentinel for JSON-safe reports.""" + value_f = float(value) + if math.isnan(value_f): + return default + if math.isinf(value_f): + return default if value_f > 0 else -default + return value_f -def _collect_lora_state(model_chunks: list[Any]) -> dict[str, Any] | None: - local_state: dict[str, Any] = {} - for chunk in model_chunks: - for module in chunk.modules(): - if not hasattr(module, "sharded_lora_state_dict"): - continue - module_state = module.sharded_lora_state_dict() - for key, value in module_state.items(): - if key in local_state: - raise RuntimeError( - f"Duplicate LoRA key while collecting state: {key}" - ) - local_state[key] = value.detach().cpu() - return _gather_full_state(local_state) - +def _triplet_expert_key(param: str) -> tuple[int, int] | None: + """Returns (layer, expert_id) for expert up/gate/down params.""" + match = EXPERT_TRIPLET_PARAM_RE.search(param) + if match is None: + return None + return int(match.group("layer")), int(match.group("expert")) + + +class VariantRunner: + """Runs oracle/candidate variants and emits row-level comparison reports.""" + + def __init__( + self, + *, + case_config: OracleCaseConfig, + console: Console | None = None, + ) -> None: + self.case_config = case_config + self.case_artifacts = ensure_case_artifacts(case_config) + self.case_id = self.case_artifacts.case_id + self.case_dir = Path(self.case_artifacts.case_dir) + self.oracle_slug = ORACLE_TOPOLOGY.slug() + self.oracle_dir = self.case_dir / self.oracle_slug + self.oracle_routing_bundle_dir = ( + self.case_dir / ORACLE_MOE_ROUTING_BUNDLE_DIRNAME + ) + self.shared_init_path = Path(self.case_artifacts.shared_init_adapter_path) + self.console = console or Console(width=160) + self._oracle_initialized = False + self._oracle_regenerated = False + + def _run_topology( + self, + *, + topology: Topology, + output_slug: str, + mutation: SensitivityMutation | None, + replay_bundle_dir: Path | None, + capture_bundle_dir: Path | None, + regenerate: bool, + ) -> Path: + """Executes one topology worker run and returns its output directory.""" + topology_dir = self.case_dir / output_slug + manifest_path = topology_dir / "manifest.json" + if manifest_path.exists() and not regenerate: + return topology_dir + _replace_topology_dir(topology_dir) + request = WorkerRunRequest( + case_id=self.case_id, + case_config=self.case_config, + topology=topology, + topology_dir=str(topology_dir), + packed_tensors=self.case_artifacts.packed_tensors, + shared_init_adapter_path=str(self.shared_init_path), + mutation=mutation, + moe_routing_replay_path=( + None if replay_bundle_dir is None else str(replay_bundle_dir) + ), + moe_routing_replay_strict=True, + capture_moe_routing_bundle_path=( + None if capture_bundle_dir is None else str(capture_bundle_dir) + ), + ) + from .megatron_oracle_worker import run_worker_subprocess -def _collect_lora_grads(model_chunks: list[Any]) -> dict[str, Any] | None: - from megatron.core import parallel_state as ps + run_worker_subprocess(request, topology_dir, repo_root=REPO_ROOT) + return topology_dir - from art.megatron.lora import LoRA + def ensure_oracle(self) -> Path: + """Ensures oracle capture and canonical replay artifacts exist exactly once per session.""" + regenerate = regenerate_requested() + if self._oracle_initialized and (not regenerate or self._oracle_regenerated): + return self.oracle_dir + if regenerate and self.shared_init_path.exists(): + self.shared_init_path.unlink() + bundle_manifest = self.oracle_routing_bundle_dir / "manifest.json" + oracle_manifest = self.oracle_dir / "manifest.json" + need_capture = ( + regenerate + or not bundle_manifest.exists() + or not self.shared_init_path.exists() + ) + run_oracle_topology = partial( + self._run_topology, + topology=ORACLE_TOPOLOGY, + mutation=None, + regenerate=True, + ) + if need_capture: + run_oracle_topology( + output_slug=f"{self.oracle_slug}__oracle_capture", + replay_bundle_dir=None, + capture_bundle_dir=self.oracle_routing_bundle_dir, + ) + if regenerate or not oracle_manifest.exists(): + run_oracle_topology( + output_slug=self.oracle_slug, + replay_bundle_dir=self.oracle_routing_bundle_dir, + capture_bundle_dir=None, + ) + self._oracle_initialized = True + self._oracle_regenerated = self._oracle_regenerated or regenerate + return self.oracle_dir + + def ensure_variant_artifacts( + self, + variant: VariantSpec, + ) -> Path: + """Ensures oracle prerequisites and candidate artifacts for one variant.""" + self.ensure_oracle() + output_slug = variant.resolved_output_slug() + if output_slug == self.oracle_slug and variant.mutation is None: + return self.oracle_dir + return self._run_topology( + topology=variant.topology, + output_slug=output_slug, + mutation=variant.mutation, + replay_bundle_dir=self.oracle_routing_bundle_dir, + capture_bundle_dir=None, + regenerate=True, + ) - local_grads: dict[str, Any] = {} - for chunk in model_chunks: - for module in chunk.modules(): - if not isinstance(module, LoRA): + @staticmethod + def _apply_thresholds(row: MetricRow, thresholds: dict[str, float]) -> None: + """Evaluates row thresholds using AND semantics over all configured keys.""" + row.thresholds = dict(thresholds) + if not thresholds: + row.pass_signal = True + row.failure_reasons = [] + return + payload = row.model_dump(mode="python") + reasons: list[str] = [] + for key, limit in sorted(thresholds.items()): + value = payload.get(key) + if not isinstance(value, (int, float)): + reasons.append(f"{key}=missing") continue - grad_a = ( - module.A_T.grad - if module.A_T.grad is not None - else module.A_T.new_zeros(module.A_T.shape) - ) - grad_b = ( - module.B_T.grad - if module.B_T.grad is not None - else module.B_T.new_zeros(module.B_T.shape) + if float(value) > float(limit): + reasons.append(f"{key}={float(value):.6g}>{float(limit):.6g}") + row.pass_signal = len(reasons) == 0 + row.failure_reasons = reasons + + @staticmethod + def _inf_summary() -> dict[str, float]: + """Builds a large-error finite summary for structural mismatches.""" + return { + "numel": 0.0, + "mean_abs_diff": NON_FINITE_METRIC_VALUE, + "relative_l2": NON_FINITE_METRIC_VALUE, + "typical_abs_scale": 0.0, + "mean_abs_pct": NON_FINITE_METRIC_VALUE, + "topk_mismatch_fraction": 1.0, + "top1_mismatch_fraction": 1.0, + } + + def _build_metric_row( + self, + *, + variant: VariantSpec, + step_index: int, + phase: str, + param: str, + summary: dict[str, float], + structural_failure: str | None = None, + ) -> MetricRow: + """Builds one metric row and applies per-phase thresholds.""" + row = MetricRow( + case_id=self.case_id, + variant=variant.name, + topology=variant.resolved_output_slug(), + oracle_topology=variant.resolved_reference_slug(), + step_index=step_index, + phase=phase, + param=param, + numel=summary["numel"], + mean_abs_diff=summary["mean_abs_diff"], + relative_l2=summary["relative_l2"], + typical_abs_scale=summary["typical_abs_scale"], + mean_abs_pct=summary["mean_abs_pct"], + topk_mismatch_fraction=summary.get("topk_mismatch_fraction"), + top1_mismatch_fraction=summary.get("top1_mismatch_fraction"), + ) + self._apply_thresholds(row, variant.thresholds_by_phase.get(phase, {})) + if structural_failure is not None: + row.pass_signal = False + row.failure_reasons = [structural_failure, *row.failure_reasons] + return row + + def _build_metric_rows_from_tensor_pairs( + self, + *, + variant: VariantSpec, + step_index: int, + phase: str, + pairs: list[tuple[str, Any, Any]], + router_ids: bool = False, + ) -> list[MetricRow]: + """Builds rows from named tensor pairs with one shared diff path.""" + rows: list[MetricRow] = [] + for name, reference, candidate in pairs: + shared_kwargs = { + "variant": variant, + "step_index": step_index, + "phase": phase, + "param": _minimal_param_name(name), + } + reference_aligned = reference + candidate_aligned = candidate + aligned_candidate = _align_sequence_parallel( + reference_aligned, candidate_aligned ) - if module.num_local_experts > 1: - if ps.get_expert_data_parallel_rank() != 0: - continue - for expert in range(module.num_local_experts): - prefix = module.adapter_model_prefix.format( - expert=expert + module._expert_offset - ) - local_grads[f"{prefix}.lora_A.weight"] = ( - grad_a[expert].detach().cpu().T - ) - local_grads[f"{prefix}.lora_B.weight"] = ( - grad_b[expert].detach().cpu().T + if aligned_candidate is None: + rows.append( + self._build_metric_row( + summary=self._inf_summary(), + structural_failure="shape mismatch", + **shared_kwargs, ) + ) + continue + accumulator = DiffAccumulator() + if router_ids: + accumulator.update_router_ids(reference_aligned, aligned_candidate) else: - if ps.get_data_parallel_rank() != 0: - continue - local_grads[f"{module.adapter_model_prefix}.lora_A.weight"] = ( - grad_a.detach().cpu().T + accumulator.update(reference_aligned, aligned_candidate) + rows.append( + self._build_metric_row( + summary=accumulator.as_summary(), **shared_kwargs ) - local_grads[f"{module.adapter_model_prefix}.lora_B.weight"] = ( - grad_b.detach().cpu().T + ) + return rows + + def _check_matching_keys( + self, + reference: dict[str, Any], + candidate: dict[str, Any], + variant: VariantSpec, + step_index: int, + phase: str, + ) -> tuple[bool, list[MetricRow] | None]: + """Checks if the keys of two tensor maps match and builds a metric row if they don't.""" + reference_keys = set(reference.keys()) + candidate_keys = set(candidate.keys()) + if reference_keys != candidate_keys: + missing = sorted(reference_keys - candidate_keys) + extra = sorted(candidate_keys - reference_keys) + return False, [ + self._build_metric_row( + variant=variant, + step_index=step_index, + phase=phase, + param="__keys__", + summary=self._inf_summary(), + structural_failure=f"missing={missing[:5]} extra={extra[:5]}", ) - return _gather_full_state(local_grads) - - -def _validate_adapter_exact( - expected_state: dict[str, Any], - adapter_model: dict[str, Any], -) -> None: - expected_keys = set(expected_state.keys()) - adapter_keys = set(adapter_model.keys()) - missing = sorted(expected_keys - adapter_keys) - extra = sorted(adapter_keys - expected_keys) - if missing or extra: - raise KeyError( - f"Adapter keys mismatch: missing={missing[:5]} extra={extra[:5]}" + ] + return True, None + + def _build_metric_rows_from_tensor_maps( + self, + *, + variant: VariantSpec, + step_index: int, + phase: str, + reference: dict[str, Any], + candidate: dict[str, Any], + router_ids: bool = False, + ) -> list[MetricRow]: + """Builds rows from two keyed tensor maps through a unified compare path.""" + matching, rows = self._check_matching_keys( + reference, candidate, variant, step_index, phase + ) + if not matching: + return rows + pairs = [ + (key, reference[key], candidate[key]) + for key in sorted(set(reference.keys())) + ] + return self._build_metric_rows_from_tensor_pairs( + variant=variant, + step_index=step_index, + phase=phase, + pairs=pairs, + router_ids=router_ids, ) - -def _validate_loaded_state_matches_adapter( - loaded_state: dict[str, Any], - adapter_model: dict[str, Any], -) -> None: - import torch - - for key in sorted(adapter_model.keys()): - if key not in loaded_state: - raise KeyError(f"Loaded LoRA state missing key: {key}") - if not torch.equal(loaded_state[key].cpu(), adapter_model[key].cpu()): - max_abs, max_rel = _tensor_error(adapter_model[key], loaded_state[key]) - raise RuntimeError( - f"Loaded LoRA state mismatch for key '{key}' " - f"(max_abs={max_abs:.6g}, max_rel={max_rel:.6g})" + @staticmethod + def _flatten_forward_trace_tensors( + trace: dict[str, list[dict[str, Any]]], + *, + value_key: str, + ) -> dict[str, Any]: + """Flattens per-module forward trace calls into a deterministic tensor map.""" + flattened: dict[str, Any] = {} + for module_name in sorted(trace.keys()): + for call_offset, call in enumerate(trace[module_name]): + tensor = call.get(value_key) + if tensor is None: + continue + call_index = call.get("call_index", call_offset) + if value_key == "primary_output" and isinstance(tensor, torch.Tensor): + tensor = _canonicalize_moe_base_forward_tensor( + module_name=module_name, + call_index=int(call_index), + tensor=tensor, + trace=trace, + call=call, + ) + flattened[f"{module_name}.call_{call_index}"] = tensor + return flattened + + @staticmethod + def _build_step_summaries(rows: list[MetricRow]) -> dict[int, dict[str, Any]]: + """Builds step-indexed payloads directly from row model dumps.""" + step_summaries: dict[int, dict[str, Any]] = {} + for row in rows: + step_entry = step_summaries.setdefault(row.step_index, {}) + phase_entry = cast(dict[str, Any], step_entry.setdefault(row.phase, {})) + phase_entry[row.param] = row.model_dump(mode="json") + return step_summaries + + def compare_variant(self, variant: VariantSpec) -> VariantReport: + """Compares one candidate variant against its reference topology.""" + reference_slug = variant.resolved_reference_slug() + topology_slug = variant.resolved_output_slug() + reference_dir = self.case_dir / reference_slug + topology_dir = self.case_dir / topology_slug + reference_manifest = _load_manifest(reference_dir) + topology_manifest = _load_manifest(topology_dir) + rows: list[MetricRow] = [] + if len(reference_manifest.steps) != len(topology_manifest.steps): + rows.append( + self._build_metric_row( + variant=variant, + step_index=0, + phase="step_count", + param="__step_count__", + summary=self._inf_summary(), + structural_failure=( + f"reference={len(reference_manifest.steps)} " + f"candidate={len(topology_manifest.steps)}" + ), + ) ) + import torch -def _configure_provider(provider: Any, topology: Topology) -> None: - provider.tensor_model_parallel_size = topology.tp - provider.expert_model_parallel_size = topology.ep - provider.expert_tensor_parallel_size = topology.etp - provider.pipeline_model_parallel_size = 1 - provider.context_parallel_size = 1 - provider.sequence_parallel = bool(topology.sp) - if hasattr(provider, "attention_dropout"): - provider.attention_dropout = 0.0 - if hasattr(provider, "hidden_dropout"): - provider.hidden_dropout = 0.0 - - -def _delta_state( - initial_state: dict[str, Any], - current_state: dict[str, Any], -) -> dict[str, Any]: - initial_keys = set(initial_state.keys()) - current_keys = set(current_state.keys()) - if initial_keys != current_keys: - missing = sorted(initial_keys - current_keys) - extra = sorted(current_keys - initial_keys) - raise KeyError( - f"LoRA state keys changed during training: missing={missing[:3]} extra={extra[:3]}" + for reference_step, topology_step in zip( + reference_manifest.steps, topology_manifest.steps + ): + step_index = reference_step.step_index + reference_trace = _load_forward_trace(reference_dir, step_index) + topology_trace = _load_forward_trace(topology_dir, step_index) + map_phase_inputs = [ + ( + "outputs", + {"logprobs": _load_output_tensor(reference_dir, reference_step)}, + {"logprobs": _load_output_tensor(topology_dir, topology_step)}, + False, + ), + ( + "losses", + {"loss": torch.tensor([reference_step.loss], dtype=torch.float32)}, + {"loss": torch.tensor([topology_step.loss], dtype=torch.float32)}, + False, + ), + ( + "grads", + _load_safetensor_map(reference_dir / reference_step.grads_file), + _load_safetensor_map(topology_dir / topology_step.grads_file), + False, + ), + ( + "deltas", + _load_safetensor_map(reference_dir / reference_step.deltas_file), + _load_safetensor_map(topology_dir / topology_step.deltas_file), + False, + ), + *[ + ( + phase, + self._flatten_forward_trace_tensors( + reference_trace, + value_key=value_key, + ), + self._flatten_forward_trace_tensors( + topology_trace, + value_key=value_key, + ), + phase == "router_topk_ids", + ) + for phase, value_key in ( + ("forward", "primary_output"), + ("router_scores", "router_topk_scores"), + ("router_topk_ids", "router_topk_ids"), + ) + ], + ] + for phase, reference_map, candidate_map, router_ids in map_phase_inputs: + rows.extend( + self._build_metric_rows_from_tensor_maps( + variant=variant, + step_index=step_index, + phase=phase, + reference=reference_map, + candidate=candidate_map, + router_ids=router_ids, + ) + ) + pass_count = sum(1 for row in rows if row.pass_signal) + fail_count = len(rows) - pass_count + signal: Literal["pass", "fail"] = "pass" if fail_count == 0 else "fail" + return VariantReport( + case_id=self.case_id, + variant=variant.name, + topology=topology_slug, + reference_topology=reference_slug, + expected_signal=variant.expected_signal, + signal=signal, + pass_count=pass_count, + fail_count=fail_count, + step_summaries=self._build_step_summaries(rows), + metrics=rows, ) - return { - key: current_state[key].detach().cpu() - initial_state[key].detach().cpu() - for key in sorted(initial_keys) - } - - -@contextmanager -def _mutation_hook( - megatron_train_module: Any, - mutation: SensitivityMutation | None, - pre_optimizer_step_hook: Callable[[], None] | None = None, -): - original_finalize = megatron_train_module._finalize_grads - original_optimizer_step = megatron_train_module._optimizer_step - - if mutation == "drop_finalize": - megatron_train_module._finalize_grads = lambda _model: None - elif mutation is not None: - raise ValueError(f"Unsupported mutation: {mutation}") - if pre_optimizer_step_hook is not None: - - def _patched_optimizer_step(optimizer: Any, learning_rate: float): - pre_optimizer_step_hook() - return original_optimizer_step(optimizer, learning_rate) - - megatron_train_module._optimizer_step = _patched_optimizer_step - - if mutation is None: - if pre_optimizer_step_hook is None: - yield + @staticmethod + def assert_expected_signal(report: VariantReport, context: str) -> None: + """Raises when observed run signal diverges from variant expectation.""" + if report.signal == report.expected_signal: return - try: - yield - finally: - megatron_train_module._finalize_grads = original_finalize - megatron_train_module._optimizer_step = original_optimizer_step - + if report.signal == "fail": + first_failure = next(row for row in report.metrics if not row.pass_signal) + raise AssertionError( + f"{context}: topology={report.topology} phase={first_failure.phase} " + f"step={first_failure.step_index} param={first_failure.param} " + f"reasons={'; '.join(first_failure.failure_reasons)}" + ) + raise AssertionError( + f"{context}: expected_signal={report.expected_signal} " + f"observed_signal={report.signal} topology={report.topology}" + ) -def _worker_run(request: WorkerRunRequest) -> None: - from megatron.core.optimizer import OptimizerConfig - from safetensors.torch import load_file, save_file - import torch + def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> None: + """Persists full variant report JSON for debugging and regression inspection.""" + _write_json( + topology_dir / "variant_report.json", report.model_dump(mode="json") + ) - from art import dev, types - from art.megatron import train as megatron_train - from art.preprocessing.pack import packed_tensors_from_dir + def print_report(self, report: VariantReport) -> None: + """Prints a row-level table with expert rows subsampled by highest relative_l2.""" + non_expert_rows: list[MetricRow] = [] + triplet_rows: list[tuple[tuple[int, int], MetricRow]] = [] + for row in report.metrics: + expert_key = _triplet_expert_key(row.param) + if expert_key is None: + non_expert_rows.append(row) + continue + triplet_rows.append((expert_key, row)) - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group(backend="nccl") - _set_deterministic_seed(request.case_config.seed) + scores_by_layer: dict[int, dict[int, float]] = {} + for (layer, expert_id), row in triplet_rows: + layer_scores = scores_by_layer.setdefault(layer, {}) + layer_scores[expert_id] = max( + layer_scores.get(expert_id, float("-inf")), row.relative_l2 + ) - world_size = torch.distributed.get_world_size() - if world_size != request.topology.world_size(): - raise RuntimeError( - f"World size mismatch: expected {request.topology.world_size()}, got {world_size}" + selected_experts: set[tuple[int, int]] = set() + for layer, expert_scores in scores_by_layer.items(): + top_experts = sorted( + expert_scores.items(), + key=lambda item: item[1], + reverse=True, + )[:EXPERT_TABLE_ROW_LIMIT] + for expert_id, _score in top_experts: + selected_experts.add((layer, expert_id)) + + selected_triplet_rows = [ + row for expert_key, row in triplet_rows if expert_key in selected_experts + ] + table_rows = non_expert_rows + selected_triplet_rows + detail_table = Table( + title=( + f"Variant Report | variant={report.variant} " + f"| topology={report.topology} | signal={report.signal} " + f"| selected_experts={len(selected_experts)} " + f"(top {EXPERT_TABLE_ROW_LIMIT} per layer)" + ), + box=box.SIMPLE_HEAVY, + show_lines=False, ) - - runtime = megatron_train.build_training_runtime( - model_identifier=request.case_config.base_model, - provider_configure=lambda provider: _configure_provider( - provider, request.topology - ), - optimizer_config=OptimizerConfig( - bf16=True, - lr=request.case_config.learning_rate, - adam_beta1=0.9, - adam_beta2=0.99, - clip_grad=0.1, - weight_decay=0.1, - ), - print_env=False, - print_optimizer_stats=False, - ) - model_chunks = runtime.model - optimizer = runtime.optimizer - - topology_dir = Path(request.topology_dir) - traces_dir = topology_dir / "traces" - traces_dir.mkdir(parents=True, exist_ok=True) - - shared_init_path = Path(request.shared_init_adapter_path) - if not shared_init_path.exists(): - if not request.allow_create_shared_init: - raise FileNotFoundError( - f"Missing oracle shared adapter at {shared_init_path}" + detail_table.add_column("Step", justify="right") + detail_table.add_column("Phase", style="cyan") + detail_table.add_column("Param") + detail_table.add_column("Status") + detail_table.add_column("relative_l2", justify="right") + detail_table.add_column("mean_abs_pct", justify="right") + detail_table.add_column("typical_abs", justify="right") + # detail_table.add_column("Thresholds") + detail_table.add_column("Failure") + sorted_rows = sorted( + table_rows, + key=lambda row: ( + row.step_index, + PHASE_PRINT_ORDER.get(row.phase, 999), + row.param, + row.pass_signal, + ), + ) + for row in sorted_rows: + status_text = ( + "[green]PASS[/green]" if row.pass_signal else "[red]FAIL[/red]" ) - initial_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0: - assert initial_state is not None - shared_init_path.parent.mkdir(parents=True, exist_ok=True) - save_file(initial_state, str(shared_init_path)) - torch.distributed.barrier() - if not shared_init_path.exists(): - raise FileNotFoundError(f"Shared init adapter not created: {shared_init_path}") - - adapter_model = load_file(str(shared_init_path)) - expected_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0: - assert expected_state is not None - _validate_adapter_exact(expected_state, adapter_model) - torch.distributed.barrier() - - megatron_train.load_adapter_into_model(model_chunks, adapter_model) - loaded_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0: - assert loaded_state is not None - _validate_loaded_state_matches_adapter(loaded_state, adapter_model) - torch.distributed.barrier() - - packed_tensors = packed_tensors_from_dir( - **request.packed_tensors.model_dump(exclude_none=True) - ) - initial_lora_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0 and initial_lora_state is None: - raise RuntimeError("Failed to collect initial LoRA state on rank 0") - - train_config = types.TrainConfig( - learning_rate=request.case_config.learning_rate, - beta=request.case_config.beta, - kl_penalty_coef=0.0, - ) - experimental_config: dev.TrainConfig = {} - step_traces: list[StepTrace] = [] - captured_grads: dict[str, Any] | None = None - - def _capture_lora_grads() -> None: - nonlocal captured_grads - captured_grads = _collect_lora_grads(model_chunks) - - with _mutation_hook( - megatron_train, - request.mutation, - pre_optimizer_step_hook=_capture_lora_grads, - ): - for step_index in range(request.case_config.num_steps): - sample_index = step_index % request.packed_tensors.num_sequences - inputs = megatron_train.select_indexed_inputs(packed_tensors, sample_index) - captured_grads = None - - step_result = megatron_train.run_training_step( - model_chunks=model_chunks, - optimizer=optimizer, - learning_rate=train_config.learning_rate, - inputs=inputs, - config=train_config, - experimental_config=experimental_config, - ref_logprobs=None, + failure_text = "" if row.pass_signal else "; ".join(row.failure_reasons) + detail_table.add_row( + str(row.step_index), + row.phase, + row.param, + status_text, + f"{row.relative_l2:.6g}", + f"{row.mean_abs_pct:.6g}%", + f"{row.typical_abs_scale:.6g}", + # _threshold_string(row.thresholds), # disabled for now to avoid clutter, neat to keep though + failure_text, ) - if torch.distributed.get_rank() == 0 and captured_grads is None: - raise RuntimeError("Failed to collect LoRA grads on rank 0") - - current_lora_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0 and current_lora_state is None: - raise RuntimeError("Failed to collect current LoRA state on rank 0") - - if torch.distributed.get_rank() == 0: - grads = _require_not_none(captured_grads, "captured_grads") - initial_state = _require_not_none( - initial_lora_state, "initial_lora_state" - ) - current_state = _require_not_none( - current_lora_state, "current_lora_state" - ) - output_rel = Path("traces") / f"output_step_{step_index:03d}.pt" - grads_rel = Path("traces") / f"grads_step_{step_index:03d}.safetensors" - deltas_rel = ( - Path("traces") / f"deltas_step_{step_index:03d}.safetensors" - ) - lora_rel = Path(f"lora_step_{step_index:03d}.safetensors") + self.console.print(detail_table) + + def run_variant( + self, + variant: VariantSpec, + ) -> VariantReport: + """Runs a variant end-to-end, writes JSON report, and prints row table.""" + topology_dir = self.ensure_variant_artifacts(variant) + report = self.compare_variant(variant) + self._write_variant_report(topology_dir, report) + self.print_report(report) + return report + + def run_suite( + self, + variants: list[VariantSpec], + ) -> list[VariantReport]: + """Runs variants in order and stops at the first unexpected signal.""" + reports: list[VariantReport] = [] + for variant in variants: + report = self.run_variant(variant) + reports.append(report) + self.assert_expected_signal(report, "Megatron oracle suite mismatch") + return reports + + +def _default_phase_thresholds( + case_cfg: OracleCaseConfig, +) -> dict[str, dict[str, float]]: + """Builds default per-phase (fwd, grad, outputs, losses, deltas) threshold dictionaries.""" + default = { + "relative_l2": case_cfg.tolerances.relative_l2, + "mean_abs_pct": case_cfg.tolerances.mean_abs_pct, + } + return { + key: default for key in ["outputs", "losses", "grads", "deltas", "forward"] + } | { + "router_scores": {"mean_abs_pct": 0.0}, + "router_topk_ids": { + "topk_mismatch_fraction": 0.0, + "top1_mismatch_fraction": 0.0, + }, + } - torch.save( - step_result.new_logprobs.detach().cpu().float(), - topology_dir / output_rel, - ) - save_file(grads, str(topology_dir / grads_rel)) - deltas = _delta_state(initial_state, current_state) - save_file(deltas, str(topology_dir / deltas_rel)) - save_file(current_state, str(topology_dir / lora_rel)) - step_traces.append( - StepTrace( - step_index=step_index, - loss=float(step_result.reduced_loss.item()), - probs_corr=step_result.probs_corr, - output_file=str(output_rel), - grads_file=str(grads_rel), - deltas_file=str(deltas_rel), - lora_file=str(lora_rel), - ) - ) - torch.distributed.barrier() - - if torch.distributed.get_rank() == 0: - manifest = RunManifest( - case_id=request.case_id, - base_model=request.case_config.base_model, - topology=request.topology.slug(), - world_size=request.topology.world_size(), - seed=request.case_config.seed, - num_steps=request.case_config.num_steps, - packed_tensors=request.packed_tensors, - tolerances=request.case_config.tolerances, - steps=step_traces, +def _suite_variants(case_cfg: OracleCaseConfig) -> list[VariantSpec]: + """Builds the standard oracle suite variant ordering.""" + thresholds = _default_phase_thresholds(case_cfg) + variants = [ + VariantSpec( + name="oracle_replay_parity", + topology=ORACLE_TOPOLOGY, + output_slug=_topology_output_slug( + ORACLE_TOPOLOGY, ORACLE_REPLAY_TOPOLOGY_SUFFIX + ), + thresholds_by_phase=thresholds, ) - _write_json(topology_dir / "manifest.json", manifest.model_dump(mode="json")) - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -def _run_worker_cli(run_request_path: Path) -> None: - request = WorkerRunRequest.model_validate(_read_json(run_request_path)) - _worker_run(request) - - -def _parse_args(argv: list[str]) -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Megatron oracle harness worker") - parser.add_argument("--worker-run", action="store_true") - parser.add_argument("--run-request", type=Path) - return parser.parse_args(argv) + ] + for topology in TOPOLOGIES[1:] + ( + EXTENDED_TOPOLOGIES if extended_topologies_enabled() else [] + ): + variants.append( + VariantSpec( + name=f"topology_{topology.slug()}", + topology=topology, + thresholds_by_phase=thresholds, + ) + ) + return variants -def _main(argv: list[str]) -> int: - args = _parse_args(argv) - if not args.worker_run: - raise SystemExit("This module is intended for test imports or --worker-run") - if args.run_request is None: - raise SystemExit("--run-request is required with --worker-run") - _run_worker_cli(args.run_request) - return 0 +def run_suite( + *, + case_config: OracleCaseConfig, +) -> list[VariantReport]: + """Runs replay parity and topology variants with fail-fast assertions.""" + runner = VariantRunner(case_config=case_config) + return runner.run_suite(_suite_variants(case_config)) -if __name__ == "__main__": - raise SystemExit(_main(sys.argv[1:])) +def run_sensitivity_suite( + *, + case_config: OracleCaseConfig, + mutations: list[SensitivityMutation], +) -> list[VariantReport]: + """Runs a list of sensitivity mutations and expects each to fail.""" + runner = VariantRunner(case_config=case_config) + thresholds = _default_phase_thresholds(case_config) + variants = [ + VariantSpec( + name=f"sensitivity_{mutation}", + topology=SENSITIVITY_TOPOLOGY, + mutation=mutation, + expected_signal="fail", + thresholds_by_phase=thresholds, + ) + for mutation in mutations + ] + return runner.run_suite(variants) diff --git a/tests/integration/test_megatron_lora_oracle_correctness.py b/tests/integration/test_megatron_lora_oracle_correctness.py index f6949b81..f7e4dbbc 100644 --- a/tests/integration/test_megatron_lora_oracle_correctness.py +++ b/tests/integration/test_megatron_lora_oracle_correctness.py @@ -1,19 +1,17 @@ import pytest from .megatron_oracle_harness import ( - ORACLE_TOPOLOGY, - PHASE_A_TOPOLOGIES, - PHASE_B_TOPOLOGIES, + EXTENDED_TOPOLOGIES, SENSITIVITY_MUTATION_ENV, SENSITIVITY_TOPOLOGY, + TOPOLOGIES, available_gpu_count, - default_case_config, - ensure_oracle_reference_artifacts, - phase_b_dp_enabled, - regenerate_requested, - run_and_compare_topology, - run_sensitivity_check, + case_config, + extended_topologies_enabled, + run_sensitivity_suite, + run_suite, sensitivity_enabled, + sensitivity_mutations, ) @@ -25,76 +23,38 @@ def _require_gpus_for(topology_world_size: int) -> None: ) -def _skip_if_sensitivity_mode() -> None: - if sensitivity_enabled(): - pytest.skip( - f"{SENSITIVITY_MUTATION_ENV} is enabled; running sensitivity check only." - ) +def _suite_world_size() -> int: + suite_topologies = list(TOPOLOGIES) + if extended_topologies_enabled(): + suite_topologies.extend(EXTENDED_TOPOLOGIES) + return max(topology.world_size() for topology in suite_topologies) -def _run_topology_case( # type: ignore[no-untyped-def] - topology, - case_config, - *, - regenerate: bool, -) -> None: - _require_gpus_for(topology.world_size()) - run_and_compare_topology( - case_config=case_config, - topology=topology, - regenerate=regenerate, - ) - +def test_megatron_lora_diff_sensitivity() -> None: + """ + Runs a each of the sensitivity mutations (e.g. drop megatron finalize grads) + and expects each to fail (numerical differences larger than our thresholds) -def test_000_megatron_lora_oracle_sensitivity_check() -> None: + This test ensures we can catch errors we know of (implying we will be able to catch unknown errors as well) + """ if not sensitivity_enabled(): pytest.skip( - f"Set {SENSITIVITY_MUTATION_ENV}=drop_finalize to enable sensitivity check." + f"Set {SENSITIVITY_MUTATION_ENV}=drop_finalize (or CSV) to enable sensitivity check." ) _require_gpus_for(SENSITIVITY_TOPOLOGY.world_size()) - run_sensitivity_check( - case_config=default_case_config(), - regenerate=regenerate_requested(), + mutations = sensitivity_mutations() + assert mutations + run_sensitivity_suite( + case_config=case_config(), + mutations=mutations, ) -def test_megatron_lora_oracle_phase_a_matrix() -> None: - _skip_if_sensitivity_mode() - case_config = default_case_config() - regenerate = regenerate_requested() - _require_gpus_for(ORACLE_TOPOLOGY.world_size()) - ensure_oracle_reference_artifacts( - case_config=case_config, - regenerate=regenerate, - ) - for topology in PHASE_A_TOPOLOGIES: - _run_topology_case( - topology, - case_config, - regenerate=regenerate and topology.slug() != ORACLE_TOPOLOGY.slug(), - ) - - -@pytest.mark.parametrize( - "topology_index", - range(len(PHASE_B_TOPOLOGIES)), - ids=[topology.slug() for topology in PHASE_B_TOPOLOGIES], -) -def test_megatron_lora_oracle_phase_b_dp_matrix(topology_index: int) -> None: - _skip_if_sensitivity_mode() - if not phase_b_dp_enabled(): - pytest.xfail( - "DP matrix currently blocked until Megatron backend DP support is enabled" - ) - case_config = default_case_config() - regenerate = regenerate_requested() - _require_gpus_for(ORACLE_TOPOLOGY.world_size()) - ensure_oracle_reference_artifacts( - case_config=case_config, - regenerate=regenerate, - ) - _run_topology_case( - PHASE_B_TOPOLOGIES[topology_index], - case_config, - regenerate=regenerate, +def test_megatron_lora_topology_suite() -> None: + """ + Runs the suite of topologies and expects each to pass (numerical differences within our thresholds) + """ + _require_gpus_for(_suite_world_size()) + run_suite( + case_config=case_config(), ) From ec8371629d06ed8f9fe7d3f3d1ea08084ff861d6 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Thu, 12 Mar 2026 09:41:59 +0000 Subject: [PATCH 10/19] typing: clear blocking ty errors in oracle replay and LoRA paths --- src/art/megatron/lora.py | 7 ++++--- tests/integration/megatron_forward_trace.py | 17 +++++++++++------ tests/integration/megatron_oracle_harness.py | 20 +++++++++++--------- tests/integration/megatron_oracle_worker.py | 7 ++----- tests/unit/test_moe_routing_replay.py | 4 +++- 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index b594bf18..ca578e5f 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -119,10 +119,11 @@ def _set_lora_parallel_metadata( # Megatron optimizer and checkpoint logic rely on tensor model-parallel metadata # to distinguish true shards from TP-duplicate params. if parallel_spec.sharded: + shard_dim = parallel_spec.shard_dim + if shard_dim is None: + raise ValueError("LoRAParallelSpec.shard_dim must be set when sharded=True") setattr(param, "tensor_model_parallel", True) - setattr( - param, "partition_dim", _normalize_axis(parallel_spec.shard_dim, param.ndim) - ) + setattr(param, "partition_dim", _normalize_axis(shard_dim, param.ndim)) # stride > 1 means the dim is split into blocks and each tp rank holds a shard of the block # this might happen for fused e.g. gate_(up|proj), but loras are individual per module setattr(param, "partition_stride", 1) diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py index 2ca418aa..5da6f0d6 100644 --- a/tests/integration/megatron_forward_trace.py +++ b/tests/integration/megatron_forward_trace.py @@ -90,12 +90,17 @@ def _extract_primary_tensor(value: Any) -> torch.Tensor | None: def _materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: - if hasattr(tensor, "full_tensor"): - tensor = cast(torch.Tensor, tensor.full_tensor()) - elif hasattr(tensor, "to_local"): - tensor = cast(torch.Tensor, tensor.to_local()) - elif hasattr(tensor, "_local_tensor"): - tensor = cast(torch.Tensor, tensor._local_tensor) + full_tensor = getattr(tensor, "full_tensor", None) + if callable(full_tensor): + tensor = cast(torch.Tensor, full_tensor()) + else: + to_local = getattr(tensor, "to_local", None) + if callable(to_local): + tensor = cast(torch.Tensor, to_local()) + else: + local_tensor = getattr(tensor, "_local_tensor", None) + if isinstance(local_tensor, torch.Tensor): + tensor = local_tensor return tensor.detach().cpu() diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index a66385b0..4e6567f9 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -959,12 +959,7 @@ def _build_metric_rows_from_tensor_pairs( """Builds rows from named tensor pairs with one shared diff path.""" rows: list[MetricRow] = [] for name, reference, candidate in pairs: - shared_kwargs = { - "variant": variant, - "step_index": step_index, - "phase": phase, - "param": _minimal_param_name(name), - } + param_name = _minimal_param_name(name) reference_aligned = reference candidate_aligned = candidate aligned_candidate = _align_sequence_parallel( @@ -973,9 +968,12 @@ def _build_metric_rows_from_tensor_pairs( if aligned_candidate is None: rows.append( self._build_metric_row( + variant=variant, + step_index=step_index, + phase=phase, + param=param_name, summary=self._inf_summary(), structural_failure="shape mismatch", - **shared_kwargs, ) ) continue @@ -986,7 +984,11 @@ def _build_metric_rows_from_tensor_pairs( accumulator.update(reference_aligned, aligned_candidate) rows.append( self._build_metric_row( - summary=accumulator.as_summary(), **shared_kwargs + variant=variant, + step_index=step_index, + phase=phase, + param=param_name, + summary=accumulator.as_summary(), ) ) return rows @@ -1032,7 +1034,7 @@ def _build_metric_rows_from_tensor_maps( reference, candidate, variant, step_index, phase ) if not matching: - return rows + return rows if rows is not None else [] pairs = [ (key, reference[key], candidate[key]) for key in sorted(set(reference.keys())) diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index 91c7647d..a5a3ed66 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -201,17 +201,14 @@ def _build_optimizer_config(case_config: OracleCaseConfig): """Builds Megatron optimizer settings for deterministic harness runs.""" from megatron.core.optimizer import OptimizerConfig - optimizer_kwargs = dict( + return OptimizerConfig( + bf16=True, lr=case_config.learning_rate, adam_beta1=0.9, adam_beta2=0.99, clip_grad=0.1, weight_decay=0.1, ) - return OptimizerConfig( - bf16=True, - **optimizer_kwargs, - ) def _assert_runtime_configuration( diff --git a/tests/unit/test_moe_routing_replay.py b/tests/unit/test_moe_routing_replay.py index 980784c7..15d1ebc6 100644 --- a/tests/unit/test_moe_routing_replay.py +++ b/tests/unit/test_moe_routing_replay.py @@ -2,6 +2,7 @@ from pathlib import Path import tempfile +from typing import cast import pytest import torch @@ -165,7 +166,8 @@ def test_controller_patches_router_and_replays() -> None: controller.set_step(step_index=0, sample_index=0) logits = torch.randn((4, 3), dtype=torch.float32) - replay_probs, replay_map = chunk.decoder.layers[0].mlp.router.routing(logits) + router = cast(_FakeRouter, chunk.decoder.layers[0].mlp.router) + replay_probs, replay_map = router.routing(logits) expected_probs, expected_map = _dense_from_compact(route, dtype=logits.dtype) assert torch.equal(replay_map.cpu(), expected_map) From 83d871bcb75610518e224dd0253ad270aa623191 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Sat, 14 Mar 2026 08:20:41 +0000 Subject: [PATCH 11/19] megatron: reduce oracle variance with sequence grad accumulation Use per-step micro-accumulation over multiple packed sequences so updates are less sensitive to sparse expert token assignment. Also make backend progress accounting accumulation-aware. --- src/art/local/backend.py | 5 +- src/art/megatron/train.py | 109 ++++++++++++++++++++++++++------------ src/art/types.py | 1 + 3 files changed, 79 insertions(+), 36 deletions(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index b74c0b05..f75a77a2 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -696,7 +696,10 @@ async def _train_model( ) # Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train()) results: list[dict[str, float]] = [] - estimated_gradient_steps = disk_packed_tensors["num_sequences"] + grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences)) + estimated_gradient_steps = math.ceil( + disk_packed_tensors["num_sequences"] / grad_accumulation_sequences + ) pbar = tqdm.tqdm(total=estimated_gradient_steps, desc="train") async for result in service.train( disk_packed_tensors, config, dev_config, verbose diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 33dc8172..a66156bc 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -133,6 +133,7 @@ def _default_optimizer_config() -> OptimizerConfig: adam_beta2=0.99, clip_grad=0.1, weight_decay=0.1, + adam_eps=1e-13, ) @@ -336,63 +337,94 @@ def run_training_step( model_chunks: list[MegatronModule], optimizer: Any, learning_rate: float, - inputs: PackedTensors, + inputs: PackedTensors | list[PackedTensors], config: types.TrainConfig, experimental_config: dev.TrainConfig, ref_logprobs: torch.Tensor | None = None, step_index: int | None = None, - sample_index: int | None = None, + sample_index: int | list[int] | None = None, moe_routing_replay_controller: MoeRoutingReplayController | None = None, ) -> TrainStepResult: + micro_inputs = inputs if isinstance(inputs, list) else [inputs] + if not micro_inputs: + raise ValueError("run_training_step requires at least one packed sequence") + + if isinstance(sample_index, list): + if len(sample_index) != len(micro_inputs): + raise ValueError( + "sample_index list length must match number of micro inputs: " + f"{len(sample_index)} != {len(micro_inputs)}" + ) + micro_sample_indices = sample_index + elif sample_index is None: + micro_sample_indices = [0] * len(micro_inputs) + else: + micro_sample_indices = [sample_index] * len(micro_inputs) + if moe_routing_replay_controller is not None: - assert step_index is not None and sample_index is not None + assert step_index is not None moe_routing_replay_controller.set_step( step_index=step_index, - sample_index=sample_index, + sample_index=micro_sample_indices[0], ) device = next(model_chunks[0].parameters()).device - _move_inputs_to_device(inputs, device) - - attention_state = create_shared_prefix_attention_state( - group_ids=inputs["group_ids"], - parent_ids=inputs["parent_ids"], - ) - attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) for chunk in model_chunks: chunk.zero_grad_buffer() # ty: ignore[call-non-callable] - new_logprobs: torch.Tensor = -model_chunks[0]( - input_ids=inputs["tokens"], - position_ids=inputs["input_pos"], - attention_mask=attention_mask, - labels=shift_tensor(inputs["tokens"], 0), - extra_block_kwargs={"attention_bias": attention_state}, - ) + micro_count = len(micro_inputs) + loss_sum: torch.Tensor | None = None + probs_corr_sum = 0.0 + new_logprobs: torch.Tensor | None = None + + for micro in micro_inputs: + _move_inputs_to_device(micro, device) + attention_state = create_shared_prefix_attention_state( + group_ids=micro["group_ids"], + parent_ids=micro["parent_ids"], + ) + attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) + + new_logprobs = -model_chunks[0]( + input_ids=micro["tokens"], + position_ids=micro["input_pos"], + attention_mask=attention_mask, + labels=shift_tensor(micro["tokens"], 0), + extra_block_kwargs={"attention_bias": attention_state}, + ) + + loss_info = loss_fn( + micro, # ty: ignore[invalid-argument-type] + new_logprobs, + ref_logprobs, + None, + experimental_config, + ) + micro_loss = loss_info.mean_policy_loss + config.beta * loss_info.mean_kl + (micro_loss / micro_count).backward() + probs_corr_sum += float(loss_info.probs_corr.item()) + if loss_sum is None: + loss_sum = micro_loss.detach() + else: + loss_sum = loss_sum + micro_loss.detach() + + if new_logprobs is None or loss_sum is None: + raise RuntimeError("run_training_step did not produce outputs") - loss_info = loss_fn( - inputs, # ty: ignore[invalid-argument-type] - new_logprobs, - ref_logprobs, - None, - experimental_config, - ) - loss = loss_info.mean_policy_loss + config.beta * loss_info.mean_kl - loss.backward() _finalize_grads(model_chunks) update_successful, grad_norm, num_zeros_in_grad = _optimizer_step( optimizer, learning_rate, ) - reduced_loss = _reduce_loss(loss) + reduced_loss = _reduce_loss(loss_sum / micro_count) if moe_routing_replay_controller is not None: moe_routing_replay_controller.finalize_step() return TrainStepResult( reduced_loss=reduced_loss, - probs_corr=float(loss_info.probs_corr.item()), + probs_corr=probs_corr_sum / micro_count, new_logprobs=new_logprobs, update_successful=update_successful, grad_norm=grad_norm, @@ -475,18 +507,25 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: repeat = math.ceil(num_indices / len(indices)) indices = (indices * repeat)[:num_indices] - for step_index, index in enumerate(indices): - inputs = select_indexed_inputs(packed_tensors, index) + grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences)) + for step_index, start in enumerate( + range(0, len(indices), grad_accumulation_sequences) + ): + micro_indices = indices[start : start + grad_accumulation_sequences] + micro_inputs = [ + select_indexed_inputs(packed_tensors, sample_index) + for sample_index in micro_indices + ] step_result = run_training_step( model_chunks=runtime.model, optimizer=runtime.optimizer, learning_rate=config.learning_rate, - inputs=inputs, + inputs=micro_inputs, config=config, experimental_config=experimental_config, ref_logprobs=None, step_index=step_index, - sample_index=index, + sample_index=micro_indices, moe_routing_replay_controller=runtime.moe_routing_replay_controller, ) print0( @@ -535,8 +574,8 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: del packed_tensors del adapter_model - if "inputs" in locals(): - del inputs + if "micro_inputs" in locals(): + del micro_inputs gc.collect() torch.cuda.empty_cache() diff --git a/src/art/types.py b/src/art/types.py index 017f05c7..be31db23 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -18,6 +18,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 beta: float = 0.0 kl_penalty_coef: float = 0.0 + grad_accumulation_sequences: int = pydantic.Field(default=1, ge=1) class TrainSFTConfig(pydantic.BaseModel): From 84e2ea76d8811754d75ec7ca4eddd02897e3a78b Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Sat, 14 Mar 2026 08:21:17 +0000 Subject: [PATCH 12/19] megatron lora: fix TP/EP export participation rules Correct LoRA shard export behavior so non-zero TP ranks in EP/ETP topologies contribute when required, while still filtering replicated-only entries. --- src/art/megatron/lora.py | 53 +++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index ca578e5f..247389c3 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -182,9 +182,9 @@ def num_local_experts(self) -> int: return self.A_T.shape[0] if self.A_T.ndim == 3 else 1 def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: - if not getattr(param, "lora_tp_replicated", False): + if not param.lora_tp_replicated: return - domain = getattr(param, "lora_shard_domain") + domain = param.lora_shard_domain world_size = _get_shard_world_size(domain) if world_size <= 1: return @@ -252,14 +252,9 @@ def load_weights( self.load_weight(weight, into=into) def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: - domain = getattr(into, "lora_shard_domain") - sharded = bool(getattr(into, "lora_tp_sharded")) - if sharded: - axis = getattr(into, "lora_tp_shard_dim") - if axis is None: - raise RuntimeError( - f"{self.adapter_model_prefix}: missing shard axis for sharded parameter" - ) + domain = into.lora_shard_domain + if into.lora_tp_sharded: + axis = into.lora_tp_shard_dim axis = _normalize_axis(axis, weight.ndim) world_size = _get_shard_world_size(domain) rank = _get_shard_rank(domain) @@ -283,37 +278,35 @@ def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None into.requires_grad = True def _should_export_parameter(self, param: torch.nn.Parameter) -> bool: + """ + Determine if the given LoRA param should be exported in the sharded LoRA state dict + (drop replicated ranks/params). + """ if self.num_local_experts > 1: # self is a MoE layer if ps.get_expert_data_parallel_rank() != 0: return False else: # self is a non-MoE layer - if ps.get_data_parallel_rank() != 0: - return False - # Non-MoE layers are replicated across expert-model-parallel ranks. - if ( - ps.get_expert_model_parallel_world_size() > 1 - and ps.get_expert_model_parallel_rank() != 0 - ): + # dp x cp rank 0 participates + if ps.get_data_parallel_rank(with_context_parallel=True) != 0: return False - if getattr(param, "lora_tp_sharded", False): - # this param is fully sharded, all shard ranks participate + # this param is fully sharded, all shard ranks participate + if param.lora_tp_sharded: return True - - domain = getattr(param, "lora_shard_domain") # param is replicated, tp rank 0 or etp rank 0 participates - return _get_shard_rank(domain) == 0 + return _get_shard_rank(param.lora_shard_domain) == 0 def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: - domain = getattr(param, "lora_shard_domain") - sharded = bool(getattr(param, "lora_tp_sharded", False)) - shard_dim = getattr(param, "lora_tp_shard_dim", None) return { - "domain": domain, - "sharded": sharded, - "shard_dim": shard_dim, - "shard_world_size": _get_shard_world_size(domain) if sharded else 1, - "shard_rank": _get_shard_rank(domain) if sharded else 0, + "domain": param.lora_shard_domain, + "sharded": param.lora_tp_sharded, + "shard_dim": param.lora_tp_shard_dim, + "shard_world_size": _get_shard_world_size(param.lora_shard_domain) + if param.lora_tp_sharded + else 1, + "shard_rank": _get_shard_rank(param.lora_shard_domain) + if param.lora_tp_sharded + else 0, } def _lora_params(self) -> list[tuple[str, torch.nn.Parameter]]: From 0bc99194d162a8a803a1cb684a2d9d799722fa8e Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Sat, 14 Mar 2026 08:21:25 +0000 Subject: [PATCH 13/19] oracle trace: canonicalize MoE outputs across arbitrary topologies Move normalization logic into ForwardTraceCapture so saved traces are canonicalized toward world-size-1 semantics (expert row identity/order and ETP fc1 layout). --- tests/integration/megatron_forward_trace.py | 272 +++++++++++++++++++- 1 file changed, 268 insertions(+), 4 deletions(-) diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py index 5da6f0d6..d3befd9f 100644 --- a/tests/integration/megatron_forward_trace.py +++ b/tests/integration/megatron_forward_trace.py @@ -20,6 +20,7 @@ ".mlp.experts.linear_fc2.lora", ) ROUTER_NAME_TOKEN = ".mlp.router" +PRIMARY_OUTPUT_CANONICAL_KEY = "primary_output__is_canonical" def _safe_int(value: Any, default: int = 0) -> int: @@ -221,6 +222,25 @@ def _infer_primary_output_merge_hint( if lora_hint is not None: return lora_hint + # Base MoE expert linears need expert-TP aware merge semantics. + # With etp>1: + # - FC1 (column-parallel) shards output features -> concat on feature dim. + # - FC2 (row-parallel) emits partial output contributions -> sum across ranks. + # With etp==1, keep the existing token-row concat behavior. + etp_world_size = _safe_ps_stat("get_expert_tensor_parallel_world_size", 1) + if ".mlp.experts.linear_fc1" in name and ".lora" not in name: + if etp_world_size > 1: + return { + "op": "concat", + "dim": -1, + "layout": "gate_up_rank_interleaved", + } + return {"op": "concat", "dim": 0} + if ".mlp.experts.linear_fc2" in name and ".lora" not in name: + if etp_world_size > 1: + return {"op": "sum"} + return {"op": "concat", "dim": 0} + gather_output = getattr(module, "gather_output", None) if isinstance(gather_output, bool) and not gather_output: return {"op": "concat", "dim": -1} @@ -287,6 +307,249 @@ def set_step(self, step_index: int) -> None: self.current_step_index = step_index self.current_step_trace = {} + @staticmethod + def _is_moe_expert_forward_module(module_name: str) -> bool: + """Returns whether one module emits MoE expert forward outputs.""" + if ".mlp.experts." not in module_name: + return False + if ".mlp.router" in module_name: + return False + return ".linear_fc1" in module_name or ".linear_fc2" in module_name + + @staticmethod + def _primary_output_merge_hint(call: dict[str, Any]) -> dict[str, Any] | None: + """Reads primary-output merge metadata from one call payload.""" + merge_hints = call.get("merge_hints") + if not isinstance(merge_hints, dict): + return None + primary_hint = merge_hints.get("primary_output") + if not isinstance(primary_hint, dict): + return None + return primary_hint + + @classmethod + def _lookup_call_by_index( + cls, + trace: dict[str, list[dict[str, Any]]], + module_name: str, + call_index: int, + ) -> dict[str, Any] | None: + """Finds one call entry by call-index with positional fallback.""" + calls = trace.get(module_name) + if calls is None: + return None + for call in calls: + if int(call.get("call_index", -1)) == call_index: + return call + if 0 <= call_index < len(calls): + return calls[call_index] + return None + + @staticmethod + def _router_module_name_for_expert_module(module_name: str) -> str | None: + """Maps one expert module name to its layer router module name.""" + for token in (".mlp.experts.linear_fc1", ".mlp.experts.linear_fc2"): + token_index = module_name.find(token) + if token_index != -1: + return f"{module_name[:token_index]}.mlp.router" + return None + + @classmethod + def _build_moe_row_identities( + cls, + *, + module_name: str, + call_index: int, + trace: dict[str, list[dict[str, Any]]], + row_splits: list[int] | None, + ) -> list[tuple[int, int, int]] | None: + """Builds stable `(expert_id, token_index, topk_slot)` identities for MoE rows.""" + router_module_name = cls._router_module_name_for_expert_module(module_name) + if router_module_name is None: + return None + router_call = cls._lookup_call_by_index(trace, router_module_name, call_index) + if router_call is None: + return None + router_topk_ids = router_call.get("router_topk_ids") + if not isinstance(router_topk_ids, torch.Tensor) or router_topk_ids.ndim != 2: + return None + token_splits_raw = router_call.get("router_topk_ids__row_splits") + if row_splits is None: + if isinstance(token_splits_raw, list): + row_splits = [ + int(v) * int(router_topk_ids.shape[1]) for v in token_splits_raw + ] + else: + row_splits = [int(router_topk_ids.numel())] + if isinstance(token_splits_raw, list): + token_splits = [int(v) for v in token_splits_raw] + else: + topk = int(router_topk_ids.shape[1]) + token_splits = [int(v) // topk for v in row_splits] + if len(row_splits) != len(token_splits): + return None + row_cursor = 0 + token_cursor = 0 + identities: list[tuple[int, int, int]] = [] + for row_count, token_count in zip(row_splits, token_splits): + local_ids = router_topk_ids[token_cursor : token_cursor + token_count] + token_cursor += token_count + local_identities: list[tuple[int, int, int]] = [] + max_expert = int(local_ids.max().item()) if local_ids.numel() > 0 else -1 + for expert_id in range(max_expert + 1): + expert_rows = (local_ids == expert_id).nonzero(as_tuple=False) + for token_offset, slot_index in expert_rows.tolist(): + local_identities.append( + (expert_id, token_cursor - token_count + token_offset, slot_index) + ) + if len(local_identities) != row_count: + return None + identities.extend(local_identities) + row_cursor += row_count + if row_cursor != sum(row_splits): + return None + return identities + + @classmethod + def _canonicalize_etp_fc1_feature_layout( + cls, + *, + module_name: str, + tensor: torch.Tensor, + call: dict[str, Any], + ) -> torch.Tensor: + """Normalizes expert-TP fc1 feature order to a topology-independent layout.""" + if ".mlp.experts.linear_fc1" not in module_name or ".lora" in module_name: + return tensor + if tensor.ndim != 2: + return tensor + primary_hint = cls._primary_output_merge_hint(call) + if not isinstance(primary_hint, dict): + return tensor + if primary_hint.get("layout") != "gate_up_rank_interleaved": + return tensor + rank_meta = call.get("rank_meta") + etp_world_size = None + if isinstance(rank_meta, list) and rank_meta: + first_meta = rank_meta[0] + if isinstance(first_meta, dict): + etp_world_size = first_meta.get("etp_world_size") + elif isinstance(rank_meta, dict): + etp_world_size = rank_meta.get("etp_world_size") + if not isinstance(etp_world_size, int) or etp_world_size <= 1: + return tensor + block_count = 2 * etp_world_size + if tensor.shape[1] % block_count != 0: + return tensor + blocks = torch.chunk(tensor, block_count, dim=1) + reordered = [blocks[index] for index in range(0, block_count, 2)] + [ + blocks[index] for index in range(1, block_count, 2) + ] + return torch.cat(reordered, dim=1).contiguous() + + @classmethod + def _canonicalize_moe_expert_row_order( + cls, + *, + module_name: str, + call_index: int, + tensor: torch.Tensor, + trace: dict[str, list[dict[str, Any]]], + call: dict[str, Any], + ) -> torch.Tensor: + """Canonicalizes MoE expert-row ordering using router replay identities.""" + if not cls._is_moe_expert_forward_module(module_name): + return tensor + if tensor.ndim != 2: + return tensor + primary_hint = cls._primary_output_merge_hint(call) + if isinstance(primary_hint, dict) and ( + primary_hint.get("op") != "concat" or primary_hint.get("dim") != 0 + ): + return tensor + row_splits_raw = call.get("primary_output__row_splits") + row_splits = ( + [int(v) for v in row_splits_raw] if isinstance(row_splits_raw, list) else None + ) + identities = cls._build_moe_row_identities( + module_name=module_name, + call_index=call_index, + trace=trace, + row_splits=row_splits, + ) + if identities is None or len(identities) != int(tensor.shape[0]): + return tensor + order = sorted(range(len(identities)), key=lambda index: identities[index]) + return tensor[order] + + @classmethod + def _canonicalize_primary_output_tensor( + cls, + *, + module_name: str, + call_index: int, + tensor: torch.Tensor, + trace: dict[str, list[dict[str, Any]]], + call: dict[str, Any], + ) -> torch.Tensor: + """Runs all primary-output canonicalization passes for one call tensor.""" + tensor = cls._canonicalize_etp_fc1_feature_layout( + module_name=module_name, + tensor=tensor, + call=call, + ) + return cls._canonicalize_moe_expert_row_order( + module_name=module_name, + call_index=call_index, + tensor=tensor, + trace=trace, + call=call, + ) + + @classmethod + def canonicalize_trace( + cls, + trace: dict[str, list[dict[str, Any]]], + ) -> dict[str, list[dict[str, Any]]]: + """Canonicalizes topology-dependent trace outputs in place.""" + for module_name in sorted(trace.keys()): + calls = trace[module_name] + for call_offset, call in enumerate(calls): + if bool(call.get(PRIMARY_OUTPUT_CANONICAL_KEY)): + continue + call_index = int(call.get("call_index", call_offset)) + tensor = call.get("primary_output") + if isinstance(tensor, torch.Tensor): + call["primary_output"] = cls._canonicalize_primary_output_tensor( + module_name=module_name, + call_index=call_index, + tensor=tensor, + trace=trace, + call=call, + ) + call[PRIMARY_OUTPUT_CANONICAL_KEY] = True + return trace + + @classmethod + def flatten_trace_tensors( + cls, + trace: dict[str, list[dict[str, Any]]], + *, + value_key: str, + ) -> dict[str, Any]: + """Flattens trace calls into deterministic key->value tensor maps.""" + if value_key == "primary_output": + cls.canonicalize_trace(trace) + flattened: dict[str, Any] = {} + for module_name in sorted(trace.keys()): + for call_offset, call in enumerate(trace[module_name]): + tensor = call.get(value_key) + if tensor is None: + continue + call_index = call.get("call_index", call_offset) + flattened[f"{module_name}.call_{call_index}"] = tensor + return flattened + @classmethod def _merge_rank_values( cls, @@ -478,15 +741,16 @@ def save_current_step(self, traces_dir: Path) -> Path | None: gathered_traces = self._gather_rank_traces(self.current_step_trace) if gathered_traces is None: return None - merged_trace = self._merge_rank_traces(gathered_traces) + merged_trace = self.canonicalize_trace(self._merge_rank_traces(gathered_traces)) traces_dir.mkdir(parents=True, exist_ok=True) trace_path = traces_dir / f"forward_trace_step_{self.current_step_index:03d}.pt" torch.save(merged_trace, trace_path) return trace_path - @staticmethod - def load_trace(trace_path: Path) -> dict[str, list[dict[str, Any]]]: - return torch.load(trace_path, map_location="cpu", weights_only=False) + @classmethod + def load_trace(cls, trace_path: Path) -> dict[str, list[dict[str, Any]]]: + trace = torch.load(trace_path, map_location="cpu", weights_only=False) + return cls.canonicalize_trace(trace) def close(self) -> None: for handle in self._hook_handles: From 8370c7d912a182e892ca092770f3de3d2dfdb81c Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Sat, 14 Mar 2026 08:21:31 +0000 Subject: [PATCH 14/19] oracle harness: stabilize scoring and expand sensitivity mutations Rework oracle pass/fail evaluation with per-phase functions, layer-averaged metrics, deterministic init, expanded sensitivity mutations, and smaller Adam epsilon for tiny-gradient regimes. --- tests/integration/megatron_oracle_harness.py | 499 +++++++++---------- tests/integration/megatron_oracle_worker.py | 299 +++++++++-- 2 files changed, 518 insertions(+), 280 deletions(-) diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index 4e6567f9..785e8618 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -8,7 +8,7 @@ from pathlib import Path import re import shutil -from typing import Any, Literal, TypeVar, cast +from typing import Any, Callable, Literal, TypeVar, cast from pydantic import BaseModel, ConfigDict, Field from rich import box @@ -23,12 +23,21 @@ ORACLE_MOE_ROUTING_BUNDLE_DIRNAME = "oracle_moe_routing_replay" ORACLE_REPLAY_TOPOLOGY_SUFFIX = "oracle_replay" -REGENERATE_ENV = "ART_REGENERATE_MEGATRON_ORACLE" -EXTENDED_TOPOLOGIES_ENV = "ART_MEGATRON_ORACLE_ENABLE_EXTENDED_TOPOLOGIES" -SENSITIVITY_MUTATION_ENV = "ART_MEGATRON_ORACLE_MUTATION" - -DEFAULT_SENSITIVITY_MUTATION = "drop_finalize" -SUPPORTED_SENSITIVITY_MUTATIONS = (DEFAULT_SENSITIVITY_MUTATION,) +REGENERATE_ENV = "ART_REGENERATE_ORACLE" +EXTENDED_TOPOLOGIES_ENV = "ART_ENABLE_EXTENDED_TOPOLOGIES" +SENSITIVITY_MUTATION_ENV = "ART_SENSITIVITY_MUTATIONS" + +DEFAULT_SENSITIVITY_MUTATION = "skip_finalize" +SUPPORTED_SENSITIVITY_MUTATIONS = ( + DEFAULT_SENSITIVITY_MUTATION, + "fwd_skip_o_proj_tp_reduce", + "fwd_o_proj_tp_reduce_avg_not_sum", + "bwd_skip_sync_qkv_a", + "bwd_skip_sync_o_proj_b", + "bwd_skip_sync_fc1_a", + "save_drop_nonzero_ranked_tp_shards", + "save_duplicate_replicated_entries", +) SensitivityMutation = str REQUIRED_PACKED_TENSOR_FILES = ( @@ -44,9 +53,10 @@ NON_FINITE_METRIC_VALUE = 1e30 EXPERT_TABLE_ROW_LIMIT = 8 EXPERT_TRIPLET_PARAM_RE = re.compile( - r"layers\.(?P\d+)\.mlp\.experts\.(?P\d+)\." + r"layers\.(?P\d+|__layer_avg__)\.mlp\.experts\.(?P\d+)\." r"(?Pgate_proj|up_proj|down_proj)\." ) +LAYER_INDEX_RE = re.compile(r"layers\.(\d+)\.") PHASE_PRINT_ORDER = { "forward": 0, "router_scores": 1, @@ -85,6 +95,7 @@ def resolved_expert_dp(self) -> int: return attention_world // expert_divisor def slug(self) -> str: + """Builds a deterministic topology identifier used for output directories.""" return ( f"tp{self.tp}_ep{self.ep}_etp{self.etp}" f"_dp{self.dp}_edp{self.resolved_expert_dp()}" @@ -103,7 +114,7 @@ def world_size(self) -> int: class PackedTensorConfig(BaseModel): """Controls synthetic packed tensor generation used by oracle harness runs.""" - num_sequences: int = 8 + num_sequences: int = 4 sequence_length: int = 256 prefill_tokens: int = 64 decode_tokens: int = 64 @@ -128,11 +139,30 @@ class LoraConfig(BaseModel): ) -class ToleranceProfile(BaseModel): - """Defines row-level pass/fail thresholds for variant comparison phases.""" +MetricSummary = dict[str, float] +PhasePassFn = Callable[[MetricSummary], bool] - relative_l2: float = 1e-2 - mean_abs_pct: float = 1.0 + +class MetricThresholdRule(BaseModel): + """Callable row pass rule that AND-checks configured metric upper bounds.""" + + limits: dict[str, float] = Field(default_factory=dict) + + def failure_reasons(self, summary: MetricSummary) -> list[str]: + """Builds readable failure reasons for this threshold rule.""" + reasons: list[str] = [] + for key, limit in sorted(self.limits.items()): + value = summary.get(key) + if not isinstance(value, (int, float)): + reasons.append(f"{key}=missing") + continue + if float(value) > float(limit): + reasons.append(f"{key}={float(value):.6g}>{float(limit):.6g}") + return reasons + + def __call__(self, summary: MetricSummary) -> bool: + """Evaluates whether the summary satisfies all configured bounds.""" + return len(self.failure_reasons(summary)) == 0 class OracleCaseConfig(BaseModel): @@ -141,13 +171,13 @@ class OracleCaseConfig(BaseModel): base_model: str num_layers: int = 4 seed: int = 20260305 - num_steps: int = 2 - learning_rate: float = 1e-3 + num_steps: int = 1 + grad_accumulation_sequences: int = Field(default=4, ge=1) + learning_rate: float = 5e-6 beta: float = 0.0 - loss_scale: float = 1e4 + loss_scale: float = 1 packed_tensors: PackedTensorConfig = Field(default_factory=PackedTensorConfig) lora: LoraConfig = Field(default_factory=LoraConfig) - tolerances: ToleranceProfile = Field(default_factory=ToleranceProfile) class DiskPackedTensorsSpec(BaseModel): @@ -207,7 +237,6 @@ class RunManifest(BaseModel): seed: int num_steps: int packed_tensors: DiskPackedTensorsSpec - tolerances: ToleranceProfile steps: list[StepTrace] @@ -228,7 +257,6 @@ class MetricRow(BaseModel): mean_abs_pct: float topk_mismatch_fraction: float | None = None top1_mismatch_fraction: float | None = None - thresholds: dict[str, float] = Field(default_factory=dict) pass_signal: bool = True failure_reasons: list[str] = Field(default_factory=list) @@ -238,18 +266,25 @@ class VariantSpec(BaseModel): name: str topology: Topology - thresholds_by_phase: dict[str, dict[str, float]] + pass_fn_by_phase: dict[str, PhasePassFn] = Field( + default_factory=dict, + repr=False, + exclude=True, + ) output_slug: str | None = None reference_slug: str | None = None mutation: SensitivityMutation | None = None expected_signal: Literal["pass", "fail"] = "pass" + force_regenerate: bool = True def resolved_output_slug(self) -> str: + """Resolves the artifact slug for this run, including mutation suffix when present.""" if self.output_slug is not None: return self.output_slug return _topology_output_slug(self.topology, self.mutation) def resolved_reference_slug(self) -> str: + """Resolves which topology slug should be treated as the comparison oracle.""" if self.reference_slug is not None: return self.reference_slug return ORACLE_TOPOLOGY.slug() @@ -266,8 +301,8 @@ class VariantReport(BaseModel): signal: Literal["pass", "fail"] pass_count: int fail_count: int - step_summaries: dict[int, dict[str, Any]] - metrics: list[MetricRow] + step_summaries: dict[int, dict[str, Any]] = Field(repr=False) + metrics: list[MetricRow] = Field(repr=False) class DiffAccumulator: @@ -297,6 +332,20 @@ def update(self, reference, candidate) -> None: # type: ignore[no-untyped-def] self.ref_sq_sum += float(ref.square().sum().item()) self.ref_abs_sum += float(ref.abs().sum().item()) + @staticmethod + def layer_averaged_summary(reference_stack, candidate_stack) -> dict[str, float]: # type: ignore[no-untyped-def] + """Computes normal per-layer summaries, then averages those summaries.""" + ref = reference_stack.detach().float() + cand = candidate_stack.detach().float() + layer_count = int(ref.shape[0]) + metrics = {k: 0.0 for k in ["numel", "mean_abs_diff", "relative_l2", "typical_abs_scale", "mean_abs_pct"]} + for layer_index in range(layer_count): + layer_accumulator = DiffAccumulator() + layer_accumulator.update(ref[layer_index], cand[layer_index]) + layer_summary = layer_accumulator.as_summary() + metrics = {k: metrics[k] + layer_summary[k] for k in metrics.keys()} + return {k: _finite_metric(metrics[k] / layer_count) for k in metrics.keys()} + def update_router_ids(self, reference_ids, candidate_ids) -> None: # type: ignore[no-untyped-def] """Adds router top-k id mismatch counts into the accumulator.""" self.router_topk_total += int(reference_ids.numel()) @@ -353,6 +402,7 @@ def as_summary(self) -> dict[str, float]: def _require_not_none(value: T | None, name: str) -> T: + """Asserts non-None values for required artifacts and raises a named runtime error.""" if value is None: raise RuntimeError(f"{name} is None") return value @@ -361,18 +411,23 @@ def _require_not_none(value: T | None, name: str) -> T: TOPOLOGIES = [ Topology(tp=1, ep=1, etp=1, dp=1, sp=False), Topology(tp=2, ep=1, etp=1, dp=1, sp=True), - Topology(tp=1, ep=2, etp=1, dp=2, sp=False), - Topology(tp=2, ep=2, etp=1, dp=2, sp=True), + Topology(tp=2, ep=2, etp=1, dp=1, sp=True), + Topology(tp=2, ep=1, etp=2, dp=1, sp=True), ] EXTENDED_TOPOLOGIES = [ Topology(tp=1, ep=1, etp=1, dp=2, sp=False), - Topology(tp=2, ep=1, etp=1, dp=2, sp=True), + Topology(tp=1, ep=2, etp=1, dp=2, sp=True), ] ORACLE_TOPOLOGY = TOPOLOGIES[0] -SENSITIVITY_TOPOLOGY = TOPOLOGIES[1] +SENSITIVITY_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) +SENSITIVITY_TOPOLOGY_BY_MUTATION: dict[SensitivityMutation, Topology] = { + mutation: SENSITIVITY_TOPOLOGY for mutation in SUPPORTED_SENSITIVITY_MUTATIONS +} +SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology(tp=2, ep=1, etp=2, dp=1, sp=True) def _truthy(value: str | None) -> bool: + """Parses env-var style booleans using a small accepted truthy set.""" if value is None: return False return value.strip().lower() in {"1", "true", "yes", "on"} @@ -384,6 +439,8 @@ def sensitivity_mutations() -> list[SensitivityMutation]: if raw is None or raw.strip() == "": return [] normalized = raw.strip().lower() + if normalized == "all": + return list(SUPPORTED_SENSITIVITY_MUTATIONS) if normalized in {"1", "true", "yes", "on"}: return [DEFAULT_SENSITIVITY_MUTATION] mutations = [item.strip().lower() for item in raw.split(",") if item.strip()] @@ -397,20 +454,35 @@ def sensitivity_mutations() -> list[SensitivityMutation]: supported = ", ".join(SUPPORTED_SENSITIVITY_MUTATIONS) raise ValueError( f"Unsupported {SENSITIVITY_MUTATION_ENV} value '{raw}'. " - f"Supported values: {supported}, CSV of supported values, 1/true/yes/on." + f"Supported values: {supported}, CSV of supported values, all, 1/true/yes/on." ) def sensitivity_enabled() -> bool: + """Returns whether any sensitivity mutation has been requested via environment.""" return bool(sensitivity_mutations()) +def sensitivity_topology_for_mutation(mutation: SensitivityMutation) -> Topology: + """Returns the sensitivity topology required for one mutation.""" + return SENSITIVITY_TOPOLOGY_BY_MUTATION[mutation] + + +def sensitivity_required_world_size(mutations: list[SensitivityMutation]) -> int: + """Returns the max world-size required by a selected mutation set.""" + return max( + sensitivity_topology_for_mutation(mutation).world_size() + for mutation in mutations + ) + + def extended_topologies_enabled() -> bool: """Returns whether extended topologies are enabled for the suite.""" return _truthy(os.environ.get(EXTENDED_TOPOLOGIES_ENV)) def regenerate_requested() -> bool: + """Returns whether regeneration mode is enabled for oracle artifacts.""" return _truthy(os.environ.get(REGENERATE_ENV)) @@ -422,6 +494,7 @@ def case_config( def available_gpu_count() -> int: + """Reports visible CUDA device count for topology scheduling and test skips.""" import torch return int(torch.cuda.device_count()) @@ -442,12 +515,14 @@ def stable_case_id(case_config: OracleCaseConfig) -> str: def _write_json(path: Path, payload: Any) -> None: + """Writes canonical pretty JSON to disk, creating parent directories as needed.""" path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2, sort_keys=True, allow_nan=False) def _read_json(path: Path) -> dict[str, Any]: + """Loads a JSON object from disk.""" with path.open("r", encoding="utf-8") as handle: return json.load(handle) @@ -612,128 +687,6 @@ def _align_sequence_parallel(reference, candidate): # type: ignore[no-untyped-d return None -def _is_moe_base_forward_param(name: str) -> bool: - """Returns whether this forward param is a base MoE expert internal tensor.""" - if ".mlp.experts." not in name: - return False - if any(token in name for token in (".router", ".gate_lora", ".up_lora", ".lora")): - return False - return ".linear_fc1" in name or ".linear_fc2" in name - - -def _lookup_call_by_index( - trace: dict[str, list[dict[str, Any]]], - module_name: str, - call_index: int, -) -> dict[str, Any] | None: - calls = trace.get(module_name) - if calls is None: - return None - for call in calls: - if int(call.get("call_index", -1)) == call_index: - return call - if 0 <= call_index < len(calls): - return calls[call_index] - return None - - -def _router_module_name_for_expert_module(module_name: str) -> str | None: - if ".mlp.experts.linear_fc1" in module_name: - return module_name.replace(".mlp.experts.linear_fc1", ".mlp.router") - if ".mlp.experts.linear_fc2" in module_name: - return module_name.replace(".mlp.experts.linear_fc2", ".mlp.router") - return None - - -def _build_moe_row_identities( - *, - module_name: str, - call_index: int, - trace: dict[str, list[dict[str, Any]]], - row_splits: list[int] | None, -) -> list[tuple[int, int, int]] | None: - router_module_name = _router_module_name_for_expert_module(module_name) - if router_module_name is None: - return None - router_call = _lookup_call_by_index(trace, router_module_name, call_index) - if router_call is None: - return None - router_topk_ids = router_call.get("router_topk_ids") - if not isinstance(router_topk_ids, torch.Tensor) or router_topk_ids.ndim != 2: - return None - token_splits_raw = router_call.get("router_topk_ids__row_splits") - if row_splits is None: - if isinstance(token_splits_raw, list): - row_splits = [ - int(v) * int(router_topk_ids.shape[1]) for v in token_splits_raw - ] - else: - row_splits = [int(router_topk_ids.numel())] - if isinstance(token_splits_raw, list): - token_splits = [int(v) for v in token_splits_raw] - else: - topk = int(router_topk_ids.shape[1]) - token_splits = [int(v) // topk for v in row_splits] - if len(row_splits) != len(token_splits): - return None - row_cursor = 0 - token_cursor = 0 - identities: list[tuple[int, int, int]] = [] - for row_count, token_count in zip(row_splits, token_splits): - local_ids = router_topk_ids[token_cursor : token_cursor + token_count] - token_cursor += token_count - local_identities: list[tuple[int, int, int]] = [] - max_expert = int(local_ids.max().item()) if local_ids.numel() > 0 else -1 - for expert_id in range(max_expert + 1): - expert_rows = (local_ids == expert_id).nonzero(as_tuple=False) - for token_offset, slot_index in expert_rows.tolist(): - local_identities.append( - (expert_id, token_cursor - token_count + token_offset, slot_index) - ) - if len(local_identities) != row_count: - return None - identities.extend(local_identities) - row_cursor += row_count - if row_cursor != sum(row_splits): - return None - return identities - - -def _canonicalize_moe_base_forward_tensor( - *, - module_name: str, - call_index: int, - tensor: torch.Tensor, - trace: dict[str, list[dict[str, Any]]], - call: dict[str, Any], -) -> torch.Tensor: - if not _is_moe_base_forward_param(module_name): - return tensor - if tensor.ndim != 2: - return tensor - row_splits_raw = call.get("primary_output__row_splits") - row_splits = ( - [int(v) for v in row_splits_raw] if isinstance(row_splits_raw, list) else None - ) - identities = _build_moe_row_identities( - module_name=module_name, - call_index=call_index, - trace=trace, - row_splits=row_splits, - ) - if identities is None or len(identities) != int(tensor.shape[0]): - return tensor - order = sorted(range(len(identities)), key=lambda index: identities[index]) - return tensor[order] - - -def _minimal_param_name(name: str) -> str: - """Returns a shorter but 1:1 param/module identifier for report readability.""" - return name.removeprefix("base_model.model.model.").replace( - "module.module.decoder.", "" - ) - - def _load_forward_trace( topology_dir: Path, step_index: int ) -> dict[str, list[dict[str, Any]]]: @@ -742,13 +695,6 @@ def _load_forward_trace( return ForwardTraceCapture.load_trace(trace_path) -def _threshold_string(thresholds: dict[str, float]) -> str: - """Formats threshold dicts into compact table cells.""" - if not thresholds: - return "-" - return ", ".join(f"{key}<={value:.3g}" for key, value in sorted(thresholds.items())) - - def _finite_metric(value: float, *, default: float = NON_FINITE_METRIC_VALUE) -> float: """Maps NaN/Inf metric values to a large finite sentinel for JSON-safe reports.""" value_f = float(value) @@ -759,12 +705,50 @@ def _finite_metric(value: float, *, default: float = NON_FINITE_METRIC_VALUE) -> return value_f -def _triplet_expert_key(param: str) -> tuple[int, int] | None: - """Returns (layer, expert_id) for expert up/gate/down params.""" +def _triplet_expert_key(param: str) -> tuple[str, int] | None: + """Returns (projection, expert_id) for expert gate/up/down params.""" match = EXPERT_TRIPLET_PARAM_RE.search(param) if match is None: return None - return int(match.group("layer")), int(match.group("expert")) + return match.group("proj"), int(match.group("expert")) + + +def _layer_agnostic_param_key(param: str) -> str | None: + """Normalizes one parameter name by stripping the explicit layer index.""" + if LAYER_INDEX_RE.search(param) is None: + return None + return LAYER_INDEX_RE.sub("layers.__layer_avg__.", param, count=1) + + +def _stacked_layers( + pairs: list[tuple[str, Any, Any]], +) -> list[tuple[str, Any, Any]]: + """Builds layer-stacked tensor pairs keyed without explicit layer index.""" + import torch + + grouped: dict[str, list[tuple[Any, Any]]] = {} + for name, reference, candidate in pairs: + normalized = _layer_agnostic_param_key(name) + if normalized is None: + raise RuntimeError( + "Expected all compared params to include a layer index, " + f"got '{name}'." + ) + grouped.setdefault(normalized, []).append( + (reference.detach().float(), candidate.detach().float()) + ) + + stacked_pairs: list[tuple[str, Any, Any]] = [] + for normalized in sorted(grouped): + group = grouped[normalized] + stacked_pairs.append( + ( + normalized, + torch.stack([reference for reference, _ in group], dim=0), + torch.stack([candidate for _, candidate in group], dim=0), + ) + ) + return stacked_pairs class VariantRunner: @@ -806,9 +790,10 @@ def _run_topology( if manifest_path.exists() and not regenerate: return topology_dir _replace_topology_dir(topology_dir) + run_case_config = self.case_config request = WorkerRunRequest( case_id=self.case_id, - case_config=self.case_config, + case_config=run_case_config, topology=topology, topology_dir=str(topology_dir), packed_tensors=self.case_artifacts.packed_tensors, @@ -878,28 +863,33 @@ def ensure_variant_artifacts( mutation=variant.mutation, replay_bundle_dir=self.oracle_routing_bundle_dir, capture_bundle_dir=None, - regenerate=True, + regenerate=variant.force_regenerate, ) @staticmethod - def _apply_thresholds(row: MetricRow, thresholds: dict[str, float]) -> None: - """Evaluates row thresholds using AND semantics over all configured keys.""" - row.thresholds = dict(thresholds) - if not thresholds: + def _apply_phase_pass( + *, + row: MetricRow, + phase: str, + summary: MetricSummary, + pass_fn_by_phase: dict[str, PhasePassFn], + ) -> None: + """Evaluates a per-phase pass function against one summary payload.""" + pass_fn = pass_fn_by_phase.get(phase) + if pass_fn is None: row.pass_signal = True row.failure_reasons = [] return - payload = row.model_dump(mode="python") - reasons: list[str] = [] - for key, limit in sorted(thresholds.items()): - value = payload.get(key) - if not isinstance(value, (int, float)): - reasons.append(f"{key}=missing") - continue - if float(value) > float(limit): - reasons.append(f"{key}={float(value):.6g}>{float(limit):.6g}") - row.pass_signal = len(reasons) == 0 - row.failure_reasons = reasons + row.pass_signal = bool(pass_fn(summary)) + if row.pass_signal: + row.failure_reasons = [] + return + explain = getattr(pass_fn, "failure_reasons", None) + if callable(explain): + reasons = explain(summary) + row.failure_reasons = reasons if reasons else ["phase pass function returned false"] + return + row.failure_reasons = ["phase pass function returned false"] @staticmethod def _inf_summary() -> dict[str, float]: @@ -924,7 +914,7 @@ def _build_metric_row( summary: dict[str, float], structural_failure: str | None = None, ) -> MetricRow: - """Builds one metric row and applies per-phase thresholds.""" + """Builds one metric row and applies per-phase pass evaluation.""" row = MetricRow( case_id=self.case_id, variant=variant.name, @@ -941,7 +931,12 @@ def _build_metric_row( topk_mismatch_fraction=summary.get("topk_mismatch_fraction"), top1_mismatch_fraction=summary.get("top1_mismatch_fraction"), ) - self._apply_thresholds(row, variant.thresholds_by_phase.get(phase, {})) + self._apply_phase_pass( + row=row, + phase=phase, + summary=summary, + pass_fn_by_phase=variant.pass_fn_by_phase, + ) if structural_failure is not None: row.pass_signal = False row.failure_reasons = [structural_failure, *row.failure_reasons] @@ -955,11 +950,11 @@ def _build_metric_rows_from_tensor_pairs( phase: str, pairs: list[tuple[str, Any, Any]], router_ids: bool = False, + layer_averaged: bool = False, ) -> list[MetricRow]: """Builds rows from named tensor pairs with one shared diff path.""" rows: list[MetricRow] = [] for name, reference, candidate in pairs: - param_name = _minimal_param_name(name) reference_aligned = reference candidate_aligned = candidate aligned_candidate = _align_sequence_parallel( @@ -971,24 +966,30 @@ def _build_metric_rows_from_tensor_pairs( variant=variant, step_index=step_index, phase=phase, - param=param_name, + param=name, summary=self._inf_summary(), structural_failure="shape mismatch", ) ) continue - accumulator = DiffAccumulator() + summary: dict[str, float] if router_ids: + accumulator = DiffAccumulator() accumulator.update_router_ids(reference_aligned, aligned_candidate) + summary = accumulator.as_summary() + elif layer_averaged: + summary = DiffAccumulator.layer_averaged_summary(reference_aligned, aligned_candidate) else: + accumulator = DiffAccumulator() accumulator.update(reference_aligned, aligned_candidate) + summary = accumulator.as_summary() rows.append( self._build_metric_row( variant=variant, step_index=step_index, phase=phase, - param=param_name, - summary=accumulator.as_summary(), + param=name, + summary=summary, ) ) return rows @@ -1039,39 +1040,17 @@ def _build_metric_rows_from_tensor_maps( (key, reference[key], candidate[key]) for key in sorted(set(reference.keys())) ] + if phase in {"forward", "grads", "deltas"}: + pairs = _stacked_layers(pairs) return self._build_metric_rows_from_tensor_pairs( variant=variant, step_index=step_index, phase=phase, pairs=pairs, router_ids=router_ids, + layer_averaged=phase in {"forward", "grads", "deltas"}, ) - @staticmethod - def _flatten_forward_trace_tensors( - trace: dict[str, list[dict[str, Any]]], - *, - value_key: str, - ) -> dict[str, Any]: - """Flattens per-module forward trace calls into a deterministic tensor map.""" - flattened: dict[str, Any] = {} - for module_name in sorted(trace.keys()): - for call_offset, call in enumerate(trace[module_name]): - tensor = call.get(value_key) - if tensor is None: - continue - call_index = call.get("call_index", call_offset) - if value_key == "primary_output" and isinstance(tensor, torch.Tensor): - tensor = _canonicalize_moe_base_forward_tensor( - module_name=module_name, - call_index=int(call_index), - tensor=tensor, - trace=trace, - call=call, - ) - flattened[f"{module_name}.call_{call_index}"] = tensor - return flattened - @staticmethod def _build_step_summaries(rows: list[MetricRow]) -> dict[int, dict[str, Any]]: """Builds step-indexed payloads directly from row model dumps.""" @@ -1142,11 +1121,11 @@ def compare_variant(self, variant: VariantSpec) -> VariantReport: *[ ( phase, - self._flatten_forward_trace_tensors( + ForwardTraceCapture.flatten_trace_tensors( reference_trace, value_key=value_key, ), - self._flatten_forward_trace_tensors( + ForwardTraceCapture.flatten_trace_tensors( topology_trace, value_key=value_key, ), @@ -1187,7 +1166,12 @@ def compare_variant(self, variant: VariantSpec) -> VariantReport: ) @staticmethod - def assert_expected_signal(report: VariantReport, context: str) -> None: + def assert_expected_signal( + report: VariantReport, + context: str, + *, + report_path: Path, + ) -> None: """Raises when observed run signal diverges from variant expectation.""" if report.signal == report.expected_signal: return @@ -1196,11 +1180,13 @@ def assert_expected_signal(report: VariantReport, context: str) -> None: raise AssertionError( f"{context}: topology={report.topology} phase={first_failure.phase} " f"step={first_failure.step_index} param={first_failure.param} " - f"reasons={'; '.join(first_failure.failure_reasons)}" + f"reasons={'; '.join(first_failure.failure_reasons)} " + f"report={report_path}" ) raise AssertionError( f"{context}: expected_signal={report.expected_signal} " - f"observed_signal={report.signal} topology={report.topology}" + f"observed_signal={report.signal} topology={report.topology} " + f"report={report_path}" ) def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> None: @@ -1210,9 +1196,9 @@ def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> No ) def print_report(self, report: VariantReport) -> None: - """Prints a row-level table with expert rows subsampled by highest relative_l2.""" + """Prints a row-level table with expert rows subsampled by highest mean_abs_pct.""" non_expert_rows: list[MetricRow] = [] - triplet_rows: list[tuple[tuple[int, int], MetricRow]] = [] + triplet_rows: list[tuple[tuple[str, int], MetricRow]] = [] for row in report.metrics: expert_key = _triplet_expert_key(row.param) if expert_key is None: @@ -1220,22 +1206,22 @@ def print_report(self, report: VariantReport) -> None: continue triplet_rows.append((expert_key, row)) - scores_by_layer: dict[int, dict[int, float]] = {} - for (layer, expert_id), row in triplet_rows: - layer_scores = scores_by_layer.setdefault(layer, {}) - layer_scores[expert_id] = max( - layer_scores.get(expert_id, float("-inf")), row.relative_l2 + scores_by_proj: dict[str, dict[int, float]] = {} + for (projection, expert_id), row in triplet_rows: + projection_scores = scores_by_proj.setdefault(projection, {}) + projection_scores[expert_id] = max( + projection_scores.get(expert_id, float("-inf")), row.mean_abs_pct ) - selected_experts: set[tuple[int, int]] = set() - for layer, expert_scores in scores_by_layer.items(): + selected_experts: set[tuple[str, int]] = set() + for projection, expert_scores in scores_by_proj.items(): top_experts = sorted( expert_scores.items(), key=lambda item: item[1], reverse=True, )[:EXPERT_TABLE_ROW_LIMIT] for expert_id, _score in top_experts: - selected_experts.add((layer, expert_id)) + selected_experts.add((projection, expert_id)) selected_triplet_rows = [ row for expert_key, row in triplet_rows if expert_key in selected_experts @@ -1244,20 +1230,20 @@ def print_report(self, report: VariantReport) -> None: detail_table = Table( title=( f"Variant Report | variant={report.variant} " - f"| topology={report.topology} | signal={report.signal} " f"| selected_experts={len(selected_experts)} " - f"(top {EXPERT_TABLE_ROW_LIMIT} per layer)" + f"(top {EXPERT_TABLE_ROW_LIMIT} per projection by mean_abs_pct)" ), box=box.SIMPLE_HEAVY, show_lines=False, ) detail_table.add_column("Step", justify="right") detail_table.add_column("Phase", style="cyan") - detail_table.add_column("Param") + detail_table.add_column("Param", overflow="fold") detail_table.add_column("Status") detail_table.add_column("relative_l2", justify="right") detail_table.add_column("mean_abs_pct", justify="right") detail_table.add_column("typical_abs", justify="right") + detail_table.add_column("mean_abs_diff", justify="right") # detail_table.add_column("Thresholds") detail_table.add_column("Failure") sorted_rows = sorted( @@ -1282,7 +1268,7 @@ def print_report(self, report: VariantReport) -> None: f"{row.relative_l2:.6g}", f"{row.mean_abs_pct:.6g}%", f"{row.typical_abs_scale:.6g}", - # _threshold_string(row.thresholds), # disabled for now to avoid clutter, neat to keep though + f"{row.mean_abs_diff:.6g}", failure_text, ) self.console.print(detail_table) @@ -1307,32 +1293,42 @@ def run_suite( for variant in variants: report = self.run_variant(variant) reports.append(report) - self.assert_expected_signal(report, "Megatron oracle suite mismatch") + self.assert_expected_signal( + report, + "Megatron correctness suite mismatch", + report_path=self.case_dir + / variant.resolved_output_slug() + / "variant_report.json", + ) return reports -def _default_phase_thresholds( - case_cfg: OracleCaseConfig, -) -> dict[str, dict[str, float]]: - """Builds default per-phase (fwd, grad, outputs, losses, deltas) threshold dictionaries.""" - default = { - "relative_l2": case_cfg.tolerances.relative_l2, - "mean_abs_pct": case_cfg.tolerances.mean_abs_pct, - } - return { - key: default for key in ["outputs", "losses", "grads", "deltas", "forward"] - } | { - "router_scores": {"mean_abs_pct": 0.0}, - "router_topk_ids": { +def _default_phase_pass_fns() -> dict[str, PhasePassFn]: + """Builds default per-phase pass functions over diff summaries.""" + # note the metrics get averaged across layers to reduce noise + # we don't expect particular layers to see errors as opposed to the others so this is helpful + fwd_out_loss = MetricThresholdRule(limits={"relative_l2": 3e-2, "mean_abs_pct": 3.0}) + grads = lambda summary: ( + summary["mean_abs_pct"] < 5.0 + or (summary["typical_abs_scale"] < 1e-6 and summary["mean_abs_diff"] < 2e-8 and summary["relative_l2"] < 1.0) + ) + deltas = lambda summary: ( + summary["mean_abs_pct"] < 15.0 + ) + router_topk_rule = MetricThresholdRule( # should be no mismatch due to router replay + limits={ "topk_mismatch_fraction": 0.0, "top1_mismatch_fraction": 0.0, - }, - } + } + ) + return { + key: fwd_out_loss for key in ["forward", "outputs", "losses"] + } | {"grads": grads, "deltas": deltas, "router_topk_ids": router_topk_rule} -def _suite_variants(case_cfg: OracleCaseConfig) -> list[VariantSpec]: +def _suite_variants() -> list[VariantSpec]: """Builds the standard oracle suite variant ordering.""" - thresholds = _default_phase_thresholds(case_cfg) + phase_pass = _default_phase_pass_fns() variants = [ VariantSpec( name="oracle_replay_parity", @@ -1340,7 +1336,8 @@ def _suite_variants(case_cfg: OracleCaseConfig) -> list[VariantSpec]: output_slug=_topology_output_slug( ORACLE_TOPOLOGY, ORACLE_REPLAY_TOPOLOGY_SUFFIX ), - thresholds_by_phase=thresholds, + pass_fn_by_phase=phase_pass, + force_regenerate=regenerate_requested(), ) ] for topology in TOPOLOGIES[1:] + ( @@ -1350,7 +1347,7 @@ def _suite_variants(case_cfg: OracleCaseConfig) -> list[VariantSpec]: VariantSpec( name=f"topology_{topology.slug()}", topology=topology, - thresholds_by_phase=thresholds, + pass_fn_by_phase=phase_pass, ) ) return variants @@ -1362,7 +1359,7 @@ def run_suite( ) -> list[VariantReport]: """Runs replay parity and topology variants with fail-fast assertions.""" runner = VariantRunner(case_config=case_config) - return runner.run_suite(_suite_variants(case_config)) + return runner.run_suite(_suite_variants()) def run_sensitivity_suite( @@ -1372,14 +1369,14 @@ def run_sensitivity_suite( ) -> list[VariantReport]: """Runs a list of sensitivity mutations and expects each to fail.""" runner = VariantRunner(case_config=case_config) - thresholds = _default_phase_thresholds(case_config) + phase_pass = _default_phase_pass_fns() variants = [ VariantSpec( name=f"sensitivity_{mutation}", - topology=SENSITIVITY_TOPOLOGY, + topology=sensitivity_topology_for_mutation(mutation), mutation=mutation, expected_signal="fail", - thresholds_by_phase=thresholds, + pass_fn_by_phase=phase_pass, ) for mutation in mutations ] diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index a5a3ed66..1d734d91 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -1,15 +1,18 @@ from __future__ import annotations import argparse -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager +import hashlib import os from pathlib import Path import random import subprocess import sys +from types import MethodType from typing import Any, Callable import numpy as np +import torch from art.megatron.routing_replay import ( ParallelTopology as ReplayParallelTopology, @@ -20,6 +23,7 @@ from .megatron_forward_trace import ForwardTraceCapture from .megatron_oracle_harness import ( + SUPPORTED_SENSITIVITY_MUTATIONS, OracleCaseConfig, RunManifest, SensitivityMutation, @@ -104,7 +108,9 @@ def _merge_sharded_dicts(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any] return full_state -def _gather_full_state(local_state: dict[str, Any]) -> dict[str, Any] | None: +def _gather_full_state( + local_state: dict[str, Any], +) -> dict[str, Any] | None: """Gathers local state dicts to rank 0 and merges them.""" import torch @@ -119,7 +125,9 @@ def _gather_full_state(local_state: dict[str, Any]) -> dict[str, Any] | None: return _merge_sharded_dicts(entries) -def _collect_lora_state(model_chunks: list[Any]) -> dict[str, Any] | None: +def _collect_lora_state( + model_chunks: list[Any], +) -> dict[str, Any] | None: """Collects full LoRA adapter state for validation and delta computation.""" local_state: dict[str, Any] = {} for chunk in model_chunks: @@ -163,6 +171,49 @@ def _collect_lora_grads(model_chunks: list[Any]) -> dict[str, Any] | None: return _gather_full_state(local_grads) +def _apply_save_mutation_to_tensor_map( + tensor_map: dict[str, Any], + *, + mutation: SensitivityMutation | None, +) -> dict[str, Any]: + """Applies save-only mutation transforms to already-collected full tensor maps.""" + if mutation == "save_drop_nonzero_ranked_tp_shards": + mutated: dict[str, Any] = {} + for key, value in tensor_map.items(): + if not isinstance(value, torch.Tensor): + mutated[key] = value + continue + if ".lora_A." in key and value.ndim >= 2 and value.shape[1] > 1: + keep = max(1, value.shape[1] // 2) + mutated[key] = value.narrow(1, 0, keep).contiguous() + continue + if ".lora_B." in key and value.ndim >= 2 and value.shape[0] > 1: + keep = max(1, value.shape[0] // 2) + mutated[key] = value.narrow(0, 0, keep).contiguous() + continue + mutated[key] = value + return mutated + + if mutation == "save_duplicate_replicated_entries": + mutated = dict(tensor_map) + source_by_bucket: dict[tuple[tuple[int, ...], str], torch.Tensor] = {} + for key in sorted(mutated.keys()): + value = mutated[key] + if not isinstance(value, torch.Tensor): + continue + if not key.endswith(".weight"): + continue + bucket = (tuple(value.shape), str(value.dtype)) + source = source_by_bucket.get(bucket) + if source is None: + source_by_bucket[bucket] = value + continue + mutated[key] = source.clone().contiguous() + return mutated + + return tensor_map + + def _validate_loaded_state_matches_adapter( loaded_state: dict[str, Any], adapter_model: dict[str, Any], @@ -176,6 +227,29 @@ def _validate_loaded_state_matches_adapter( ) +def _build_deterministic_shared_init( + initial_state: dict[str, Any], + *, + seed: int, +) -> dict[str, Any]: + """Builds deterministic nonzero LoRA init values for both A and B tensors.""" + initialized: dict[str, Any] = {} + for key in sorted(initial_state.keys()): + value = initial_state[key] + if not isinstance(value, torch.Tensor): + raise TypeError(f"Expected tensor value for key '{key}', got {type(value)}") + digest = hashlib.sha256(f"{seed}:{key}".encode("utf-8")).digest() + key_seed = int.from_bytes(digest[:8], "little") % (2**31) + generator = torch.Generator(device="cpu").manual_seed(key_seed) + random_values = torch.randn( + value.shape, + generator=generator, + dtype=torch.float32, + ) + initialized[key] = (0.01 * random_values).to(dtype=value.dtype).contiguous() + return initialized + + def _configure_provider( provider: Any, topology: Topology, @@ -207,7 +281,8 @@ def _build_optimizer_config(case_config: OracleCaseConfig): adam_beta1=0.9, adam_beta2=0.99, clip_grad=0.1, - weight_decay=0.1, + weight_decay=0.0, + adam_eps=1e-13, ) @@ -252,9 +327,144 @@ def _delta_state( } +def _iter_named_unique_parameters( + model_chunks: list[Any], +) -> list[tuple[str, torch.nn.Parameter]]: + seen: set[int] = set() + params: list[tuple[str, torch.nn.Parameter]] = [] + for chunk_index, chunk in enumerate(model_chunks): + for name, param in chunk.named_parameters(): + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + params.append((f"chunk{chunk_index}.{name}", param)) + return params + + +def _matches_grad_sync_skip_mutation(param_name: str, mutation: SensitivityMutation) -> bool: + if mutation == "bwd_skip_sync_qkv_a": + return any( + token in param_name + for token in ( + ".self_attention.linear_qkv.q_proj_lora.A_T", + ".self_attention.linear_qkv.k_proj_lora.A_T", + ".self_attention.linear_qkv.v_proj_lora.A_T", + ) + ) + if mutation == "bwd_skip_sync_o_proj_b": + return ".self_attention.linear_proj.lora.B_T" in param_name + if mutation == "bwd_skip_sync_fc1_a": + return ( + ".mlp.experts.linear_fc1.gate_lora.A_T" in param_name + or ".mlp.experts.linear_fc1.up_lora.A_T" in param_name + ) + return False + + +@contextmanager +def _apply_grad_sync_skip_mutation( + model_chunks: list[Any], + mutation: SensitivityMutation | None, +): + if mutation not in { + "bwd_skip_sync_qkv_a", + "bwd_skip_sync_o_proj_b", + "bwd_skip_sync_fc1_a", + }: + yield + return + + saved_attrs: list[tuple[Any, str, Any]] = [] + for param_name, param in _iter_named_unique_parameters(model_chunks): + # this only passes lora params atm, so we assume lora params below + if not _matches_grad_sync_skip_mutation(param_name, mutation): + continue + if ( + mutation == "bwd_skip_sync_fc1_a" + and param.grad_sync_domain != "expert_tp" + ): + continue + + # For fc1 A params, extended finalize handles expert-TP sync via grad_sync_op. + saved_attrs.append((param, "grad_sync_op", param.grad_sync_op)) + param.grad_sync_op = "none" + + # Megatron native TP finalize uses this only for tp_default-domain params. + if param.average_gradients_across_tp_domain and param.grad_sync_domain == "tp_default": + saved_attrs.append((param, "average_gradients_across_tp_domain", param.average_gradients_across_tp_domain)) + param.average_gradients_across_tp_domain = False + try: + yield + finally: + for param, attr, value in reversed(saved_attrs): + setattr(param, attr, value) + + +@contextmanager +def _apply_o_proj_forward_mutation( + model_chunks: list[Any], + mutation: SensitivityMutation | None, +): + if mutation not in { + "fwd_skip_o_proj_tp_reduce", + "fwd_o_proj_tp_reduce_avg_not_sum", + }: + yield + return + + from megatron.core import parallel_state as ps + from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, + ) + + from art.megatron.lora import SelfAttentionLinearProjLoRA + + original_forwards: list[tuple[Any, Any]] = [] + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, SelfAttentionLinearProjLoRA): + continue + original_forwards.append((module, module.forward)) + + def _mutated_forward(self: Any, x: Any): + base_output, bias_output = self.linear_proj(x) + lora_output = self.lora(x) + tp_size = self.provider.tensor_model_parallel_size + if tp_size > 1: + if mutation == "fwd_o_proj_tp_reduce_avg_not_sum": + if self.provider.sequence_parallel: + lora_output = reduce_scatter_to_sequence_parallel_region( + lora_output + ) + else: + lora_output = reduce_from_tensor_model_parallel_region( + lora_output + ) + lora_output = lora_output / tp_size + elif mutation == "fwd_skip_o_proj_tp_reduce": + if self.provider.sequence_parallel: + seq_per_rank = lora_output.shape[0] // tp_size + tp_rank = ps.get_tensor_model_parallel_rank() + lora_output = lora_output.narrow( + 0, tp_rank * seq_per_rank, seq_per_rank + ) + return base_output + lora_output, bias_output + + module.forward = MethodType(_mutated_forward, module) + + try: + yield + finally: + for module, original_forward in reversed(original_forwards): + module.forward = original_forward + + @contextmanager def _mutation_hook( megatron_train_module: Any, + model_chunks: list[Any], mutation: SensitivityMutation | None, pre_optimizer_step_hook: Callable[[], None] | None = None, loss_scale: float = 1.0, @@ -264,45 +474,54 @@ def _mutation_hook( original_optimizer_step = megatron_train_module._optimizer_step original_loss_fn = megatron_train_module.loss_fn - if mutation == "drop_finalize": - megatron_train_module._finalize_grads = lambda _model: None - elif mutation is not None: + known_mutations = {None, *SUPPORTED_SENSITIVITY_MUTATIONS} + if mutation not in known_mutations: raise ValueError(f"Unsupported mutation: {mutation}") + if mutation == "skip_finalize": + megatron_train_module._finalize_grads = lambda _model: None + if pre_optimizer_step_hook is not None: def _patched_optimizer_step(optimizer: Any, learning_rate: float): - pre_optimizer_step_hook() + if pre_optimizer_step_hook is not None: + pre_optimizer_step_hook() return original_optimizer_step(optimizer, learning_rate) megatron_train_module._optimizer_step = _patched_optimizer_step - if loss_scale <= 0: - raise ValueError(f"loss_scale must be > 0, got {loss_scale}") - if loss_scale != 1.0: + effective_loss_scale = loss_scale + if effective_loss_scale <= 0: + raise ValueError( + f"effective_loss_scale must be > 0, got {effective_loss_scale}" + ) + if effective_loss_scale != 1.0: def _scaled_loss_fn(*args: Any, **kwargs: Any): loss = original_loss_fn(*args, **kwargs) return loss.model_copy( update={ - "mean_policy_loss": loss.mean_policy_loss * loss_scale, - "mean_kl": loss.mean_kl * loss_scale, - "policy_loss_sum": loss.policy_loss_sum * loss_scale, + "mean_policy_loss": loss.mean_policy_loss * effective_loss_scale, + "mean_kl": loss.mean_kl * effective_loss_scale, + "policy_loss_sum": loss.policy_loss_sum * effective_loss_scale, } ) megatron_train_module.loss_fn = _scaled_loss_fn if mutation is None: - if pre_optimizer_step_hook is None and loss_scale == 1.0: + if pre_optimizer_step_hook is None and effective_loss_scale == 1.0: yield return - try: - yield - finally: - megatron_train_module._finalize_grads = original_finalize - megatron_train_module._optimizer_step = original_optimizer_step - megatron_train_module.loss_fn = original_loss_fn + with ExitStack() as stack: + stack.enter_context(_apply_o_proj_forward_mutation(model_chunks, mutation)) + stack.enter_context(_apply_grad_sync_skip_mutation(model_chunks, mutation)) + try: + yield + finally: + megatron_train_module._finalize_grads = original_finalize + megatron_train_module._optimizer_step = original_optimizer_step + megatron_train_module.loss_fn = original_loss_fn def _worker_run(request: WorkerRunRequest) -> None: @@ -347,8 +566,12 @@ def _worker_run(request: WorkerRunRequest) -> None: initial_state = _collect_lora_state(model_chunks) if torch.distributed.get_rank() == 0: shared_init_path.parent.mkdir(parents=True, exist_ok=True) - save_file( + deterministic_init = _build_deterministic_shared_init( _require_not_none(initial_state, "initial_state"), + seed=request.case_config.seed, + ) + save_file( + deterministic_init, str(shared_init_path), ) torch.distributed.barrier() @@ -373,6 +596,7 @@ def _worker_run(request: WorkerRunRequest) -> None: learning_rate=request.case_config.learning_rate, beta=request.case_config.beta, kl_penalty_coef=0.0, + grad_accumulation_sequences=request.case_config.grad_accumulation_sequences, ) experimental_config: dev.TrainConfig = {} step_traces: list[StepTrace] = [] @@ -385,26 +609,36 @@ def _capture_lora_grads() -> None: with _mutation_hook( megatron_train, + model_chunks, request.mutation, pre_optimizer_step_hook=_capture_lora_grads, loss_scale=request.case_config.loss_scale, ): for step_index in range(request.case_config.num_steps): forward_trace_capture.set_step(step_index) - sample_index = step_index % request.packed_tensors.num_sequences - inputs = megatron_train.select_indexed_inputs(packed_tensors, sample_index) + base_sample_index = ( + step_index * request.case_config.grad_accumulation_sequences + ) + micro_sample_indices = [ + (base_sample_index + offset) % request.packed_tensors.num_sequences + for offset in range(request.case_config.grad_accumulation_sequences) + ] + micro_inputs = [ + megatron_train.select_indexed_inputs(packed_tensors, sample_index) + for sample_index in micro_sample_indices + ] captured_grads = None step_result = megatron_train.run_training_step( model_chunks=model_chunks, optimizer=optimizer, learning_rate=train_config.learning_rate, - inputs=inputs, + inputs=micro_inputs, config=train_config, experimental_config=experimental_config, ref_logprobs=None, step_index=step_index, - sample_index=sample_index, + sample_index=micro_sample_indices, moe_routing_replay_controller=runtime.moe_routing_replay_controller, ) forward_trace_capture.save_current_step(traces_dir) @@ -421,6 +655,14 @@ def _capture_lora_grads() -> None: current_lora_state, "current_lora_state" ) deltas = _delta_state(initial_state, current_state) + saved_deltas = _apply_save_mutation_to_tensor_map( + deltas, + mutation=request.mutation, + ) + saved_current_state = _apply_save_mutation_to_tensor_map( + current_state, + mutation=request.mutation, + ) output_rel = Path("traces") / f"output_step_{step_index:03d}.pt" grads_rel = Path("traces") / f"grads_step_{step_index:03d}.safetensors" @@ -434,8 +676,8 @@ def _capture_lora_grads() -> None: topology_dir / output_rel, ) save_file(grads, str(topology_dir / grads_rel)) - save_file(deltas, str(topology_dir / deltas_rel)) - save_file(current_state, str(topology_dir / lora_rel)) + save_file(saved_deltas, str(topology_dir / deltas_rel)) + save_file(saved_current_state, str(topology_dir / lora_rel)) # build and append the step trace step_traces.append( @@ -481,7 +723,6 @@ def _capture_lora_grads() -> None: seed=request.case_config.seed, num_steps=request.case_config.num_steps, packed_tensors=request.packed_tensors, - tolerances=request.case_config.tolerances, steps=step_traces, ) _write_json(topology_dir / "manifest.json", manifest.model_dump(mode="json")) From d396bfd60bad639fc5124db046b3d6e7640508f6 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Sat, 14 Mar 2026 08:21:40 +0000 Subject: [PATCH 15/19] oracle tests: write suite output tables to log files Redirect suite stdout/stderr into local correctness/sensitivity logs and make skip/report messaging point to those artifacts instead of terminal output. --- src/art/megatron/lora.py | 28 +++--- tests/integration/megatron_oracle_worker.py | 25 ++++-- .../test_megatron_lora_oracle_correctness.py | 88 ++++++++++++++++--- 3 files changed, 108 insertions(+), 33 deletions(-) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 247389c3..63f10c85 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -182,9 +182,9 @@ def num_local_experts(self) -> int: return self.A_T.shape[0] if self.A_T.ndim == 3 else 1 def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: - if not param.lora_tp_replicated: + if not param.lora_tp_replicated: # ty: ignore[unresolved-attribute] return - domain = param.lora_shard_domain + domain = param.lora_shard_domain # ty: ignore[unresolved-attribute] world_size = _get_shard_world_size(domain) if world_size <= 1: return @@ -252,9 +252,9 @@ def load_weights( self.load_weight(weight, into=into) def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: - domain = into.lora_shard_domain - if into.lora_tp_sharded: - axis = into.lora_tp_shard_dim + domain = into.lora_shard_domain # ty: ignore[unresolved-attribute] + if into.lora_tp_sharded: # ty: ignore[unresolved-attribute] + axis = into.lora_tp_shard_dim # ty: ignore[unresolved-attribute] axis = _normalize_axis(axis, weight.ndim) world_size = _get_shard_world_size(domain) rank = _get_shard_rank(domain) @@ -291,21 +291,21 @@ def _should_export_parameter(self, param: torch.nn.Parameter) -> bool: return False # this param is fully sharded, all shard ranks participate - if param.lora_tp_sharded: + if param.lora_tp_sharded: # ty: ignore[unresolved-attribute] return True # param is replicated, tp rank 0 or etp rank 0 participates - return _get_shard_rank(param.lora_shard_domain) == 0 + return _get_shard_rank(param.lora_shard_domain) == 0 # ty: ignore[unresolved-attribute] def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: return { - "domain": param.lora_shard_domain, - "sharded": param.lora_tp_sharded, - "shard_dim": param.lora_tp_shard_dim, - "shard_world_size": _get_shard_world_size(param.lora_shard_domain) - if param.lora_tp_sharded + "domain": param.lora_shard_domain, # ty: ignore[unresolved-attribute] + "sharded": param.lora_tp_sharded, # ty: ignore[unresolved-attribute] + "shard_dim": param.lora_tp_shard_dim, # ty: ignore[unresolved-attribute] + "shard_world_size": _get_shard_world_size(param.lora_shard_domain) # ty: ignore[unresolved-attribute] + if param.lora_tp_sharded # ty: ignore[unresolved-attribute] else 1, - "shard_rank": _get_shard_rank(param.lora_shard_domain) - if param.lora_tp_sharded + "shard_rank": _get_shard_rank(param.lora_shard_domain) # ty: ignore[unresolved-attribute] + if param.lora_tp_sharded # ty: ignore[unresolved-attribute] else 0, } diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index 1d734d91..316d39d8 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -342,7 +342,9 @@ def _iter_named_unique_parameters( return params -def _matches_grad_sync_skip_mutation(param_name: str, mutation: SensitivityMutation) -> bool: +def _matches_grad_sync_skip_mutation( + param_name: str, mutation: SensitivityMutation +) -> bool: if mutation == "bwd_skip_sync_qkv_a": return any( token in param_name @@ -381,19 +383,26 @@ def _apply_grad_sync_skip_mutation( if not _matches_grad_sync_skip_mutation(param_name, mutation): continue if ( - mutation == "bwd_skip_sync_fc1_a" - and param.grad_sync_domain != "expert_tp" + mutation == "bwd_skip_sync_fc1_a" and param.grad_sync_domain != "expert_tp" # ty: ignore[unresolved-attribute] ): continue # For fc1 A params, extended finalize handles expert-TP sync via grad_sync_op. - saved_attrs.append((param, "grad_sync_op", param.grad_sync_op)) - param.grad_sync_op = "none" + saved_attrs.append((param, "grad_sync_op", param.grad_sync_op)) # ty: ignore[unresolved-attribute] + param.grad_sync_op = "none" # ty: ignore[unresolved-attribute] # Megatron native TP finalize uses this only for tp_default-domain params. - if param.average_gradients_across_tp_domain and param.grad_sync_domain == "tp_default": - saved_attrs.append((param, "average_gradients_across_tp_domain", param.average_gradients_across_tp_domain)) - param.average_gradients_across_tp_domain = False + average_gradients_across_tp_domain = param.average_gradients_across_tp_domain # ty: ignore[unresolved-attribute] + grad_sync_domain = param.grad_sync_domain # ty: ignore[unresolved-attribute] + if average_gradients_across_tp_domain and grad_sync_domain == "tp_default": + saved_attrs.append( + ( + param, + "average_gradients_across_tp_domain", + average_gradients_across_tp_domain, + ) + ) + param.average_gradients_across_tp_domain = False # ty: ignore[unresolved-attribute] try: yield finally: diff --git a/tests/integration/test_megatron_lora_oracle_correctness.py b/tests/integration/test_megatron_lora_oracle_correctness.py index f7e4dbbc..67c35adb 100644 --- a/tests/integration/test_megatron_lora_oracle_correctness.py +++ b/tests/integration/test_megatron_lora_oracle_correctness.py @@ -1,9 +1,12 @@ +from contextlib import redirect_stderr, redirect_stdout +from pathlib import Path +from typing import Callable + import pytest from .megatron_oracle_harness import ( EXTENDED_TOPOLOGIES, SENSITIVITY_MUTATION_ENV, - SENSITIVITY_TOPOLOGY, TOPOLOGIES, available_gpu_count, case_config, @@ -12,8 +15,33 @@ run_suite, sensitivity_enabled, sensitivity_mutations, + sensitivity_required_world_size, ) +REPO_ROOT = Path(__file__).resolve().parents[2] +CORRECTNESS_LOG_PATH = REPO_ROOT / ".local" / "correctness.log" +SENSITIVITY_LOG_PATH = REPO_ROOT / ".local" / "sensitivity.log" + + +def _run_suite_with_log( + *, + log_path: Path, + run: Callable[[], object], +) -> None: + log_path.parent.mkdir(parents=True, exist_ok=True) + with log_path.open("w", encoding="utf-8") as log_file: + with redirect_stdout(log_file), redirect_stderr(log_file): + run() + + +def _announce_report_log( + *, + log_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + with capsys.disabled(): + print(f"\nMegatron LoRA oracle report log: {log_path}", flush=True) + def _require_gpus_for(topology_world_size: int) -> None: gpu_count = available_gpu_count() @@ -30,31 +58,69 @@ def _suite_world_size() -> int: return max(topology.world_size() for topology in suite_topologies) -def test_megatron_lora_diff_sensitivity() -> None: +def test_megatron_lora_diff_sensitivity(capsys: pytest.CaptureFixture[str]) -> None: """ Runs a each of the sensitivity mutations (e.g. drop megatron finalize grads) and expects each to fail (numerical differences larger than our thresholds) This test ensures we can catch errors we know of (implying we will be able to catch unknown errors as well) """ + _announce_report_log(log_path=SENSITIVITY_LOG_PATH, capsys=capsys) if not sensitivity_enabled(): + SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + SENSITIVITY_LOG_PATH.write_text( + ( + "Sensitivity suite skipped. " + f"Set {SENSITIVITY_MUTATION_ENV}=all (or one mutation / CSV).\n" + ), + encoding="utf-8", + ) pytest.skip( - f"Set {SENSITIVITY_MUTATION_ENV}=drop_finalize (or CSV) to enable sensitivity check." + f"Set {SENSITIVITY_MUTATION_ENV}=all (or one mutation / CSV) to enable sensitivity check." ) - _require_gpus_for(SENSITIVITY_TOPOLOGY.world_size()) mutations = sensitivity_mutations() assert mutations - run_sensitivity_suite( - case_config=case_config(), - mutations=mutations, + sensitivity_world_size = sensitivity_required_world_size(mutations) + gpu_count = available_gpu_count() + if gpu_count < sensitivity_world_size: + SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + SENSITIVITY_LOG_PATH.write_text( + ( + "Sensitivity suite skipped. " + f"Need {sensitivity_world_size} GPUs, found {gpu_count}.\n" + ), + encoding="utf-8", + ) + _require_gpus_for(sensitivity_world_size) + _run_suite_with_log( + log_path=SENSITIVITY_LOG_PATH, + run=lambda: run_sensitivity_suite( + case_config=case_config(), + mutations=mutations, + ), ) -def test_megatron_lora_topology_suite() -> None: +def test_megatron_lora_topology_suite(capsys: pytest.CaptureFixture[str]) -> None: """ Runs the suite of topologies and expects each to pass (numerical differences within our thresholds) """ - _require_gpus_for(_suite_world_size()) - run_suite( - case_config=case_config(), + _announce_report_log(log_path=CORRECTNESS_LOG_PATH, capsys=capsys) + suite_world_size = _suite_world_size() + gpu_count = available_gpu_count() + if gpu_count < suite_world_size: + CORRECTNESS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + CORRECTNESS_LOG_PATH.write_text( + ( + "Topology suite skipped. " + f"Need {suite_world_size} GPUs, found {gpu_count}.\n" + ), + encoding="utf-8", + ) + _require_gpus_for(suite_world_size) + _run_suite_with_log( + log_path=CORRECTNESS_LOG_PATH, + run=lambda: run_suite( + case_config=case_config(), + ), ) From 5385fbbc1f6959d6d5341f8fa304cef57b6370b8 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Mon, 16 Mar 2026 18:40:13 +0000 Subject: [PATCH 16/19] Add correct data parallelism. --- src/art/loss.py | 14 +- src/art/megatron/routing_replay.py | 42 ++++- src/art/megatron/train.py | 180 +++++++++++++++---- tests/integration/megatron_forward_trace.py | 17 +- tests/integration/megatron_oracle_harness.py | 146 +++++++++++---- tests/integration/megatron_oracle_worker.py | 70 ++++++-- 6 files changed, 373 insertions(+), 96 deletions(-) diff --git a/src/art/loss.py b/src/art/loss.py index a22cca3f..7d25a8eb 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from pydantic import BaseModel, ConfigDict import torch @@ -27,6 +27,7 @@ def loss_fn( ref_logprobs: torch.Tensor | None, entropies: torch.Tensor | None, experimental_config: dev.TrainConfig, + reduction: Literal["mean", "sum"] = "mean", ) -> Loss: old_logprobs = shift_tensor(inputs["logprobs"], float("nan")) advantages = shift_tensor(inputs["advantages"], 0.0) @@ -132,14 +133,15 @@ def loss_fn( kl_div = torch.zeros_like(policy_loss) policy_loss = policy_loss * weights * assistant_mask kl_div = kl_div * weights * assistant_mask - mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6) - mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6) + denominator = assistant_mask.sum() + 1e-6 if reduction == "mean" else 1.0 + mean_policy_loss = policy_loss.sum() / denominator + mean_kl = kl_div.sum() / denominator # Compute mean entropy for the current step if entropies is not None: shifted_entropies = shift_tensor(entropies, 0.0) - mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / ( - assistant_mask.sum() + 1e-6 - ) + mean_entropy = ( + shifted_entropies * weights * assistant_mask + ).sum() / denominator else: mean_entropy = None return Loss( diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py index 91865b80..463e5258 100644 --- a/src/art/megatron/routing_replay.py +++ b/src/art/megatron/routing_replay.py @@ -507,6 +507,7 @@ def __init__( self._active_sample_index: int | None = None self._active_step_routes: StepRoutes | None = None self._router_call_cursors: dict[str, int] = {} + self._router_call_limits: dict[str, int] = {} self._global_uid_to_row_index: dict[int, int] = {} self._local_router_keys: set[str] = set() @@ -600,6 +601,8 @@ def remove_router_patches(self) -> None: self._local_router_keys.clear() def set_step(self, *, step_index: int, sample_index: int) -> None: + from megatron.core import parallel_state as ps + if step_index not in self.bundle.steps: raise RuntimeError( f"Replay bundle missing step_index={step_index}. " @@ -615,9 +618,26 @@ def set_step(self, *, step_index: int, sample_index: int) -> None: "Replay bundle step is missing local router key: " f"step={step_index}, router='{local_router_key}'" ) - self._router_call_cursors = { - router_key: 0 for router_key in sorted(self._local_router_keys) - } + dp_world_size = int(ps.get_data_parallel_world_size(with_context_parallel=True)) + dp_rank = int(ps.get_data_parallel_rank(with_context_parallel=True)) + self._router_call_cursors = {} + self._router_call_limits = {} + for router_key in sorted(self._local_router_keys): + total_calls = len(step_routes.routers[router_key].calls) + call_start = 0 + call_limit = total_calls + if dp_world_size > 1: + if total_calls % dp_world_size != 0: + raise RuntimeError( + "Replay router call count is not divisible by DP world size: " + f"step={step_index}, router='{router_key}', " + f"calls={total_calls}, dp_world_size={dp_world_size}" + ) + calls_per_dp_rank = total_calls // dp_world_size + call_start = dp_rank * calls_per_dp_rank + call_limit = call_start + calls_per_dp_rank + self._router_call_cursors[router_key] = call_start + self._router_call_limits[router_key] = call_limit self._global_uid_to_row_index = { int(uid.item()): row_index for row_index, uid in enumerate(step_routes.global_token_uids) @@ -627,9 +647,13 @@ def finalize_step(self) -> None: if self._active_step_routes is None: raise RuntimeError("finalize_step called before set_step") for router_key in sorted(self._local_router_keys): - router_routes = self._active_step_routes.routers[router_key] consumed = self._router_call_cursors.get(router_key, 0) - expected = len(router_routes.calls) + expected = self._router_call_limits.get(router_key) + if expected is None: + raise RuntimeError( + "Routing replay call limits missing for router key: " + f"step={self._active_step_index}, router='{router_key}'" + ) if consumed != expected: raise RuntimeError( "Routing replay step consumption mismatch: " @@ -640,6 +664,7 @@ def finalize_step(self) -> None: self._active_sample_index = None self._active_step_routes = None self._router_call_cursors = {} + self._router_call_limits = {} self._global_uid_to_row_index = {} def get_route_for_router( @@ -652,7 +677,14 @@ def get_route_for_router( ) -> tuple[torch.Tensor, torch.Tensor]: step_routes = self._active_step_routes call_index = self._router_call_cursors.get(router_key, 0) + call_limit = self._router_call_limits.get(router_key) router_calls = step_routes.routers[router_key].calls + if call_limit is not None and call_index >= call_limit: + raise RuntimeError( + "Routing replay call cursor exceeded local call range: " + f"step={self._active_step_index}, router='{router_key}', " + f"call_index={call_index}, limit={call_limit}" + ) route = router_calls[call_index] self._router_call_cursors[router_key] = call_index + 1 diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index a66156bc..c08394e1 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -291,6 +291,7 @@ def collect_sharded_lora_state( return sharded_state_dict, sharded_state_manifest +@torch.no_grad() def select_indexed_inputs(packed_tensors: PackedTensors, index: int) -> PackedTensors: return PackedTensors( # type: ignore[call-arg] **{ @@ -303,6 +304,76 @@ def select_indexed_inputs(packed_tensors: PackedTensors, index: int) -> PackedTe ) +@torch.no_grad() +def _clone_packed_tensors(inputs: PackedTensors) -> PackedTensors: + return PackedTensors( # type: ignore[call-arg] + **{ + key: value.clone() + for key, value in inputs.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=[None], + image_grid_thw=[None], + ) + + +@torch.no_grad() +def _zero_contribution_inputs(template: PackedTensors) -> PackedTensors: + dummy = _clone_packed_tensors(template) + dummy["assistant_mask"].zero_() + return dummy + + +def resolve_local_grad_accumulation_sequences( + global_grad_accumulation_sequences: int, +) -> int: + dp_world_size = ps.get_data_parallel_world_size(with_context_parallel=True) + if ( + global_grad_accumulation_sequences <= 0 + or global_grad_accumulation_sequences % dp_world_size != 0 + ): + raise RuntimeError( + "Invalid global grad accumulation / DP world size combination: " + f"global_grad_accumulation_sequences={global_grad_accumulation_sequences}, " + f"dp_world_size={dp_world_size}" + ) + return global_grad_accumulation_sequences // dp_world_size + + +def build_micro_sample_indices( + step_index: int, + num_sequences: int, + global_grad_accumulation_sequences: int, +) -> list[int | None]: + dp_rank = ps.get_data_parallel_rank(with_context_parallel=True) + local_grad_accumulation_sequences = resolve_local_grad_accumulation_sequences( + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + base_global_sample_index = step_index * global_grad_accumulation_sequences + global_step_indices: list[int | None] = [] + for offset in range(global_grad_accumulation_sequences): + global_sample_index = base_global_sample_index + offset + global_step_indices.append( + global_sample_index if global_sample_index < num_sequences else None + ) + rank_start = dp_rank * local_grad_accumulation_sequences + rank_end = rank_start + local_grad_accumulation_sequences + return global_step_indices[rank_start:rank_end] + + +def select_micro_inputs( + packed_tensors: PackedTensors, + sample_indices: list[int | None], + zero_template: PackedTensors, +) -> list[PackedTensors]: + return [ + _clone_packed_tensors(zero_template) + if sample_index is None + else select_indexed_inputs(packed_tensors, sample_index) + for sample_index in sample_indices + ] + + def _move_inputs_to_device(inputs: PackedTensors, device: torch.device) -> None: for key, value in inputs.items(): if isinstance(value, torch.Tensor): @@ -332,6 +403,46 @@ def _reduce_loss(loss: torch.Tensor) -> torch.Tensor: return reduced_loss +def _count_trainable_tokens(inputs: PackedTensors) -> float: + assistant_mask = shift_tensor(inputs["assistant_mask"], False) + return float(assistant_mask.sum().item()) + + +def _global_token_normalization_scale( + micro_inputs: list[PackedTensors], + device: torch.device, +) -> float: + """ + Data parallel grad normalization scale + dp_world_size / global_micro_batch_token_count, where dp_world_size cancels out + the dp grad averaging, since we divide by global rather than local token count. + Using reduction="sum" and dividing by global token count means each rank is normalized + correctly. + """ + local_token_total = sum(_count_trainable_tokens(micro) for micro in micro_inputs) + dp_world_size = 1 + global_token_total = local_token_total + + dp_world_size = ps.get_data_parallel_world_size(with_context_parallel=True) + if dp_world_size > 1: + dp_group = ps.get_data_parallel_group(with_context_parallel=True) + + global_token_tensor = torch.tensor( + [local_token_total], device=device, dtype=torch.float32 + ) + torch.distributed.all_reduce( + global_token_tensor, + op=torch.distributed.ReduceOp.SUM, + group=dp_group, + ) + global_token_total = float(global_token_tensor.item()) + + if global_token_total <= 0.0: + return 0.0 + + return float(dp_world_size) / global_token_total + + def run_training_step( *, model_chunks: list[MegatronModule], @@ -340,9 +451,9 @@ def run_training_step( inputs: PackedTensors | list[PackedTensors], config: types.TrainConfig, experimental_config: dev.TrainConfig, + step_index: int, + sample_index: int | list[int | None], ref_logprobs: torch.Tensor | None = None, - step_index: int | None = None, - sample_index: int | list[int] | None = None, moe_routing_replay_controller: MoeRoutingReplayController | None = None, ) -> TrainStepResult: micro_inputs = inputs if isinstance(inputs, list) else [inputs] @@ -356,16 +467,18 @@ def run_training_step( f"{len(sample_index)} != {len(micro_inputs)}" ) micro_sample_indices = sample_index - elif sample_index is None: - micro_sample_indices = [0] * len(micro_inputs) else: - micro_sample_indices = [sample_index] * len(micro_inputs) + assert len(micro_inputs) == 1 + micro_sample_indices = [sample_index] if moe_routing_replay_controller is not None: - assert step_index is not None + step_sample_index = next( + (index for index in micro_sample_indices if index is not None), + 0, + ) moe_routing_replay_controller.set_step( step_index=step_index, - sample_index=micro_sample_indices[0], + sample_index=step_sample_index, ) device = next(model_chunks[0].parameters()).device @@ -374,7 +487,8 @@ def run_training_step( chunk.zero_grad_buffer() # ty: ignore[call-non-callable] micro_count = len(micro_inputs) - loss_sum: torch.Tensor | None = None + normalization_scale = _global_token_normalization_scale(micro_inputs, device=device) + normalized_loss: torch.Tensor | None = None probs_corr_sum = 0.0 new_logprobs: torch.Tensor | None = None @@ -400,16 +514,20 @@ def run_training_step( ref_logprobs, None, experimental_config, + reduction="sum", ) - micro_loss = loss_info.mean_policy_loss + config.beta * loss_info.mean_kl - (micro_loss / micro_count).backward() + micro_loss = ( + loss_info.mean_policy_loss + config.beta * loss_info.mean_kl + ) * normalization_scale + micro_loss.backward() probs_corr_sum += float(loss_info.probs_corr.item()) - if loss_sum is None: - loss_sum = micro_loss.detach() + detached_micro_loss = micro_loss.detach() + if normalized_loss is None: + normalized_loss = detached_micro_loss else: - loss_sum = loss_sum + micro_loss.detach() + normalized_loss = normalized_loss + detached_micro_loss - if new_logprobs is None or loss_sum is None: + if new_logprobs is None or normalized_loss is None: raise RuntimeError("run_training_step did not produce outputs") _finalize_grads(model_chunks) @@ -417,7 +535,7 @@ def run_training_step( optimizer, learning_rate, ) - reduced_loss = _reduce_loss(loss_sum / micro_count) + reduced_loss = _reduce_loss(normalized_loss) if moe_routing_replay_controller is not None: moe_routing_replay_controller.finalize_step() @@ -496,26 +614,20 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: runtime.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"] ) packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + template = select_indexed_inputs(packed_tensors, 0) + zero_template = _zero_contribution_inputs(template) num_sequences = job.disk_packed_tensors["num_sequences"] - - dp_rank = ps.get_data_parallel_rank() - dp_world_size = ps.get_data_parallel_world_size() - num_indices = math.ceil(num_sequences / dp_world_size) - indices = list(range(dp_rank, num_sequences, dp_world_size)) - if not indices: - indices = [dp_rank % num_sequences] - repeat = math.ceil(num_indices / len(indices)) - indices = (indices * repeat)[:num_indices] - - grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences)) - for step_index, start in enumerate( - range(0, len(indices), grad_accumulation_sequences) - ): - micro_indices = indices[start : start + grad_accumulation_sequences] - micro_inputs = [ - select_indexed_inputs(packed_tensors, sample_index) - for sample_index in micro_indices - ] + global_grad_accumulation_sequences = config.grad_accumulation_sequences + num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) + for step_index in range(num_steps): + micro_indices = build_micro_sample_indices( + step_index=step_index, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + micro_inputs = select_micro_inputs( + packed_tensors, micro_indices, zero_template + ) step_result = run_training_step( model_chunks=runtime.model, optimizer=runtime.optimizer, diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py index d3befd9f..90c4a3af 100644 --- a/tests/integration/megatron_forward_trace.py +++ b/tests/integration/megatron_forward_trace.py @@ -400,7 +400,11 @@ def _build_moe_row_identities( expert_rows = (local_ids == expert_id).nonzero(as_tuple=False) for token_offset, slot_index in expert_rows.tolist(): local_identities.append( - (expert_id, token_cursor - token_count + token_offset, slot_index) + ( + expert_id, + token_cursor - token_count + token_offset, + slot_index, + ) ) if len(local_identities) != row_count: return None @@ -469,7 +473,9 @@ def _canonicalize_moe_expert_row_order( return tensor row_splits_raw = call.get("primary_output__row_splits") row_splits = ( - [int(v) for v in row_splits_raw] if isinstance(row_splits_raw, list) else None + [int(v) for v in row_splits_raw] + if isinstance(row_splits_raw, list) + else None ) identities = cls._build_moe_row_identities( module_name=module_name, @@ -572,11 +578,6 @@ def _merge_rank_values( and cls._can_cat_along_dim(tensors, dim=preferred_cat_dim) ): return torch.cat(tensors, dim=preferred_cat_dim) - if all( - tensors[0].shape == tensor.shape and torch.equal(tensors[0], tensor) - for tensor in tensors[1:] - ): - return tensors[0] if all(tensor.ndim > 0 for tensor in tensors): if cls._can_cat_along_dim(tensors, dim=0): return torch.cat(tensors, dim=0) @@ -622,7 +623,7 @@ def _merge_rank_values( ) if all(value == values_by_rank[0] for value in values_by_rank[1:]): return values_by_rank[0] - return values_by_rank[0] + return values_by_rank @classmethod def _merge_rank_call_entries( diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index 785e8618..ad19c194 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -37,6 +37,8 @@ "bwd_skip_sync_fc1_a", "save_drop_nonzero_ranked_tp_shards", "save_duplicate_replicated_entries", + "dp_grad_accumulation_seqs", + "dp_local_token_normalization", ) SensitivityMutation = str @@ -118,6 +120,7 @@ class PackedTensorConfig(BaseModel): sequence_length: int = 256 prefill_tokens: int = 64 decode_tokens: int = 64 + decode_tokens_jitter: int = Field(default=32, ge=0) vocab_high: int = 8192 @@ -170,7 +173,7 @@ class OracleCaseConfig(BaseModel): base_model: str num_layers: int = 4 - seed: int = 20260305 + seed: int = 20260304 num_steps: int = 1 grad_accumulation_sequences: int = Field(default=4, ge=1) learning_rate: float = 5e-6 @@ -338,7 +341,16 @@ def layer_averaged_summary(reference_stack, candidate_stack) -> dict[str, float] ref = reference_stack.detach().float() cand = candidate_stack.detach().float() layer_count = int(ref.shape[0]) - metrics = {k: 0.0 for k in ["numel", "mean_abs_diff", "relative_l2", "typical_abs_scale", "mean_abs_pct"]} + metrics = { + k: 0.0 + for k in [ + "numel", + "mean_abs_diff", + "relative_l2", + "typical_abs_scale", + "mean_abs_pct", + ] + } for layer_index in range(layer_count): layer_accumulator = DiffAccumulator() layer_accumulator.update(ref[layer_index], cand[layer_index]) @@ -423,7 +435,13 @@ def _require_not_none(value: T | None, name: str) -> T: SENSITIVITY_TOPOLOGY_BY_MUTATION: dict[SensitivityMutation, Topology] = { mutation: SENSITIVITY_TOPOLOGY for mutation in SUPPORTED_SENSITIVITY_MUTATIONS } -SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology(tp=2, ep=1, etp=2, dp=1, sp=True) +SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology( + tp=2, ep=1, etp=2, dp=1, sp=True +) +SENSITIVITY_TOPOLOGY_BY_MUTATION |= { + k: Topology(tp=1, ep=2, etp=1, dp=2, sp=False) + for k in ["dp_grad_accumulation_seqs", "dp_local_token_normalization"] +} def _truthy(value: str | None) -> bool: @@ -545,6 +563,15 @@ def _build_packed_tensors( dtype=torch.long, generator=generator, ) + # Ensure paired cross-DP rows are never token-identical. + half = config.num_sequences // 2 + if half > 0 and config.num_sequences % 2 == 0: + for pair_index in range(half): + left_index = pair_index + right_index = pair_index + half + if torch.equal(tokens[left_index], tokens[right_index]): + token_span = max(1, config.vocab_high - 10) + tokens[right_index] = ((tokens[right_index] - 10 + 1) % token_span) + 10 group_ids = torch.zeros(shape, dtype=torch.long) parent_ids = torch.full(shape, -1, dtype=torch.long) input_pos = ( @@ -554,17 +581,57 @@ def _build_packed_tensors( .clone() ) prefix_length = max(1, min(config.sequence_length - 1, config.prefill_tokens)) - decode_span = max(1, config.decode_tokens) - cursor = prefix_length - branch = 1 - while cursor < config.sequence_length: - end = min(config.sequence_length, cursor + decode_span) - group_ids[:, cursor:end] = branch - parent_ids[:, cursor:end] = 0 - cursor = end - branch += 1 assistant_mask = torch.zeros(shape, dtype=torch.bool) - assistant_mask[:, prefix_length:] = True + max_decode_tokens = max(1, config.sequence_length - prefix_length) + base_decode_tokens = max(1, min(config.decode_tokens, max_decode_tokens)) + jitter_width = min(config.decode_tokens_jitter, max_decode_tokens - 1) + candidate_decode_lengths: list[int] = [] + for _ in range(config.num_sequences): + if jitter_width > 0: + jitter = int( + torch.randint( + low=-jitter_width, + high=jitter_width + 1, + size=(1,), + generator=generator, + dtype=torch.long, + ).item() + ) + else: + jitter = 0 + decode_length = max( + 1, + min(max_decode_tokens, base_decode_tokens + jitter), + ) + candidate_decode_lengths.append(decode_length) + # Keep jitter local around the configured decode length, but force pairwise + # differences across halves so default DP rank shards see different lengths. + if half > 0 and config.num_sequences % 2 == 0: + for pair_index in range(half): + left_index = pair_index + right_index = pair_index + half + if ( + candidate_decode_lengths[left_index] + != candidate_decode_lengths[right_index] + ): + continue + if candidate_decode_lengths[right_index] < max_decode_tokens: + candidate_decode_lengths[right_index] += 1 + elif candidate_decode_lengths[right_index] > 1: + candidate_decode_lengths[right_index] -= 1 + + for sequence_index, decode_length in enumerate(candidate_decode_lengths): + active_stop = prefix_length + decode_length + assistant_mask[sequence_index, prefix_length:active_stop] = True + decode_span = max(1, min(config.decode_tokens, decode_length)) + cursor = prefix_length + branch = 1 + while cursor < active_stop: + end = min(active_stop, cursor + decode_span) + group_ids[sequence_index, cursor:end] = branch + parent_ids[sequence_index, cursor:end] = 0 + cursor = end + branch += 1 logprobs = ( torch.randn( shape, @@ -619,12 +686,16 @@ def ensure_case_artifacts(case_config: OracleCaseConfig) -> CaseArtifacts: case_dir = ARTIFACT_ROOT / case_id case_dir.mkdir(parents=True, exist_ok=True) _write_json(case_dir / "case_config.json", case_config.model_dump(mode="json")) + regenerate = regenerate_requested() descriptor_path = case_dir / "packed_tensors.json" - if descriptor_path.exists(): + packed_dir = case_dir / "packed_tensors" + if descriptor_path.exists() and not regenerate: packed_spec = DiskPackedTensorsSpec.model_validate(_read_json(descriptor_path)) else: - packed_spec = _create_packed_tensors(case_config, case_dir / "packed_tensors") + if packed_dir.exists(): + shutil.rmtree(packed_dir) + packed_spec = _create_packed_tensors(case_config, packed_dir) _write_json(descriptor_path, packed_spec.model_dump(mode="json")) shared_init_path = case_dir / "shared_init" / "adapter_model.safetensors" @@ -731,8 +802,7 @@ def _stacked_layers( normalized = _layer_agnostic_param_key(name) if normalized is None: raise RuntimeError( - "Expected all compared params to include a layer index, " - f"got '{name}'." + f"Expected all compared params to include a layer index, got '{name}'." ) grouped.setdefault(normalized, []).append( (reference.detach().float(), candidate.detach().float()) @@ -887,7 +957,9 @@ def _apply_phase_pass( explain = getattr(pass_fn, "failure_reasons", None) if callable(explain): reasons = explain(summary) - row.failure_reasons = reasons if reasons else ["phase pass function returned false"] + row.failure_reasons = ( + reasons if reasons else ["phase pass function returned false"] + ) return row.failure_reasons = ["phase pass function returned false"] @@ -978,7 +1050,9 @@ def _build_metric_rows_from_tensor_pairs( accumulator.update_router_ids(reference_aligned, aligned_candidate) summary = accumulator.as_summary() elif layer_averaged: - summary = DiffAccumulator.layer_averaged_summary(reference_aligned, aligned_candidate) + summary = DiffAccumulator.layer_averaged_summary( + reference_aligned, aligned_candidate + ) else: accumulator = DiffAccumulator() accumulator.update(reference_aligned, aligned_candidate) @@ -1307,23 +1381,31 @@ def _default_phase_pass_fns() -> dict[str, PhasePassFn]: """Builds default per-phase pass functions over diff summaries.""" # note the metrics get averaged across layers to reduce noise # we don't expect particular layers to see errors as opposed to the others so this is helpful - fwd_out_loss = MetricThresholdRule(limits={"relative_l2": 3e-2, "mean_abs_pct": 3.0}) + fwd_out_loss = MetricThresholdRule( + limits={"relative_l2": 3e-2, "mean_abs_pct": 3.0} + ) grads = lambda summary: ( summary["mean_abs_pct"] < 5.0 - or (summary["typical_abs_scale"] < 1e-6 and summary["mean_abs_diff"] < 2e-8 and summary["relative_l2"] < 1.0) - ) - deltas = lambda summary: ( - summary["mean_abs_pct"] < 15.0 + or ( + summary["typical_abs_scale"] < 1e-6 + and summary["mean_abs_diff"] < 2e-8 + and summary["relative_l2"] < 1.0 + ) ) - router_topk_rule = MetricThresholdRule( # should be no mismatch due to router replay - limits={ - "topk_mismatch_fraction": 0.0, - "top1_mismatch_fraction": 0.0, - } + deltas = lambda summary: summary["mean_abs_pct"] < 15.0 + router_topk_rule = ( + MetricThresholdRule( # should be no mismatch due to router replay + limits={ + "topk_mismatch_fraction": 0.0, + "top1_mismatch_fraction": 0.0, + } + ) ) - return { - key: fwd_out_loss for key in ["forward", "outputs", "losses"] - } | {"grads": grads, "deltas": deltas, "router_topk_ids": router_topk_rule} + return {key: fwd_out_loss for key in ["forward", "outputs", "losses"]} | { + "grads": grads, + "deltas": deltas, + "router_topk_ids": router_topk_rule, + } def _suite_variants() -> list[VariantSpec]: diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index 316d39d8..d3c5b836 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -475,6 +475,7 @@ def _mutation_hook( megatron_train_module: Any, model_chunks: list[Any], mutation: SensitivityMutation | None, + topology: Topology, pre_optimizer_step_hook: Callable[[], None] | None = None, loss_scale: float = 1.0, ): @@ -482,6 +483,9 @@ def _mutation_hook( original_finalize = megatron_train_module._finalize_grads original_optimizer_step = megatron_train_module._optimizer_step original_loss_fn = megatron_train_module.loss_fn + original_token_normalization_scale = ( + megatron_train_module._global_token_normalization_scale + ) known_mutations = {None, *SUPPORTED_SENSITIVITY_MUTATIONS} if mutation not in known_mutations: @@ -490,6 +494,46 @@ def _mutation_hook( if mutation == "skip_finalize": megatron_train_module._finalize_grads = lambda _model: None + if mutation == "dp_local_token_normalization": + + def _wrong_local_token_normalization_scale( + micro_inputs: list[Any], + device: torch.device, + ) -> float: + del device + local_token_total = sum( + megatron_train_module._count_trainable_tokens(micro) + for micro in micro_inputs + ) + if local_token_total <= 0.0: + return 0.0 + # Intentionally wrong normalization: use only local token total. + dp_world_size = int( + megatron_train_module.ps.get_data_parallel_world_size( + with_context_parallel=True + ) + ) + return float(dp_world_size) / float(local_token_total) + + megatron_train_module._global_token_normalization_scale = ( + _wrong_local_token_normalization_scale + ) + + if mutation == "dp_grad_accumulation_seqs": + + def _wrong_resolve_local_grad_accumulation_sequences( + global_grad_accumulation_sequences: int, + ) -> int: + return megatron_train_module.resolve_local_grad_accumulation_sequences( + global_grad_accumulation_sequences=( + topology.dp * global_grad_accumulation_sequences + ) + ) + + megatron_train_module.resolve_local_grad_accumulation_sequences = ( + _wrong_resolve_local_grad_accumulation_sequences + ) + if pre_optimizer_step_hook is not None: def _patched_optimizer_step(optimizer: Any, learning_rate: float): @@ -531,6 +575,9 @@ def _scaled_loss_fn(*args: Any, **kwargs: Any): megatron_train_module._finalize_grads = original_finalize megatron_train_module._optimizer_step = original_optimizer_step megatron_train_module.loss_fn = original_loss_fn + megatron_train_module._global_token_normalization_scale = ( + original_token_normalization_scale + ) def _worker_run(request: WorkerRunRequest) -> None: @@ -599,13 +646,16 @@ def _worker_run(request: WorkerRunRequest) -> None: packed_tensors = packed_tensors_from_dir( **request.packed_tensors.model_dump(exclude_none=True) ) + template = megatron_train.select_indexed_inputs(packed_tensors, 0) + zero_template = megatron_train._zero_contribution_inputs(template) initial_lora_state = loaded_state + global_grad_accumulation_sequences = request.case_config.grad_accumulation_sequences train_config = types.TrainConfig( learning_rate=request.case_config.learning_rate, beta=request.case_config.beta, kl_penalty_coef=0.0, - grad_accumulation_sequences=request.case_config.grad_accumulation_sequences, + grad_accumulation_sequences=global_grad_accumulation_sequences, ) experimental_config: dev.TrainConfig = {} step_traces: list[StepTrace] = [] @@ -620,22 +670,20 @@ def _capture_lora_grads() -> None: megatron_train, model_chunks, request.mutation, + request.topology, pre_optimizer_step_hook=_capture_lora_grads, loss_scale=request.case_config.loss_scale, ): for step_index in range(request.case_config.num_steps): forward_trace_capture.set_step(step_index) - base_sample_index = ( - step_index * request.case_config.grad_accumulation_sequences + micro_sample_indices = megatron_train.build_micro_sample_indices( + step_index=step_index, + num_sequences=request.packed_tensors.num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + micro_inputs = megatron_train.select_micro_inputs( + packed_tensors, micro_sample_indices, zero_template ) - micro_sample_indices = [ - (base_sample_index + offset) % request.packed_tensors.num_sequences - for offset in range(request.case_config.grad_accumulation_sequences) - ] - micro_inputs = [ - megatron_train.select_indexed_inputs(packed_tensors, sample_index) - for sample_index in micro_sample_indices - ] captured_grads = None step_result = megatron_train.run_training_step( From 75255671712a4e07442e21932933adc3e652d49c Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 17 Mar 2026 20:33:44 +0000 Subject: [PATCH 17/19] Fix per-token DP normalization in Megatron training --- src/art/loss.py | 26 ++++---- src/art/megatron/finalize_grads.py | 18 ++++-- src/art/megatron/lora.py | 2 + src/art/megatron/provider.py | 48 +++++++++++++- src/art/megatron/train.py | 100 ++++++++++++----------------- src/art/tinker/service.py | 2 +- src/art/unsloth/train.py | 10 +-- 7 files changed, 122 insertions(+), 84 deletions(-) diff --git a/src/art/loss.py b/src/art/loss.py index 7d25a8eb..0aab6084 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -13,9 +13,10 @@ class Loss(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - mean_policy_loss: torch.Tensor - mean_kl: torch.Tensor - mean_entropy: torch.Tensor | None + reduction: Literal["mean", "sum"] + policy_loss: torch.Tensor + kl: torch.Tensor + entropy: torch.Tensor | None policy_loss_sum: torch.Tensor probs_corr: torch.Tensor kl_policy_ref: torch.Tensor | None = None @@ -134,20 +135,19 @@ def loss_fn( policy_loss = policy_loss * weights * assistant_mask kl_div = kl_div * weights * assistant_mask denominator = assistant_mask.sum() + 1e-6 if reduction == "mean" else 1.0 - mean_policy_loss = policy_loss.sum() / denominator - mean_kl = kl_div.sum() / denominator - # Compute mean entropy for the current step + reduced_policy_loss = policy_loss.sum() / denominator + kl = kl_div.sum() / denominator + # Compute reduced entropy for the current step. if entropies is not None: shifted_entropies = shift_tensor(entropies, 0.0) - mean_entropy = ( - shifted_entropies * weights * assistant_mask - ).sum() / denominator + entropy = (shifted_entropies * weights * assistant_mask).sum() / denominator else: - mean_entropy = None + entropy = None return Loss( - mean_policy_loss=mean_policy_loss, - mean_kl=mean_kl, - mean_entropy=mean_entropy, + reduction=reduction, + policy_loss=reduced_policy_loss, + kl=kl, + entropy=entropy, policy_loss_sum=policy_loss.sum(), probs_corr=probs_corr, kl_policy_ref=kl_policy_ref, diff --git a/src/art/megatron/finalize_grads.py b/src/art/megatron/finalize_grads.py index 83e8cc4f..6fce32c3 100644 --- a/src/art/megatron/finalize_grads.py +++ b/src/art/megatron/finalize_grads.py @@ -59,7 +59,10 @@ def _resolve_reduce_op(op: GradSyncOp) -> Any: raise RuntimeError(f"Unknown grad sync op: {op}") -def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: +def finalize_model_grads_extended( + model: list[torch.nn.Module], + num_tokens: torch.Tensor | None = None, +) -> None: """Run Megatron finalize, then apply extra LoRA grad-sync reductions. Megatron finalize handles DP/CP(via `param.allreduce=True`)(and expert-DP via `param.allreduce=False`) internally. @@ -68,7 +71,7 @@ def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: """ # All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, # embedding grads across first and last pipeline stages (if not tied) - finalize_model_grads(model) + finalize_model_grads(model, num_tokens=num_tokens) buckets: dict[ tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device], @@ -111,6 +114,13 @@ def finalize_model_grads_extended(model: list[torch.nn.Module]) -> None: grads = [grad for _name, grad in entries] coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce(coalesced, op=_resolve_reduce_op(op), group=group) - for grad, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + reduced = ( + coalesced.float() + if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 + else coalesced + ) + torch.distributed.all_reduce(reduced, op=_resolve_reduce_op(op), group=group) + if reduced is not coalesced: + reduced = reduced.to(dtype=coalesced.dtype) + for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): grad.copy_(synced) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 63f10c85..fd62a249 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -684,6 +684,8 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: base_out, bias_out = self.linear_fc2(x, tokens_per_expert) adapter_out = self.lora(x, tokens_per_expert=tokens_per_expert) + # the reason there is no TP comm here is because the MoE token routing handles + # expert TP comm externally return base_out + adapter_out, bias_out diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index d1029b35..acd2eda1 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -5,6 +5,11 @@ from megatron.bridge import AutoBridge from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.state import ( + SafeTensorsStateSource, + StateDict, + StateSource, +) from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.spec_utils import ModuleSpec @@ -28,15 +33,49 @@ def _resolve_layer_spec( return base_layer_spec(config, **kwargs) -def get_provider(model: str) -> GPTModelProvider: +class _CastingStateSource(StateSource): + def __init__(self, source: StateSource, *, dtype: torch.dtype): + self._source = source + self._dtype = dtype + + def get_all_keys(self) -> list[str]: + return self._source.get_all_keys() + + def load_tensors(self, keys: list[str]) -> dict[str, torch.Tensor]: + loaded = self._source.load_tensors(keys) + return { + key: ( + value.to(dtype=self._dtype) + if torch.is_floating_point(value) and value.dtype != self._dtype + else value + ) + for key, value in loaded.items() + } + + def has_glob(self, pattern: str) -> bool: + return self._source.has_glob(pattern) + + +def get_provider( + model: str, + *, + torch_dtype: torch.dtype = torch.bfloat16, +) -> GPTModelProvider: bridge = AutoBridge.from_hf_pretrained( model, - torch_dtype=torch.bfloat16, + dtype=torch_dtype, trust_remote_code=True, ) assert isinstance(bridge._model_bridge, Qwen3MoEBridge), ( "Only Qwen3 MoE models are supported" ) + if torch_dtype != torch.bfloat16: + bridge.hf_pretrained._state_dict_accessor = StateDict( + _CastingStateSource( + SafeTensorsStateSource(bridge.hf_pretrained.model_name_or_path), + dtype=torch_dtype, + ) + ) provider = bridge.to_megatron_provider() base_layer_spec = provider.transformer_layer_spec @@ -62,6 +101,11 @@ def _flex_attention_layer_spec( provider.expert_tensor_parallel_size = 1 provider.moe_shared_expert_overlap = True provider.moe_router_dtype = "fp32" + # params are disabled anyways, but should know about this if we switch to full FT + # because DP 'dummy' microbatches will unintentionally have loss for this + provider.moe_aux_loss_coeff = 0.0 + # effectively just a flag modifying finalize_model_grads behavior for DPxCP + provider.calculate_per_token_loss = True if provider.tensor_model_parallel_size > 1: provider.sequence_parallel = True return provider diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index c08394e1..a67b6eea 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -173,6 +173,7 @@ def configure_moe_routing_replay( def build_training_runtime( *, model_identifier: str | None = None, + provider_torch_dtype: torch.dtype = torch.bfloat16, provider_configure: Callable[[Any], None] | None = None, optimizer_config: OptimizerConfig | None = None, moe_routing_replay_path: str | None = None, @@ -182,7 +183,9 @@ def build_training_runtime( print_optimizer_stats: bool = True, ) -> TrainingRuntime: provider = get_provider( - model_identifier or os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER) + model_identifier + or os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER), + torch_dtype=provider_torch_dtype, ) if provider_configure is not None: provider_configure(provider) @@ -194,7 +197,11 @@ def build_training_runtime( model = cast( list[MegatronModule], provider.provide_distributed_model( - ddp_config=DistributedDataParallelConfig(), + ddp_config=DistributedDataParallelConfig( + # memory and comm for this should be small anyways cause lora + grad_reduce_in_fp32=True, + average_in_collective=False, + ), data_parallel_random_init=False, ), ) @@ -327,7 +334,7 @@ def _zero_contribution_inputs(template: PackedTensors) -> PackedTensors: def resolve_local_grad_accumulation_sequences( global_grad_accumulation_sequences: int, ) -> int: - dp_world_size = ps.get_data_parallel_world_size(with_context_parallel=True) + dp_world_size = ps.get_data_parallel_world_size() if ( global_grad_accumulation_sequences <= 0 or global_grad_accumulation_sequences % dp_world_size != 0 @@ -345,7 +352,8 @@ def build_micro_sample_indices( num_sequences: int, global_grad_accumulation_sequences: int, ) -> list[int | None]: - dp_rank = ps.get_data_parallel_rank(with_context_parallel=True) + dp_rank = ps.get_data_parallel_rank() + dp_world_size = ps.get_data_parallel_world_size() local_grad_accumulation_sequences = resolve_local_grad_accumulation_sequences( global_grad_accumulation_sequences=global_grad_accumulation_sequences, ) @@ -356,9 +364,10 @@ def build_micro_sample_indices( global_step_indices.append( global_sample_index if global_sample_index < num_sequences else None ) - rank_start = dp_rank * local_grad_accumulation_sequences - rank_end = rank_start + local_grad_accumulation_sequences - return global_step_indices[rank_start:rank_end] + return [ + global_step_indices[offset * dp_world_size + dp_rank] + for offset in range(local_grad_accumulation_sequences) + ] def select_micro_inputs( @@ -380,10 +389,6 @@ def _move_inputs_to_device(inputs: PackedTensors, device: torch.device) -> None: inputs[key] = value.to(device) # type: ignore[index] -def _finalize_grads(model_chunks: list[MegatronModule]) -> None: - finalize_model_grads_extended(cast(list[torch.nn.Module], model_chunks)) - - def _optimizer_step( optimizer: Any, learning_rate: float, @@ -397,9 +402,13 @@ def _optimizer_step( return update_successful, grad_norm, num_zeros_in_grad -def _reduce_loss(loss: torch.Tensor) -> torch.Tensor: +def _reduce_loss( + loss: torch.Tensor, + op: torch.distributed.ReduceOp.RedOpType = torch.distributed.ReduceOp.AVG, + group: torch.distributed.ProcessGroup | None = None, +) -> torch.Tensor: reduced_loss = loss.detach().clone() - torch.distributed.all_reduce(reduced_loss, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(reduced_loss, op=op, group=group) return reduced_loss @@ -408,39 +417,12 @@ def _count_trainable_tokens(inputs: PackedTensors) -> float: return float(assistant_mask.sum().item()) -def _global_token_normalization_scale( +def _local_trainable_token_count_tensor( micro_inputs: list[PackedTensors], device: torch.device, -) -> float: - """ - Data parallel grad normalization scale - dp_world_size / global_micro_batch_token_count, where dp_world_size cancels out - the dp grad averaging, since we divide by global rather than local token count. - Using reduction="sum" and dividing by global token count means each rank is normalized - correctly. - """ +) -> torch.Tensor: local_token_total = sum(_count_trainable_tokens(micro) for micro in micro_inputs) - dp_world_size = 1 - global_token_total = local_token_total - - dp_world_size = ps.get_data_parallel_world_size(with_context_parallel=True) - if dp_world_size > 1: - dp_group = ps.get_data_parallel_group(with_context_parallel=True) - - global_token_tensor = torch.tensor( - [local_token_total], device=device, dtype=torch.float32 - ) - torch.distributed.all_reduce( - global_token_tensor, - op=torch.distributed.ReduceOp.SUM, - group=dp_group, - ) - global_token_total = float(global_token_tensor.item()) - - if global_token_total <= 0.0: - return 0.0 - - return float(dp_world_size) / global_token_total + return torch.tensor([local_token_total], device=device, dtype=torch.float32) def run_training_step( @@ -472,13 +454,10 @@ def run_training_step( micro_sample_indices = [sample_index] if moe_routing_replay_controller is not None: - step_sample_index = next( - (index for index in micro_sample_indices if index is not None), - 0, - ) moe_routing_replay_controller.set_step( step_index=step_index, - sample_index=step_sample_index, + sample_index=micro_sample_indices, + global_grad_accumulation_sequences=config.grad_accumulation_sequences, ) device = next(model_chunks[0].parameters()).device @@ -487,8 +466,8 @@ def run_training_step( chunk.zero_grad_buffer() # ty: ignore[call-non-callable] micro_count = len(micro_inputs) - normalization_scale = _global_token_normalization_scale(micro_inputs, device=device) - normalized_loss: torch.Tensor | None = None + raw_loss_sum: torch.Tensor | None = None + num_tokens = _local_trainable_token_count_tensor(micro_inputs, device=device) probs_corr_sum = 0.0 new_logprobs: torch.Tensor | None = None @@ -516,26 +495,29 @@ def run_training_step( experimental_config, reduction="sum", ) - micro_loss = ( - loss_info.mean_policy_loss + config.beta * loss_info.mean_kl - ) * normalization_scale + micro_loss = loss_info.policy_loss + config.beta * loss_info.kl micro_loss.backward() probs_corr_sum += float(loss_info.probs_corr.item()) detached_micro_loss = micro_loss.detach() - if normalized_loss is None: - normalized_loss = detached_micro_loss + if raw_loss_sum is None: + raw_loss_sum = detached_micro_loss else: - normalized_loss = normalized_loss + detached_micro_loss + raw_loss_sum = raw_loss_sum + detached_micro_loss - if new_logprobs is None or normalized_loss is None: + if new_logprobs is None or raw_loss_sum is None: raise RuntimeError("run_training_step did not produce outputs") - _finalize_grads(model_chunks) + finalize_model_grads_extended(model_chunks, num_tokens=num_tokens) update_successful, grad_norm, num_zeros_in_grad = _optimizer_step( optimizer, learning_rate, ) - reduced_loss = _reduce_loss(normalized_loss) + global_num_tokens = max(num_tokens.item(), 1.0) + reduced_loss = _reduce_loss( + raw_loss_sum / global_num_tokens, + op=torch.distributed.ReduceOp.SUM, + group=ps.get_data_parallel_group(with_context_parallel=True), + ) if moe_routing_replay_controller is not None: moe_routing_replay_controller.finalize_step() diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index ba6768eb..d1bc7444 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -80,7 +80,7 @@ def custom_loss_fn( for mask, lp in zip(masks, logprobs_list): logprobs[mask] = lp loss = loss_fn(inputs, logprobs.unsqueeze(0), None, None, _config) - return loss.mean_policy_loss, {"policy_loss": loss.mean_policy_loss.item()} + return loss.policy_loss, {"policy_loss": loss.policy_loss.item()} shifted_tokens = shift_tensor(packed_tensors["tokens"], 0) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 34dbc5cd..91d76be3 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -170,14 +170,14 @@ def compute_loss( ) trainer._metrics["train"]["learning_rate"].append(config.learning_rate) - trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item()) - if loss.mean_entropy is not None: - trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) + trainer._metrics["train"]["policy_loss"].append(loss.policy_loss.item()) + if loss.entropy is not None: + trainer._metrics["train"]["entropy"].append(loss.entropy.item()) if config.beta > 0.0: - trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item()) + trainer._metrics["train"]["kl_div"].append(loss.kl.item()) if loss.kl_policy_ref is not None: trainer._metrics["train"]["kl_policy_ref"].append(loss.kl_policy_ref.item()) - return loss.mean_policy_loss + config.beta * loss.mean_kl + return loss.policy_loss + config.beta * loss.kl return compute_loss From 7eb96e562acb81badf95888a00f2bacdbac52bd7 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 17 Mar 2026 20:35:23 +0000 Subject: [PATCH 18/19] Expand the oracle harness for DP correctness checks --- src/art/megatron/routing_replay.py | 661 ++++++++++++++++++- tests/integration/megatron_forward_trace.py | 491 ++++++++++---- tests/integration/megatron_oracle_harness.py | 78 +-- tests/integration/megatron_oracle_worker.py | 227 +++++-- 4 files changed, 1209 insertions(+), 248 deletions(-) diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py index 463e5258..104fe185 100644 --- a/src/art/megatron/routing_replay.py +++ b/src/art/megatron/routing_replay.py @@ -1,11 +1,17 @@ from __future__ import annotations +from collections import defaultdict import json from pathlib import Path import re import types from typing import Any, Protocol +from megatron.core.tensor_parallel import ( + all_to_all, + gather_from_sequence_parallel_region, +) +from megatron.core.transformer.moe.moe_utils import permute, sort_chunks_by_idxs from pydantic import BaseModel, ConfigDict, model_validator from safetensors.torch import load_file, save_file import torch @@ -13,6 +19,8 @@ ROUTER_NAME_TOKEN = ".mlp.router" ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1" GLOBAL_TOKEN_UIDS_KEY = "global_token_uids" +TRACE_ROW_TOKEN_UIDS_ATTR = "_art_trace_row_token_uids" +TRACE_UID_SPAN_ATTR = "_art_trace_uid_span" _ROUTER_LAYER_PATTERN = re.compile(r"decoder\.layers\.(?P\d+)\.mlp\.router$") _TRACE_CHUNK_PREFIX_PATTERN = re.compile(r"^chunk(?P\d+)\.(?P.+)$") @@ -69,6 +77,40 @@ def _extract_router_output_tensors(output: Any) -> tuple[torch.Tensor, torch.Ten return probs_2d, routing_map_2d +def _extract_dp_slot_from_rank_meta(rank_meta: Any) -> tuple[int, int] | None: + if isinstance(rank_meta, dict): + rank_meta = [rank_meta] + if not isinstance(rank_meta, list) or not rank_meta: + return None + dp_ranks = { + int(item["dp_rank"]) + for item in rank_meta + if isinstance(item, dict) and "dp_rank" in item + } + dp_world_sizes = { + int(item["dp_world_size"]) + for item in rank_meta + if isinstance(item, dict) and "dp_world_size" in item + } + if len(dp_ranks) != 1 or len(dp_world_sizes) != 1: + return None + return next(iter(dp_ranks)), next(iter(dp_world_sizes)) + + +def _trace_call_route_metadata( + call_entry: dict[str, Any], +) -> tuple[int | None, int | None]: + sample_index = call_entry.get("micro_sample_index") + if isinstance(sample_index, int): + return int(sample_index), None + dp_slot = _extract_dp_slot_from_rank_meta(call_entry.get("rank_meta")) + micro_order = int(call_entry.get("micro_order", 0)) + if dp_slot is None: + return None, micro_order + dp_rank, dp_world_size = dp_slot + return None, micro_order * dp_world_size + dp_rank + + def build_router_key_from_module_name(*, chunk_index: int, module_name: str) -> str: match = _ROUTER_LAYER_PATTERN.search(module_name) if match is None: @@ -114,6 +156,8 @@ class RouterCallRoute(BaseModel): expert_mask: torch.Tensor routing_map: torch.Tensor | None = None num_experts: int + sample_index: int | None = None + micro_slot: int | None = None @model_validator(mode="after") def _validate(self) -> "RouterCallRoute": @@ -146,6 +190,10 @@ def _validate(self) -> "RouterCallRoute": ) if self.num_experts <= 0: raise RuntimeError(f"num_experts must be >0, got {self.num_experts}") + if self.sample_index is not None: + self.sample_index = int(self.sample_index) + if self.micro_slot is not None: + self.micro_slot = int(self.micro_slot) if self.routing_map is not None: expected = (self.expert_indices.shape[0], self.num_experts) if tuple(self.routing_map.shape) != expected: @@ -331,6 +379,8 @@ def from_dir(cls, bundle_dir: str | Path) -> "MoeRoutingReplayBundle": expert_mask=step_tensors[expert_mask_key], routing_map=routing_map, num_experts=int(call_manifest["num_experts"]), + sample_index=call_manifest.get("sample_index"), + micro_slot=call_manifest.get("micro_slot"), ) routers[router_key] = StepRouterRoutes(calls=calls) steps[step_index] = StepRoutes( @@ -385,7 +435,12 @@ def to_dir(self, bundle_dir: str | Path) -> None: ] = _to_tensor_cpu_contiguous( route.routing_map, dtype=torch.bool ) - call_manifest[str(call_index)] = {"num_experts": route.num_experts} + call_entry: dict[str, int] = {"num_experts": route.num_experts} + if route.sample_index is not None: + call_entry["sample_index"] = int(route.sample_index) + if route.micro_slot is not None: + call_entry["micro_slot"] = int(route.micro_slot) + call_manifest[str(call_index)] = call_entry step_manifest_routers[router_key] = call_manifest save_file(step_tensors, str(step_file_path)) manifest_steps[str(step_index)] = { @@ -457,11 +512,182 @@ def build_local_token_uids( return local_uids.reshape(-1).contiguous() +_ACTIVE_ROUTING_REPLAY_CONTROLLER: MoeRoutingReplayController | None = None + + +def _active_routing_replay_controller() -> MoeRoutingReplayController | None: + return _ACTIVE_ROUTING_REPLAY_CONTROLLER + + +def _dispatcher_local_token_uids( + controller: MoeRoutingReplayController, + dispatcher: Any, + *, + num_local_tokens: int, +) -> torch.Tensor: + step_routes = controller._active_step_routes + if step_routes is None: + raise RuntimeError("Routing replay dispatcher used without an active step") + local_uids = controller.local_token_indexer.build_local_token_uids( + global_token_uids=step_routes.global_token_uids, + num_local_tokens=num_local_tokens, + sequence_parallel=bool( + getattr(getattr(dispatcher, "config", None), "sequence_parallel", False) + ), + context_parallel_size=int( + getattr(getattr(dispatcher, "config", None), "context_parallel_size", 1) + ), + ) + if int(local_uids.numel()) != num_local_tokens: + raise RuntimeError( + "Local routing replay uid count mismatch: " + f"expected={num_local_tokens}, got={int(local_uids.numel())}" + ) + sample_index = getattr(controller, "_active_sample_index", None) + uid_span = int(step_routes.global_token_uids.numel()) + if isinstance(sample_index, int) and sample_index >= 0 and uid_span > 0: + local_uids = local_uids + sample_index * uid_span + return local_uids + + +def _trace_row_uids_from_source(source: Any) -> tuple[torch.Tensor | None, int | None]: + row_token_uids = getattr(source, TRACE_ROW_TOKEN_UIDS_ATTR, None) + if not isinstance(row_token_uids, torch.Tensor): + return None, None + uid_span = getattr(source, TRACE_UID_SPAN_ATTR, None) + uid_span_int = uid_span if isinstance(uid_span, int) and uid_span > 0 else None + return row_token_uids, uid_span_int + + +def _attach_trace_row_uids( + target: Any, + *, + row_token_uids: torch.Tensor, + uid_span: int | None, +) -> None: + setattr( + target, + TRACE_ROW_TOKEN_UIDS_ATTR, + row_token_uids.detach().to(device="cpu", dtype=torch.int64).reshape(-1), + ) + setattr(target, TRACE_UID_SPAN_ATTR, uid_span) + + +def _canonicalize_expert_token_order( + expert_inputs: torch.Tensor, + expert_probs: torch.Tensor, + expert_token_uids: torch.Tensor, + *, + tokens_per_expert: torch.Tensor | list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(tokens_per_expert, torch.Tensor): + counts = [int(count) for count in tokens_per_expert.tolist()] + else: + counts = [int(count) for count in tokens_per_expert] + + if sum(counts) != int(expert_token_uids.numel()): + raise RuntimeError( + "Expert token uid count mismatch after dispatch: " + f"uids={int(expert_token_uids.numel())}, " + f"tokens_per_expert_sum={sum(counts)}" + ) + + order_segments: list[torch.Tensor] = [] + cursor = 0 + for count in counts: + if count <= 1: + order_segments.append( + torch.arange(cursor, cursor + count, dtype=torch.long) + ) + cursor += count + continue + segment_uids = expert_token_uids[cursor : cursor + count].to(device="cpu") + segment_order = torch.argsort(segment_uids, stable=True) + cursor + order_segments.append(segment_order) + cursor += count + + if not order_segments: + empty = torch.empty(0, dtype=torch.long) + return expert_inputs, expert_probs, expert_token_uids, empty + + canonical_order_cpu = torch.cat(order_segments, dim=0) + inverse_order_cpu = torch.empty_like(canonical_order_cpu) + inverse_order_cpu[canonical_order_cpu] = torch.arange( + canonical_order_cpu.numel(), dtype=torch.long + ) + + canonical_order = canonical_order_cpu.to( + device=expert_inputs.device, dtype=torch.long + ) + reordered_inputs = expert_inputs.index_select(0, canonical_order) + reordered_probs = expert_probs.index_select(0, canonical_order) + reordered_uids = expert_token_uids.index_select( + 0, + canonical_order_cpu.to(device=expert_token_uids.device, dtype=torch.long), + ) + return ( + reordered_inputs, + reordered_probs, + reordered_uids, + inverse_order_cpu, + ) + + +def _canonical_trace_row_uids( + expert_token_uids: torch.Tensor, + *, + tokens_per_expert: torch.Tensor | list[int], + local_expert_indices: list[int] | tuple[int, ...] | None, + sample_uid_span: int, + num_experts: int, +) -> tuple[torch.Tensor, int]: + if isinstance(tokens_per_expert, torch.Tensor): + counts = [int(count) for count in tokens_per_expert.tolist()] + else: + counts = [int(count) for count in tokens_per_expert] + + expert_indices = ( + [int(expert_index) for expert_index in local_expert_indices] + if local_expert_indices is not None + else list(range(len(counts))) + ) + if len(expert_indices) != len(counts): + raise RuntimeError( + "Local expert index metadata mismatch: " + f"num_expert_indices={len(expert_indices)}, num_counts={len(counts)}" + ) + row_uid_span = sample_uid_span * max(int(num_experts), 1) + row_uid_chunks: list[torch.Tensor] = [] + cursor = 0 + for global_expert_id, count in zip(expert_indices, counts): + count_int = int(count) + segment = expert_token_uids[cursor : cursor + count_int].to(dtype=torch.int64) + sample_ids = torch.div(segment, sample_uid_span, rounding_mode="floor") + local_token_ids = torch.remainder(segment, sample_uid_span) + row_uid_chunks.append( + sample_ids * row_uid_span + + int(global_expert_id) * sample_uid_span + + local_token_ids + ) + cursor += count_int + if cursor != int(expert_token_uids.numel()): + raise RuntimeError( + "Canonical trace row uid construction did not consume all expert rows: " + f"consumed={cursor}, total={int(expert_token_uids.numel())}" + ) + if not row_uid_chunks: + return expert_token_uids.new_empty((0,), dtype=torch.int64), row_uid_span + return torch.cat(row_uid_chunks, dim=0).contiguous(), row_uid_span + + def _patch_alltoall_dispatcher_preprocess() -> None: try: + from megatron.core.transformer.moe.experts import TEGroupedMLP from megatron.core.transformer.moe.token_dispatcher import ( MoEAlltoAllTokenDispatcher, ) + + from art.megatron.lora import MLPExpertsLinearFC2LoRA except Exception: return @@ -469,6 +695,12 @@ def _patch_alltoall_dispatcher_preprocess() -> None: return original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + original_dispatch_preprocess = MoEAlltoAllTokenDispatcher.dispatch_preprocess + original_token_dispatch = MoEAlltoAllTokenDispatcher.token_dispatch + original_dispatch_postprocess = MoEAlltoAllTokenDispatcher.dispatch_postprocess + original_combine_preprocess = MoEAlltoAllTokenDispatcher.combine_preprocess + original_te_grouped_mlp_forward = TEGroupedMLP.forward + original_fc2_forward = MLPExpertsLinearFC2LoRA.forward def patched_preprocess( self: Any, routing_map: torch.Tensor, *args: Any, **kwargs: Any @@ -485,7 +717,212 @@ def patched_preprocess( self.num_out_tokens = int(routing_map.sum().item()) return result + def patched_dispatch_preprocess( + self: Any, + hidden_states: torch.Tensor, + routing_map: torch.Tensor, + probs: torch.Tensor, + ): + result = original_dispatch_preprocess(self, hidden_states, routing_map, probs) + self._art_replay_permuted_local_token_uids = None + self._art_replay_global_input_token_uids = None + self._art_replay_expert_input_inverse_permutation = None + + controller = _active_routing_replay_controller() + if controller is None: + return result + + local_token_uids = _dispatcher_local_token_uids( + controller, + self, + num_local_tokens=int(routing_map.shape[0]), + ) + permuted_local_uids, _, _ = permute( + local_token_uids.to( + device=hidden_states.device, dtype=torch.int64 + ).unsqueeze(-1), + self.routing_map, + num_out_tokens=self.num_out_tokens, + fused=False, + drop_and_pad=self.drop_and_pad, + ) + self._art_replay_permuted_local_token_uids = permuted_local_uids.reshape( + -1 + ).contiguous() + return result + + def patched_token_dispatch( + self: Any, + permutated_local_input_tokens: torch.Tensor, + permuted_probs: torch.Tensor, + ): + result = original_token_dispatch( + self, + permutated_local_input_tokens, + permuted_probs, + ) + controller = _active_routing_replay_controller() + permuted_local_token_uids = getattr( + self, "_art_replay_permuted_local_token_uids", None + ) + if controller is None or permuted_local_token_uids is None: + return result + + global_token_uids = permuted_local_token_uids.to( + device=permutated_local_input_tokens.device, dtype=torch.int64 + ).unsqueeze(-1) + if self.ep_size > 1: + global_token_uids = all_to_all( + self.ep_group, + global_token_uids, + self.output_splits, + self.input_splits, + ) + if self.tp_size > 1: + output_split_sizes = ( + None + if self.output_splits_tp is None + else self.output_splits_tp.tolist() + ) + global_token_uids = gather_from_sequence_parallel_region( + global_token_uids, + group=self.tp_group, + output_split_sizes=output_split_sizes, + ) + self._art_replay_global_input_token_uids = global_token_uids.reshape( + -1 + ).contiguous() + return result + + def patched_dispatch_postprocess( + self: Any, + global_input_tokens: torch.Tensor, + global_probs: torch.Tensor, + ): + expert_inputs, tokens_per_expert, expert_probs = original_dispatch_postprocess( + self, + global_input_tokens, + global_probs, + ) + controller = _active_routing_replay_controller() + global_input_token_uids = getattr( + self, "_art_replay_global_input_token_uids", None + ) + if controller is None or global_input_token_uids is None or self.drop_and_pad: + return expert_inputs, tokens_per_expert, expert_probs + + expert_token_uids = global_input_token_uids + if self.num_local_experts > 1: + sorted_token_uids, _ = sort_chunks_by_idxs( + expert_token_uids.unsqueeze(-1), + self.num_global_tokens_per_local_expert.ravel(), + self.sort_input_by_local_experts, + fused=False, + ) + expert_token_uids = sorted_token_uids.reshape(-1).contiguous() + + ( + expert_inputs, + expert_probs, + canonical_expert_token_uids, + inverse_order_cpu, + ) = _canonicalize_expert_token_order( + expert_inputs, + expert_probs, + expert_token_uids, + tokens_per_expert=tokens_per_expert, + ) + self._art_replay_expert_input_inverse_permutation = inverse_order_cpu + trace_row_uids, trace_uid_span = _canonical_trace_row_uids( + canonical_expert_token_uids, + tokens_per_expert=tokens_per_expert, + local_expert_indices=getattr(self, "local_expert_indices", None), + sample_uid_span=int( + controller._active_step_routes.global_token_uids.numel() + ), + num_experts=int(getattr(self, "num_experts", 1)), + ) + _attach_trace_row_uids( + expert_inputs, + row_token_uids=trace_row_uids, + uid_span=trace_uid_span, + ) + return expert_inputs, tokens_per_expert, expert_probs + + def patched_combine_preprocess(self: Any, hidden_states: torch.Tensor): + inverse_order_cpu = getattr( + self, "_art_replay_expert_input_inverse_permutation", None + ) + if inverse_order_cpu is not None and inverse_order_cpu.numel() > 0: + hidden_states = hidden_states.index_select( + 0, + inverse_order_cpu.to(device=hidden_states.device, dtype=torch.long), + ) + self._art_replay_expert_input_inverse_permutation = None + return original_combine_preprocess(self, hidden_states) + + def patched_te_grouped_mlp_forward( + self: Any, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + permuted_probs: torch.Tensor, + ): + row_token_uids, uid_span = _trace_row_uids_from_source( + permuted_local_hidden_states + ) + if row_token_uids is not None: + _attach_trace_row_uids( + self.linear_fc2, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + return original_te_grouped_mlp_forward( + self, + permuted_local_hidden_states, + tokens_per_expert, + permuted_probs, + ) + + def patched_fc2_forward( + self: Any, + x: torch.Tensor, + tokens_per_expert: list[int] | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + row_token_uids, uid_span = _trace_row_uids_from_source(x) + if row_token_uids is None: + row_token_uids, uid_span = _trace_row_uids_from_source(self) + if row_token_uids is not None: + _attach_trace_row_uids( + self.linear_fc2, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + _attach_trace_row_uids( + self.lora, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + return original_fc2_forward(self, x, tokens_per_expert) + setattr(MoEAlltoAllTokenDispatcher, "preprocess", patched_preprocess) + setattr( + MoEAlltoAllTokenDispatcher, + "dispatch_preprocess", + patched_dispatch_preprocess, + ) + setattr(MoEAlltoAllTokenDispatcher, "token_dispatch", patched_token_dispatch) + setattr( + MoEAlltoAllTokenDispatcher, + "dispatch_postprocess", + patched_dispatch_postprocess, + ) + setattr( + MoEAlltoAllTokenDispatcher, + "combine_preprocess", + patched_combine_preprocess, + ) + setattr(TEGroupedMLP, "forward", patched_te_grouped_mlp_forward) + setattr(MLPExpertsLinearFC2LoRA, "forward", patched_fc2_forward) setattr(MoEAlltoAllTokenDispatcher, "_art_router_replay_preprocess_patched", True) @@ -507,9 +944,10 @@ def __init__( self._active_sample_index: int | None = None self._active_step_routes: StepRoutes | None = None self._router_call_cursors: dict[str, int] = {} - self._router_call_limits: dict[str, int] = {} + self._router_call_sequences: dict[str, list[int]] = {} self._global_uid_to_row_index: dict[int, int] = {} self._local_router_keys: set[str] = set() + self._active_micro_order: int | None = None self._patched_router_modules: list[dict[str, Any]] = [] @@ -592,6 +1030,7 @@ def routing_wrapper( ) def remove_router_patches(self) -> None: + global _ACTIVE_ROUTING_REPLAY_CONTROLLER for item in self._patched_router_modules: module = item["module"] module.routing = item["original_routing"] @@ -599,9 +1038,21 @@ def remove_router_patches(self) -> None: delattr(module, "_art_router_replay_patched") self._patched_router_modules.clear() self._local_router_keys.clear() + if _ACTIVE_ROUTING_REPLAY_CONTROLLER is self: + _ACTIVE_ROUTING_REPLAY_CONTROLLER = None - def set_step(self, *, step_index: int, sample_index: int) -> None: - from megatron.core import parallel_state as ps + def begin_micro(self, sample_index: int | None, micro_order: int) -> None: + self._active_sample_index = sample_index + self._active_micro_order = micro_order + + def set_step( + self, + *, + step_index: int, + sample_index: int | list[int | None], + global_grad_accumulation_sequences: int | None = None, + ) -> None: + global _ACTIVE_ROUTING_REPLAY_CONTROLLER if step_index not in self.bundle.steps: raise RuntimeError( @@ -610,7 +1061,14 @@ def set_step(self, *, step_index: int, sample_index: int) -> None: ) step_routes = self.bundle.steps[step_index] self._active_step_index = step_index - self._active_sample_index = sample_index + if isinstance(sample_index, list): + self._active_sample_index = next( + (index for index in sample_index if index is not None), + None, + ) + else: + self._active_sample_index = sample_index + self._active_micro_order = None self._active_step_routes = step_routes for local_router_key in sorted(self._local_router_keys): if local_router_key not in step_routes.routers: @@ -618,54 +1076,177 @@ def set_step(self, *, step_index: int, sample_index: int) -> None: "Replay bundle step is missing local router key: " f"step={step_index}, router='{local_router_key}'" ) - dp_world_size = int(ps.get_data_parallel_world_size(with_context_parallel=True)) - dp_rank = int(ps.get_data_parallel_rank(with_context_parallel=True)) self._router_call_cursors = {} - self._router_call_limits = {} + self._router_call_sequences = {} + local_call_keys = self._build_local_call_keys( + sample_index=sample_index, + ) for router_key in sorted(self._local_router_keys): - total_calls = len(step_routes.routers[router_key].calls) - call_start = 0 - call_limit = total_calls - if dp_world_size > 1: - if total_calls % dp_world_size != 0: - raise RuntimeError( - "Replay router call count is not divisible by DP world size: " - f"step={step_index}, router='{router_key}', " - f"calls={total_calls}, dp_world_size={dp_world_size}" - ) - calls_per_dp_rank = total_calls // dp_world_size - call_start = dp_rank * calls_per_dp_rank - call_limit = call_start + calls_per_dp_rank - self._router_call_cursors[router_key] = call_start - self._router_call_limits[router_key] = call_limit + router_calls = step_routes.routers[router_key].calls + if all( + self._router_call_key(route) is not None + for route in router_calls.values() + ): + calls_by_key: dict[tuple[str, int], list[int]] = defaultdict(list) + for call_index, route in sorted(router_calls.items()): + call_key = self._router_call_key(route) + assert call_key is not None + calls_by_key[call_key].append(call_index) + call_sequence = [] + for call_key in local_call_keys: + if call_key is None: + continue + matching_call_indices = calls_by_key.get(call_key) + if not matching_call_indices: + raise RuntimeError( + "Replay router call sequence is missing local micro metadata: " + f"step={step_index}, router='{router_key}', call_key={call_key}" + ) + call_sequence.extend(matching_call_indices) + else: + call_sequence = self._legacy_router_call_sequence( + step_index=step_index, + router_key=router_key, + sample_index=sample_index, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + total_calls=len(router_calls), + ) + self._router_call_cursors[router_key] = 0 + self._router_call_sequences[router_key] = call_sequence self._global_uid_to_row_index = { int(uid.item()): row_index for row_index, uid in enumerate(step_routes.global_token_uids) } + _ACTIVE_ROUTING_REPLAY_CONTROLLER = self + + def _build_local_call_keys( + self, + *, + sample_index: int | list[int | None], + ) -> list[tuple[str, int] | None]: + if not isinstance(sample_index, list): + if sample_index is None: + return [self._dummy_micro_call_key(local_micro_index=0)] + return [("sample", int(sample_index))] + return [ + self._sample_or_dummy_call_key( + global_sample_index=global_sample_index, + local_micro_index=local_micro_index, + ) + for local_micro_index, global_sample_index in enumerate(sample_index) + ] + + def _sample_or_dummy_call_key( + self, + *, + global_sample_index: int | None, + local_micro_index: int, + ) -> tuple[str, int] | None: + if global_sample_index is not None: + return ("sample", int(global_sample_index)) + return self._dummy_micro_call_key(local_micro_index=local_micro_index) + + def _dummy_micro_call_key( + self, + *, + local_micro_index: int, + ) -> tuple[str, int]: + from megatron.core import parallel_state as ps + + dp_rank = int(ps.get_data_parallel_rank()) + dp_world_size = int(ps.get_data_parallel_world_size()) + micro_slot = local_micro_index * dp_world_size + dp_rank + return ("dummy_micro_slot", micro_slot) + + @staticmethod + def _router_call_key(route: RouterCallRoute) -> tuple[str, int] | None: + if route.sample_index is not None: + return ("sample", int(route.sample_index)) + if route.micro_slot is not None: + return ("dummy_micro_slot", int(route.micro_slot)) + return None + + @staticmethod + def _legacy_router_call_sequence( + *, + step_index: int, + router_key: str, + sample_index: int | list[int | None], + global_grad_accumulation_sequences: int | None, + total_calls: int, + ) -> list[int]: + step_sample_count = global_grad_accumulation_sequences + if step_sample_count is None: + if isinstance(sample_index, list): + step_sample_count = len( + [index for index in sample_index if index is not None] + ) + else: + step_sample_count = 1 + if step_sample_count <= 0 or total_calls % step_sample_count != 0: + raise RuntimeError( + "Replay router call count is not divisible by step sample count: " + f"step={step_index}, router='{router_key}', " + f"total_calls={total_calls}, step_sample_count={step_sample_count}" + ) + calls_per_sample = total_calls // step_sample_count + step_base_sample_index = step_index * step_sample_count + if isinstance(sample_index, list): + call_sequence: list[int] = [] + for global_sample_index in sample_index: + if global_sample_index is None: + continue + sample_offset = int(global_sample_index) - step_base_sample_index + if sample_offset < 0 or sample_offset >= step_sample_count: + raise RuntimeError( + "Replay router call index is outside the step-local range: " + f"step={step_index}, router='{router_key}', " + f"global_sample_index={global_sample_index}, " + f"step_base_sample_index={step_base_sample_index}, " + f"step_sample_count={step_sample_count}" + ) + call_start = sample_offset * calls_per_sample + call_sequence.extend(range(call_start, call_start + calls_per_sample)) + return call_sequence + + sample_offset = int(sample_index) - step_base_sample_index + if sample_offset < 0 or sample_offset >= step_sample_count: + raise RuntimeError( + "Replay router call index is outside the step-local range: " + f"step={step_index}, router='{router_key}', " + f"sample_index={sample_index}, " + f"step_sample_count={step_sample_count}" + ) + call_start = sample_offset * calls_per_sample + return list(range(call_start, call_start + calls_per_sample)) def finalize_step(self) -> None: + global _ACTIVE_ROUTING_REPLAY_CONTROLLER if self._active_step_routes is None: raise RuntimeError("finalize_step called before set_step") for router_key in sorted(self._local_router_keys): consumed = self._router_call_cursors.get(router_key, 0) - expected = self._router_call_limits.get(router_key) - if expected is None: + call_sequence = self._router_call_sequences.get(router_key) + if call_sequence is None: raise RuntimeError( - "Routing replay call limits missing for router key: " + "Routing replay call sequence missing for router key: " f"step={self._active_step_index}, router='{router_key}'" ) - if consumed != expected: + if consumed != len(call_sequence): raise RuntimeError( "Routing replay step consumption mismatch: " f"step={self._active_step_index}, router='{router_key}', " - f"consumed={consumed}, expected={expected}" + f"consumed={consumed}, expected={len(call_sequence)}" ) self._active_step_index = None self._active_sample_index = None self._active_step_routes = None self._router_call_cursors = {} - self._router_call_limits = {} + self._router_call_sequences = {} self._global_uid_to_row_index = {} + self._active_micro_order = None + if _ACTIVE_ROUTING_REPLAY_CONTROLLER is self: + _ACTIVE_ROUTING_REPLAY_CONTROLLER = None def get_route_for_router( self, @@ -676,17 +1257,22 @@ def get_route_for_router( context_parallel_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: step_routes = self._active_step_routes - call_index = self._router_call_cursors.get(router_key, 0) - call_limit = self._router_call_limits.get(router_key) + call_cursor = self._router_call_cursors.get(router_key, 0) + call_sequence = self._router_call_sequences.get(router_key) + if call_sequence is None: + raise RuntimeError( + "Routing replay call sequence missing for router key: " + f"step={self._active_step_index}, router='{router_key}'" + ) router_calls = step_routes.routers[router_key].calls - if call_limit is not None and call_index >= call_limit: + if call_cursor >= len(call_sequence): raise RuntimeError( - "Routing replay call cursor exceeded local call range: " + "Routing replay call cursor exceeded local call sequence: " f"step={self._active_step_index}, router='{router_key}', " - f"call_index={call_index}, limit={call_limit}" + f"call_cursor={call_cursor}, sequence_length={len(call_sequence)}" ) - route = router_calls[call_index] - self._router_call_cursors[router_key] = call_index + 1 + route = router_calls[call_sequence[call_cursor]] + self._router_call_cursors[router_key] = call_cursor + 1 num_local_tokens = int(logits.shape[0]) num_experts = int(logits.shape[1]) @@ -813,6 +1399,9 @@ def build_bundle_from_forward_trace_dir( output = call_entry.get("output") probs_2d, routing_map_2d = _extract_router_output_tensors(output) compact_route = _compact_route_from_dense(probs_2d, routing_map_2d) + sample_index, micro_slot = _trace_call_route_metadata(call_entry) + compact_route.sample_index = sample_index + compact_route.micro_slot = micro_slot router_calls[call_index] = compact_route max_topk = max(max_topk, compact_route.max_topk) token_count = compact_route.num_global_tokens diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py index 90c4a3af..5e36fc87 100644 --- a/tests/integration/megatron_forward_trace.py +++ b/tests/integration/megatron_forward_trace.py @@ -1,7 +1,8 @@ from __future__ import annotations +import os from pathlib import Path -from typing import Any, cast +from typing import Any, Callable, cast import torch @@ -65,6 +66,58 @@ def _rank_metadata() -> dict[str, int]: } +def _extract_dp_slot_from_rank_meta(rank_meta: Any) -> tuple[int, int] | None: + """Returns one stable `(dp_rank, dp_world_size)` pair from merged rank metadata.""" + if isinstance(rank_meta, dict): + rank_meta = [rank_meta] + if not isinstance(rank_meta, list) or not rank_meta: + return None + dp_ranks = { + _safe_int(item.get("dp_rank"), 0) + for item in rank_meta + if isinstance(item, dict) and "dp_rank" in item + } + dp_world_sizes = { + _safe_int(item.get("dp_world_size"), 1) + for item in rank_meta + if isinstance(item, dict) and "dp_world_size" in item + } + if len(dp_ranks) != 1 or len(dp_world_sizes) != 1: + return None + return next(iter(dp_ranks)), next(iter(dp_world_sizes)) + + +def _trace_call_sort_key(call: dict[str, Any]) -> tuple[int, int]: + """Builds a stable micro identity for merged trace ordering.""" + sample_index = call.get("micro_sample_index") + if isinstance(sample_index, int): + return 0, int(sample_index) + micro_order = _safe_int(call.get("micro_order"), 0) + dp_slot = _extract_dp_slot_from_rank_meta(call.get("rank_meta")) + if dp_slot is None: + return 1, micro_order + dp_rank, dp_world_size = dp_slot + return 1, micro_order * dp_world_size + dp_rank + + +def _local_dummy_micro_slot(micro_order: int) -> int: + """Builds the stable dummy-micro slot used when one micro has no sample id.""" + dp_rank = _safe_ps_stat("get_data_parallel_rank", 0) + dp_world_size = _safe_ps_stat("get_data_parallel_world_size", 1) + return micro_order * dp_world_size + dp_rank + + +def _captured_output_sort_key( + sample_index: int | None, + micro_order: int, + micro_slot: int | None, +) -> tuple[int, int, int]: + """Builds the deterministic ordering used for captured root outputs.""" + if isinstance(sample_index, int): + return 0, int(sample_index), micro_order + return 1, _safe_int(micro_slot, micro_order), 0 + + def _shard_world_size_for_domain(domain: Any) -> int: """Returns shard-group world size for one LoRA shard domain.""" if domain == "tp": @@ -117,6 +170,22 @@ def _materialize_trace_value(value: Any) -> Any: return value +def _extract_tensor_attr(value: Any, attr_name: str) -> Any: + if isinstance(value, torch.Tensor): + return getattr(value, attr_name, None) + if isinstance(value, dict): + for item in value.values(): + attr_value = _extract_tensor_attr(item, attr_name) + if attr_value is not None: + return attr_value + if isinstance(value, (list, tuple)): + for item in value: + attr_value = _extract_tensor_attr(item, attr_name) + if attr_value is not None: + return attr_value + return None + + def _extract_router_topk(output: Any) -> tuple[torch.Tensor, torch.Tensor] | None: if not isinstance(output, tuple) or len(output) < 2: return None @@ -144,17 +213,36 @@ def __init__( *, enabled: bool, capture_name_tokens: tuple[str, ...] = CAPTURE_NAME_TOKENS, + micro_start_callback: Callable[[int | None, int], None] | None = None, ) -> None: self.enabled = enabled self.capture_name_tokens = capture_name_tokens + self.micro_start_callback = micro_start_callback self.current_step_index: int | None = None self.current_step_trace: dict[str, list[dict[str, Any]]] = {} + self.current_micro_sample_index: int | None = None + self.current_micro_order = 0 + self.current_micro_module_call_counts: dict[str, int] = {} + self.current_step_sample_indices: list[int | None] = [] + self.current_step_outputs: list[ + tuple[int | None, int, int | None, torch.Tensor] + ] = [] + self._next_micro_order = 0 self._hook_handles: list[Any] = [] if not enabled: return self._register_hooks(model_chunks) def _register_hooks(self, model_chunks: list[Any]) -> None: + if not model_chunks: + raise RuntimeError("Expected at least one model chunk for forward tracing") + root_module = model_chunks[0] + self._hook_handles.append( + root_module.register_forward_pre_hook(self._root_pre_hook) + ) + self._hook_handles.append( + root_module.register_forward_hook(self._root_post_hook) + ) for chunk_index, chunk in enumerate(model_chunks): for module_name, module in chunk.named_modules(): trace_module_name = f"chunk{chunk_index}.{module_name}" @@ -275,9 +363,12 @@ def _make_hook(self, name: str, module: Any): def _hook(_module: Any, inputs: Any, output: Any) -> None: if self.current_step_index is None: return - call_index = len(self.current_step_trace.get(name, [])) + micro_call_index = self.current_micro_module_call_counts.get(name, 0) + self.current_micro_module_call_counts[name] = micro_call_index + 1 trace_item: dict[str, Any] = { - "call_index": call_index, + "micro_call_index": micro_call_index, + "micro_order": self.current_micro_order, + "micro_sample_index": self.current_micro_sample_index, "module_type": module.__class__.__name__, "rank_meta": _rank_metadata(), "merge_hints": self._build_merge_hints(name, module), @@ -292,7 +383,16 @@ def _hook(_module: Any, inputs: Any, output: Any) -> None: topk_ids, topk_scores = router_topk trace_item["router_topk_ids"] = topk_ids trace_item["router_topk_scores"] = topk_scores - self.current_step_trace.setdefault(name, []).append(trace_item) + trace_items = self._split_expert_trace_items( + module_name=name, + module=module, + inputs=inputs, + trace_item=trace_item, + ) + trace_calls = self.current_step_trace.setdefault(name, []) + for split_item in trace_items: + split_item["call_index"] = len(trace_calls) + trace_calls.append(split_item) return _hook @@ -303,9 +403,185 @@ def guess_primary_tensor(value: Any) -> torch.Tensor | None: return None return _materialize_tensor(tensor) - def set_step(self, step_index: int) -> None: + def _sample_index_for_micro(self, micro_order: int) -> int | None: + if micro_order < len(self.current_step_sample_indices): + return self.current_step_sample_indices[micro_order] + return None + + def _root_pre_hook(self, _module: Any, _args: Any) -> None: + if self.current_step_index is None: + return + micro_order = self._next_micro_order + sample_index = self._sample_index_for_micro(micro_order) + self.begin_micro(sample_index=sample_index, micro_order=micro_order) + + def _root_post_hook(self, _module: Any, _inputs: Any, output: Any) -> None: + if self.current_step_index is None: + return + output_tensor = self.guess_primary_tensor(output) + if output_tensor is None: + raise RuntimeError( + f"Expected root forward output to contain a tensor, got {type(output)}" + ) + sample_index = self.current_micro_sample_index + micro_order = self.current_micro_order + self.current_step_outputs.append( + ( + sample_index, + micro_order, + None + if sample_index is not None + else _local_dummy_micro_slot(micro_order), + output_tensor.float(), + ) + ) + self._next_micro_order = micro_order + 1 + + def set_step( + self, + step_index: int, + sample_indices: list[int | None] | None = None, + ) -> None: self.current_step_index = step_index self.current_step_trace = {} + self.current_step_sample_indices = list(sample_indices or []) + self.current_step_outputs = [] + self.current_micro_sample_index = None + self.current_micro_order = 0 + self.current_micro_module_call_counts = {} + self._next_micro_order = 0 + + def begin_micro(self, sample_index: int | None, micro_order: int) -> None: + self.current_micro_sample_index = sample_index + self.current_micro_order = micro_order + self.current_micro_module_call_counts = {} + if self.micro_start_callback is not None: + self.micro_start_callback(sample_index, micro_order) + + @staticmethod + def _row_token_uids_for_trace( + *, + inputs: Any, + module: Any, + ) -> tuple[torch.Tensor | None, int | None]: + row_token_uids = _extract_tensor_attr(inputs, "_art_trace_row_token_uids") + if row_token_uids is None: + row_token_uids = getattr(module, "_art_trace_row_token_uids", None) + if not isinstance(row_token_uids, torch.Tensor): + return None, None + + uid_span = _extract_tensor_attr(inputs, "_art_trace_uid_span") + if uid_span is None: + uid_span = getattr(module, "_art_trace_uid_span", None) + uid_span_int = uid_span if isinstance(uid_span, int) and uid_span > 0 else None + return ( + row_token_uids.detach().to(device="cpu", dtype=torch.int64).reshape(-1), + uid_span_int, + ) + + @classmethod + def _slice_row_aligned_value( + cls, + value: Any, + *, + row_indices: torch.Tensor, + total_rows: int, + ) -> Any: + if isinstance(value, torch.Tensor): + if value.ndim > 0 and int(value.shape[0]) == total_rows: + return value.index_select(0, row_indices) + return value + if isinstance(value, dict): + return { + key: cls._slice_row_aligned_value( + item, + row_indices=row_indices, + total_rows=total_rows, + ) + for key, item in value.items() + } + if isinstance(value, list): + return [ + cls._slice_row_aligned_value( + item, + row_indices=row_indices, + total_rows=total_rows, + ) + for item in value + ] + if isinstance(value, tuple): + return tuple( + cls._slice_row_aligned_value( + item, + row_indices=row_indices, + total_rows=total_rows, + ) + for item in value + ) + return value + + @classmethod + def _split_expert_trace_items( + cls, + *, + module_name: str, + module: Any, + inputs: Any, + trace_item: dict[str, Any], + ) -> list[dict[str, Any]]: + if not cls._is_moe_expert_forward_module(module_name): + return [trace_item] + + primary_output = trace_item.get("primary_output") + if not isinstance(primary_output, torch.Tensor) or primary_output.ndim == 0: + return [trace_item] + + row_token_uids, uid_span = cls._row_token_uids_for_trace( + inputs=inputs, + module=module, + ) + if row_token_uids is None: + return [trace_item] + + total_rows = int(row_token_uids.numel()) + if total_rows == 0 or int(primary_output.shape[0]) != total_rows: + return [trace_item] + + trace_item["row_token_uids"] = row_token_uids + if uid_span is None: + return [trace_item] + + sample_ids = torch.div(row_token_uids, uid_span, rounding_mode="floor") + ordered_sample_ids: list[int] = [] + seen_sample_ids: set[int] = set() + for sample_id in sample_ids.tolist(): + sample_id_int = int(sample_id) + if sample_id_int in seen_sample_ids: + continue + seen_sample_ids.add(sample_id_int) + ordered_sample_ids.append(sample_id_int) + + if len(ordered_sample_ids) <= 1: + if ordered_sample_ids: + trace_item["micro_sample_index"] = ordered_sample_ids[0] + return [trace_item] + + split_items: list[dict[str, Any]] = [] + for sample_id in ordered_sample_ids: + row_indices = (sample_ids == sample_id).nonzero(as_tuple=False).reshape(-1) + split_item = { + key: cls._slice_row_aligned_value( + value, + row_indices=row_indices, + total_rows=total_rows, + ) + for key, value in trace_item.items() + if key not in {"call_index", "micro_sample_index", "row_token_uids"} + } + split_item["micro_sample_index"] = sample_id + split_item["row_token_uids"] = row_token_uids.index_select(0, row_indices) + split_items.append(split_item) + return split_items @staticmethod def _is_moe_expert_forward_module(module_name: str) -> bool: @@ -327,93 +603,6 @@ def _primary_output_merge_hint(call: dict[str, Any]) -> dict[str, Any] | None: return None return primary_hint - @classmethod - def _lookup_call_by_index( - cls, - trace: dict[str, list[dict[str, Any]]], - module_name: str, - call_index: int, - ) -> dict[str, Any] | None: - """Finds one call entry by call-index with positional fallback.""" - calls = trace.get(module_name) - if calls is None: - return None - for call in calls: - if int(call.get("call_index", -1)) == call_index: - return call - if 0 <= call_index < len(calls): - return calls[call_index] - return None - - @staticmethod - def _router_module_name_for_expert_module(module_name: str) -> str | None: - """Maps one expert module name to its layer router module name.""" - for token in (".mlp.experts.linear_fc1", ".mlp.experts.linear_fc2"): - token_index = module_name.find(token) - if token_index != -1: - return f"{module_name[:token_index]}.mlp.router" - return None - - @classmethod - def _build_moe_row_identities( - cls, - *, - module_name: str, - call_index: int, - trace: dict[str, list[dict[str, Any]]], - row_splits: list[int] | None, - ) -> list[tuple[int, int, int]] | None: - """Builds stable `(expert_id, token_index, topk_slot)` identities for MoE rows.""" - router_module_name = cls._router_module_name_for_expert_module(module_name) - if router_module_name is None: - return None - router_call = cls._lookup_call_by_index(trace, router_module_name, call_index) - if router_call is None: - return None - router_topk_ids = router_call.get("router_topk_ids") - if not isinstance(router_topk_ids, torch.Tensor) or router_topk_ids.ndim != 2: - return None - token_splits_raw = router_call.get("router_topk_ids__row_splits") - if row_splits is None: - if isinstance(token_splits_raw, list): - row_splits = [ - int(v) * int(router_topk_ids.shape[1]) for v in token_splits_raw - ] - else: - row_splits = [int(router_topk_ids.numel())] - if isinstance(token_splits_raw, list): - token_splits = [int(v) for v in token_splits_raw] - else: - topk = int(router_topk_ids.shape[1]) - token_splits = [int(v) // topk for v in row_splits] - if len(row_splits) != len(token_splits): - return None - row_cursor = 0 - token_cursor = 0 - identities: list[tuple[int, int, int]] = [] - for row_count, token_count in zip(row_splits, token_splits): - local_ids = router_topk_ids[token_cursor : token_cursor + token_count] - token_cursor += token_count - local_identities: list[tuple[int, int, int]] = [] - max_expert = int(local_ids.max().item()) if local_ids.numel() > 0 else -1 - for expert_id in range(max_expert + 1): - expert_rows = (local_ids == expert_id).nonzero(as_tuple=False) - for token_offset, slot_index in expert_rows.tolist(): - local_identities.append( - ( - expert_id, - token_cursor - token_count + token_offset, - slot_index, - ) - ) - if len(local_identities) != row_count: - return None - identities.extend(local_identities) - row_cursor += row_count - if row_cursor != sum(row_splits): - return None - return identities - @classmethod def _canonicalize_etp_fc1_feature_layout( cls, @@ -456,12 +645,10 @@ def _canonicalize_moe_expert_row_order( cls, *, module_name: str, - call_index: int, tensor: torch.Tensor, - trace: dict[str, list[dict[str, Any]]], call: dict[str, Any], ) -> torch.Tensor: - """Canonicalizes MoE expert-row ordering using router replay identities.""" + """Canonicalizes MoE expert rows using dispatch-time UID identities.""" if not cls._is_moe_expert_forward_module(module_name): return tensor if tensor.ndim != 2: @@ -471,34 +658,23 @@ def _canonicalize_moe_expert_row_order( primary_hint.get("op") != "concat" or primary_hint.get("dim") != 0 ): return tensor - row_splits_raw = call.get("primary_output__row_splits") - row_splits = ( - [int(v) for v in row_splits_raw] - if isinstance(row_splits_raw, list) - else None - ) - identities = cls._build_moe_row_identities( - module_name=module_name, - call_index=call_index, - trace=trace, - row_splits=row_splits, - ) - if identities is None or len(identities) != int(tensor.shape[0]): + row_token_uids = call.get("row_token_uids") + if not isinstance(row_token_uids, torch.Tensor): return tensor - order = sorted(range(len(identities)), key=lambda index: identities[index]) - return tensor[order] + if int(row_token_uids.numel()) != int(tensor.shape[0]): + return tensor + order = torch.argsort(row_token_uids, stable=True) + return tensor.index_select(0, order) @classmethod def _canonicalize_primary_output_tensor( cls, *, module_name: str, - call_index: int, tensor: torch.Tensor, - trace: dict[str, list[dict[str, Any]]], call: dict[str, Any], ) -> torch.Tensor: - """Runs all primary-output canonicalization passes for one call tensor.""" + """Runs all remaining primary-output canonicalization passes for one call.""" tensor = cls._canonicalize_etp_fc1_feature_layout( module_name=module_name, tensor=tensor, @@ -506,9 +682,7 @@ def _canonicalize_primary_output_tensor( ) return cls._canonicalize_moe_expert_row_order( module_name=module_name, - call_index=call_index, tensor=tensor, - trace=trace, call=call, ) @@ -528,9 +702,7 @@ def canonicalize_trace( if isinstance(tensor, torch.Tensor): call["primary_output"] = cls._canonicalize_primary_output_tensor( module_name=module_name, - call_index=call_index, tensor=tensor, - trace=trace, call=call, ) call[PRIMARY_OUTPUT_CANONICAL_KEY] = True @@ -705,17 +877,25 @@ def _merge_rank_traces( merged: dict[str, list[dict[str, Any]]] = {} module_names = sorted(set().union(*(trace.keys() for trace in rank_traces))) for module_name in module_names: - call_count = max(len(trace.get(module_name, [])) for trace in rank_traces) module_calls: list[dict[str, Any]] = [] - for call_index in range(call_count): - rank_values = [ - trace[module_name][call_index] - for trace in rank_traces - if module_name in trace and call_index < len(trace[module_name]) - ] - if not rank_values: - continue - module_calls.append(cls._merge_rank_call_entries(rank_values)) + grouped_calls: dict[ + tuple[int, int, int, int], + list[dict[str, Any]], + ] = {} + for trace in rank_traces: + for call in trace.get(module_name, []): + sample_kind, sample_sort_index = _trace_call_sort_key(call) + merge_key = ( + sample_kind, + sample_sort_index, + int(call.get("micro_order", 0)), + int(call.get("micro_call_index", call.get("call_index", 0))), + ) + grouped_calls.setdefault(merge_key, []).append(call) + for merged_index, merge_key in enumerate(sorted(grouped_calls)): + merged_call = cls._merge_rank_call_entries(grouped_calls[merge_key]) + merged_call["call_index"] = merged_index + module_calls.append(merged_call) merged[module_name] = module_calls return merged @@ -736,6 +916,59 @@ def _gather_rank_traces( return None return cast(list[dict[str, list[dict[str, Any]]]], gathered) + @staticmethod + def _merge_group_tensor(tensors: list[torch.Tensor]) -> torch.Tensor: + if len(tensors) == 1: + return tensors[0] + first = tensors[0] + if all(tensor.shape == first.shape for tensor in tensors[1:]) and all( + torch.equal(first, tensor) for tensor in tensors[1:] + ): + return first + raise RuntimeError( + "Mismatched output captures for the same micro output across non-DP ranks" + ) + + @staticmethod + def _gather_rank_outputs( + local_outputs: list[tuple[int | None, int, int | None, torch.Tensor]], + ) -> list[list[tuple[int | None, int, int | None, torch.Tensor]]] | None: + if ( + not torch.distributed.is_initialized() + or torch.distributed.get_world_size() == 1 + ): + return [local_outputs] + gathered: list[ + list[tuple[int | None, int, int | None, torch.Tensor]] | None + ] = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(gathered, local_outputs) + if torch.distributed.get_rank() != 0: + return None + return cast( + list[list[tuple[int | None, int, int | None, torch.Tensor]]], + gathered, + ) + + def ordered_step_outputs(self) -> list[torch.Tensor] | None: + if not self.enabled: + return None + gathered_outputs = self._gather_rank_outputs(self.current_step_outputs) + if gathered_outputs is None: + return None + grouped: dict[tuple[int | None, int | None, int], list[torch.Tensor]] = {} + for rank_outputs in gathered_outputs: + for sample_index, micro_order, micro_slot, tensor in rank_outputs: + group_key = (sample_index, micro_slot, micro_order) + grouped.setdefault(group_key, []).append(tensor) + ordered_group_keys = sorted( + grouped, + key=lambda item: _captured_output_sort_key(item[0], item[2], item[1]), + ) + return [ + self._merge_group_tensor(grouped[group_key]) + for group_key in ordered_group_keys + ] + def save_current_step(self, traces_dir: Path) -> Path | None: if not self.enabled or self.current_step_index is None: return None @@ -745,7 +978,9 @@ def save_current_step(self, traces_dir: Path) -> Path | None: merged_trace = self.canonicalize_trace(self._merge_rank_traces(gathered_traces)) traces_dir.mkdir(parents=True, exist_ok=True) trace_path = traces_dir / f"forward_trace_step_{self.current_step_index:03d}.pt" - torch.save(merged_trace, trace_path) + tmp_trace_path = trace_path.with_suffix(f"{trace_path.suffix}.tmp") + torch.save(merged_trace, tmp_trace_path) + os.replace(tmp_trace_path, trace_path) return trace_path @classmethod diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index ad19c194..033cd5b9 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -113,6 +113,31 @@ def world_size(self) -> int: return attention_world +TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=1, sp=False), + Topology(tp=2, ep=1, etp=1, dp=1, sp=True), + Topology(tp=2, ep=2, etp=1, dp=1, sp=True), + Topology(tp=2, ep=1, etp=2, dp=1, sp=True), +] +EXTENDED_TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=2, sp=False), + Topology(tp=1, ep=2, etp=1, dp=2, sp=False), + Topology(tp=1, ep=1, etp=2, dp=2, sp=True), +] +ORACLE_TOPOLOGY = TOPOLOGIES[0] +SENSITIVITY_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) +SENSITIVITY_TOPOLOGY_BY_MUTATION: dict[SensitivityMutation, Topology] = { + mutation: SENSITIVITY_TOPOLOGY for mutation in SUPPORTED_SENSITIVITY_MUTATIONS +} +SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology( + tp=2, ep=1, etp=2, dp=1, sp=True +) +SENSITIVITY_TOPOLOGY_BY_MUTATION |= { + k: Topology(tp=1, ep=2, etp=1, dp=2, sp=False) + for k in ["dp_grad_accumulation_seqs", "dp_local_token_normalization"] +} + + class PackedTensorConfig(BaseModel): """Controls synthetic packed tensor generation used by oracle harness runs.""" @@ -172,6 +197,7 @@ class OracleCaseConfig(BaseModel): """Contains all deterministic run parameters for one oracle case.""" base_model: str + precision: Literal["bf16", "fp32"] = "fp32" num_layers: int = 4 seed: int = 20260304 num_steps: int = 1 @@ -420,30 +446,6 @@ def _require_not_none(value: T | None, name: str) -> T: return value -TOPOLOGIES = [ - Topology(tp=1, ep=1, etp=1, dp=1, sp=False), - Topology(tp=2, ep=1, etp=1, dp=1, sp=True), - Topology(tp=2, ep=2, etp=1, dp=1, sp=True), - Topology(tp=2, ep=1, etp=2, dp=1, sp=True), -] -EXTENDED_TOPOLOGIES = [ - Topology(tp=1, ep=1, etp=1, dp=2, sp=False), - Topology(tp=1, ep=2, etp=1, dp=2, sp=True), -] -ORACLE_TOPOLOGY = TOPOLOGIES[0] -SENSITIVITY_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) -SENSITIVITY_TOPOLOGY_BY_MUTATION: dict[SensitivityMutation, Topology] = { - mutation: SENSITIVITY_TOPOLOGY for mutation in SUPPORTED_SENSITIVITY_MUTATIONS -} -SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology( - tp=2, ep=1, etp=2, dp=1, sp=True -) -SENSITIVITY_TOPOLOGY_BY_MUTATION |= { - k: Topology(tp=1, ep=2, etp=1, dp=2, sp=False) - for k in ["dp_grad_accumulation_seqs", "dp_local_token_normalization"] -} - - def _truthy(value: str | None) -> bool: """Parses env-var style booleans using a small accepted truthy set.""" if value is None: @@ -798,6 +800,7 @@ def _stacked_layers( import torch grouped: dict[str, list[tuple[Any, Any]]] = {} + original_names_by_group: dict[str, list[str]] = {} for name, reference, candidate in pairs: normalized = _layer_agnostic_param_key(name) if normalized is None: @@ -807,10 +810,18 @@ def _stacked_layers( grouped.setdefault(normalized, []).append( (reference.detach().float(), candidate.detach().float()) ) + original_names_by_group.setdefault(normalized, []).append(name) stacked_pairs: list[tuple[str, Any, Any]] = [] for normalized in sorted(grouped): group = grouped[normalized] + reference_shapes = {tuple(reference.shape) for reference, _ in group} + candidate_shapes = {tuple(candidate.shape) for _, candidate in group} + if len(reference_shapes) != 1 or len(candidate_shapes) != 1: + original_names = original_names_by_group[normalized] + for original_name, (reference, candidate) in zip(original_names, group): + stacked_pairs.append((original_name, reference, candidate)) + continue stacked_pairs.append( ( normalized, @@ -840,7 +851,7 @@ def __init__( self.case_dir / ORACLE_MOE_ROUTING_BUNDLE_DIRNAME ) self.shared_init_path = Path(self.case_artifacts.shared_init_adapter_path) - self.console = console or Console(width=160) + self.console = console or Console(width=140) self._oracle_initialized = False self._oracle_regenerated = False @@ -1318,7 +1329,6 @@ def print_report(self, report: VariantReport) -> None: detail_table.add_column("mean_abs_pct", justify="right") detail_table.add_column("typical_abs", justify="right") detail_table.add_column("mean_abs_diff", justify="right") - # detail_table.add_column("Thresholds") detail_table.add_column("Failure") sorted_rows = sorted( table_rows, @@ -1382,17 +1392,9 @@ def _default_phase_pass_fns() -> dict[str, PhasePassFn]: # note the metrics get averaged across layers to reduce noise # we don't expect particular layers to see errors as opposed to the others so this is helpful fwd_out_loss = MetricThresholdRule( - limits={"relative_l2": 3e-2, "mean_abs_pct": 3.0} - ) - grads = lambda summary: ( - summary["mean_abs_pct"] < 5.0 - or ( - summary["typical_abs_scale"] < 1e-6 - and summary["mean_abs_diff"] < 2e-8 - and summary["relative_l2"] < 1.0 - ) + limits={"relative_l2": 1e-2, "mean_abs_pct": 1.0} ) - deltas = lambda summary: summary["mean_abs_pct"] < 15.0 + grads_deltas = MetricThresholdRule(limits={"mean_abs_pct": 10.0}) router_topk_rule = ( MetricThresholdRule( # should be no mismatch due to router replay limits={ @@ -1402,8 +1404,8 @@ def _default_phase_pass_fns() -> dict[str, PhasePassFn]: ) ) return {key: fwd_out_loss for key in ["forward", "outputs", "losses"]} | { - "grads": grads, - "deltas": deltas, + "grads": grads_deltas, + "deltas": grads_deltas, "router_topk_ids": router_topk_rule, } diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index d3c5b836..33f3c08a 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -144,7 +144,9 @@ def _collect_lora_state( return _gather_full_state(local_state) -def _collect_lora_grads(model_chunks: list[Any]) -> dict[str, Any] | None: +def _collect_lora_grads( + model_chunks: list[Any], +) -> dict[str, Any] | None: """Collects full LoRA gradient tensors across all ranks.""" from art.megatron.lora import LoRA @@ -163,11 +165,8 @@ def _collect_lora_grads(model_chunks: list[Any]) -> dict[str, Any] | None: raise RuntimeError(f"LoRA param main_grad is None for key '{key}'") if hasattr(grad, "_local_tensor"): grad = grad._local_tensor - local_grads[key] = ( - grad[expert].detach().cpu().T - if expert is not None - else grad.detach().cpu().T - ) + captured_grad = grad[expert] if expert is not None else grad + local_grads[key] = captured_grad.detach().cpu().T return _gather_full_state(local_grads) @@ -259,12 +258,20 @@ def _configure_provider( provider.tensor_model_parallel_size = topology.tp provider.expert_model_parallel_size = topology.ep provider.expert_tensor_parallel_size = topology.etp - # These are intentionally pinned to 1 for now; switching to topology-driven - # values is the single lever to start CP/PP coverage in the harness. + # These are intentionally pinned to 1 for now provider.pipeline_model_parallel_size = 1 provider.context_parallel_size = 1 provider.sequence_parallel = topology.sp provider.num_layers = case_config.num_layers + if case_config.precision == "fp32": + provider.bf16 = False + provider.fp16 = False + provider.params_dtype = torch.float32 + provider.pipeline_dtype = torch.float32 + provider.enable_autocast = False + provider.autocast_dtype = None + provider.attention_softmax_in_fp32 = True + provider.fp32_residual_connection = True if hasattr(provider, "attention_dropout"): provider.attention_dropout = 0.0 if hasattr(provider, "hidden_dropout"): @@ -275,8 +282,26 @@ def _build_optimizer_config(case_config: OracleCaseConfig): """Builds Megatron optimizer settings for deterministic harness runs.""" from megatron.core.optimizer import OptimizerConfig + if case_config.precision == "fp32": + return OptimizerConfig( + bf16=False, + fp16=False, + params_dtype=torch.float32, + main_grads_dtype=torch.float32, + main_params_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + lr=case_config.learning_rate, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.0, + adam_eps=1e-13, + ) + return OptimizerConfig( bf16=True, + fp16=False, lr=case_config.learning_rate, adam_beta1=0.9, adam_beta2=0.99, @@ -286,6 +311,14 @@ def _build_optimizer_config(case_config: OracleCaseConfig): ) +def _configure_cuda_precision(case_config: OracleCaseConfig) -> None: + if case_config.precision != "fp32": + return + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.set_float32_matmul_precision("highest") + + def _assert_runtime_configuration( model_chunks: list[Any], case_config: OracleCaseConfig, @@ -470,6 +503,72 @@ def _mutated_forward(self: Any, x: Any): module.forward = original_forward +@contextmanager +def _patch_lora_for_fp32( + model_chunks: list[Any], + optimizer: Any, +): + """ + torch grouped_gemm is bf16 only, so we have a simple custom fp32 path + to make the numbers match closely + """ + from art.megatron.lora import LoRA + + del model_chunks + del optimizer + original_forward = LoRA.forward + + def _reference_forward( + self: Any, + x: torch.Tensor, + tokens_per_expert: list[int] | torch.Tensor | None = None, + ) -> torch.Tensor: + work_dtype = ( + torch.float32 + if torch.is_floating_point(x) and x.dtype != torch.float32 + else x.dtype + ) + work_x = x.to(dtype=work_dtype) + work_a = self.A_T.to(dtype=work_dtype) + work_b = self.B_T.to(dtype=work_dtype) + + if tokens_per_expert is None or self.num_local_experts == 1: + return (((work_x @ work_a) @ work_b) * self.scale).to(dtype=x.dtype) + + counts = ( + tokens_per_expert.tolist() + if isinstance(tokens_per_expert, torch.Tensor) + else list(tokens_per_expert) + ) + out = work_x.new_zeros((work_x.shape[0], work_b.shape[-1])) + + cursor = 0 + for expert_index, count in enumerate(counts): + count_int = int(count) + if count_int <= 0: + continue + next_cursor = cursor + count_int + x_chunk = work_x[cursor:next_cursor] + out[cursor:next_cursor] = (x_chunk @ work_a[expert_index]) @ work_b[ + expert_index + ] + cursor = next_cursor + + if cursor != int(work_x.shape[0]): + raise RuntimeError( + "Expert LoRA reference path did not consume all grouped rows: " + f"consumed={cursor}, rows={int(work_x.shape[0])}" + ) + + return (out * self.scale).to(dtype=x.dtype) + + LoRA.forward = _reference_forward + try: + yield + finally: + LoRA.forward = original_forward + + @contextmanager def _mutation_hook( megatron_train_module: Any, @@ -480,11 +579,14 @@ def _mutation_hook( loss_scale: float = 1.0, ): """Applies optional sensitivity mutation hooks around training steps.""" - original_finalize = megatron_train_module._finalize_grads + original_finalize = megatron_train_module.finalize_model_grads_extended original_optimizer_step = megatron_train_module._optimizer_step original_loss_fn = megatron_train_module.loss_fn - original_token_normalization_scale = ( - megatron_train_module._global_token_normalization_scale + original_local_token_count_tensor = ( + megatron_train_module._local_trainable_token_count_tensor + ) + original_build_micro_sample_indices = ( + megatron_train_module.build_micro_sample_indices ) known_mutations = {None, *SUPPORTED_SENSITIVITY_MUTATIONS} @@ -492,46 +594,55 @@ def _mutation_hook( raise ValueError(f"Unsupported mutation: {mutation}") if mutation == "skip_finalize": - megatron_train_module._finalize_grads = lambda _model: None + megatron_train_module.finalize_model_grads_extended = ( + lambda _model, **_kwargs: (None) + ) if mutation == "dp_local_token_normalization": - def _wrong_local_token_normalization_scale( + def _wrong_local_trainable_token_count_tensor( micro_inputs: list[Any], device: torch.device, - ) -> float: - del device + ) -> torch.Tensor: local_token_total = sum( megatron_train_module._count_trainable_tokens(micro) for micro in micro_inputs ) - if local_token_total <= 0.0: - return 0.0 - # Intentionally wrong normalization: use only local token total. dp_world_size = int( megatron_train_module.ps.get_data_parallel_world_size( with_context_parallel=True ) ) - return float(dp_world_size) / float(local_token_total) + wrong_local_token_total = local_token_total / max(dp_world_size, 1) + return torch.tensor( + [wrong_local_token_total], + device=device, + dtype=torch.float32, + ) - megatron_train_module._global_token_normalization_scale = ( - _wrong_local_token_normalization_scale + megatron_train_module._local_trainable_token_count_tensor = ( + _wrong_local_trainable_token_count_tensor ) if mutation == "dp_grad_accumulation_seqs": - def _wrong_resolve_local_grad_accumulation_sequences( + def _wrong_build_micro_sample_indices( + *, + step_index: int, + num_sequences: int, global_grad_accumulation_sequences: int, - ) -> int: - return megatron_train_module.resolve_local_grad_accumulation_sequences( - global_grad_accumulation_sequences=( - topology.dp * global_grad_accumulation_sequences + ) -> list[int | None]: + base_global_sample_index = step_index * global_grad_accumulation_sequences + return [ + (global_sample_index if global_sample_index < num_sequences else None) + for global_sample_index in range( + base_global_sample_index, + base_global_sample_index + global_grad_accumulation_sequences, ) - ) + ] - megatron_train_module.resolve_local_grad_accumulation_sequences = ( - _wrong_resolve_local_grad_accumulation_sequences + megatron_train_module.build_micro_sample_indices = ( + _wrong_build_micro_sample_indices ) if pre_optimizer_step_hook is not None: @@ -554,8 +665,8 @@ def _scaled_loss_fn(*args: Any, **kwargs: Any): loss = original_loss_fn(*args, **kwargs) return loss.model_copy( update={ - "mean_policy_loss": loss.mean_policy_loss * effective_loss_scale, - "mean_kl": loss.mean_kl * effective_loss_scale, + "policy_loss": loss.policy_loss * effective_loss_scale, + "kl": loss.kl * effective_loss_scale, "policy_loss_sum": loss.policy_loss_sum * effective_loss_scale, } ) @@ -572,11 +683,14 @@ def _scaled_loss_fn(*args: Any, **kwargs: Any): try: yield finally: - megatron_train_module._finalize_grads = original_finalize + megatron_train_module.finalize_model_grads_extended = original_finalize megatron_train_module._optimizer_step = original_optimizer_step megatron_train_module.loss_fn = original_loss_fn - megatron_train_module._global_token_normalization_scale = ( - original_token_normalization_scale + megatron_train_module._local_trainable_token_count_tensor = ( + original_local_token_count_tensor + ) + megatron_train_module.build_micro_sample_indices = ( + original_build_micro_sample_indices ) @@ -593,9 +707,13 @@ def _worker_run(request: WorkerRunRequest) -> None: torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl") _set_deterministic_seed(request.case_config.seed) + _configure_cuda_precision(request.case_config) runtime = megatron_train.build_training_runtime( model_identifier=request.case_config.base_model, + provider_torch_dtype=( + torch.float32 if request.case_config.precision == "fp32" else torch.bfloat16 + ), provider_configure=lambda provider: _configure_provider( provider, request.topology, request.case_config ), @@ -660,27 +778,40 @@ def _worker_run(request: WorkerRunRequest) -> None: experimental_config: dev.TrainConfig = {} step_traces: list[StepTrace] = [] captured_grads: dict[str, Any] | None = None - forward_trace_capture = ForwardTraceCapture(model_chunks, enabled=True) + routing_replay_controller = runtime.moe_routing_replay_controller + micro_start_callback = ( + routing_replay_controller.begin_micro + if routing_replay_controller is not None + else None + ) + forward_trace_capture = ForwardTraceCapture( + model_chunks, + enabled=True, + micro_start_callback=micro_start_callback, + ) def _capture_lora_grads() -> None: nonlocal captured_grads captured_grads = _collect_lora_grads(model_chunks) - with _mutation_hook( - megatron_train, - model_chunks, - request.mutation, - request.topology, - pre_optimizer_step_hook=_capture_lora_grads, - loss_scale=request.case_config.loss_scale, + with ( + _mutation_hook( + megatron_train, + model_chunks, + request.mutation, + request.topology, + pre_optimizer_step_hook=_capture_lora_grads, + loss_scale=request.case_config.loss_scale, + ), + _patch_lora_for_fp32(model_chunks, optimizer), ): for step_index in range(request.case_config.num_steps): - forward_trace_capture.set_step(step_index) micro_sample_indices = megatron_train.build_micro_sample_indices( step_index=step_index, num_sequences=request.packed_tensors.num_sequences, global_grad_accumulation_sequences=global_grad_accumulation_sequences, ) + forward_trace_capture.set_step(step_index, micro_sample_indices) micro_inputs = megatron_train.select_micro_inputs( packed_tensors, micro_sample_indices, zero_template ) @@ -698,12 +829,12 @@ def _capture_lora_grads() -> None: sample_index=micro_sample_indices, moe_routing_replay_controller=runtime.moe_routing_replay_controller, ) + ordered_micro_outputs = forward_trace_capture.ordered_step_outputs() forward_trace_capture.save_current_step(traces_dir) torch.distributed.barrier() current_lora_state = _collect_lora_state(model_chunks) if torch.distributed.get_rank() == 0: - # save artifacts (outputs, grads, lora deltas, current lora) grads = _require_not_none(captured_grads, "captured_grads") initial_state = _require_not_none( initial_lora_state, "initial_lora_state" @@ -727,16 +858,20 @@ def _capture_lora_grads() -> None: Path("traces") / f"deltas_step_{step_index:03d}.safetensors" ) lora_rel = Path(f"lora_step_{step_index:03d}.safetensors") + ordered_outputs = _require_not_none( + ordered_micro_outputs, "ordered_micro_outputs" + ) + if not ordered_outputs: + raise RuntimeError("Expected at least one captured micro output") torch.save( - step_result.new_logprobs.detach().cpu().float(), + torch.stack(ordered_outputs, dim=0), topology_dir / output_rel, ) save_file(grads, str(topology_dir / grads_rel)) save_file(saved_deltas, str(topology_dir / deltas_rel)) save_file(saved_current_state, str(topology_dir / lora_rel)) - # build and append the step trace step_traces.append( StepTrace( step_index=step_index, From 9cde0d43f6741fb24fc174cdd58682f9c1af058c Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 17 Mar 2026 20:46:16 +0000 Subject: [PATCH 19/19] Clean up type errors in Megatron correctness changes --- src/art/megatron/finalize_grads.py | 22 +++++++++----- src/art/megatron/lora.py | 20 ++++++++++--- src/art/megatron/provider.py | 7 +++-- src/art/megatron/routing_replay.py | 13 ++++++-- src/art/megatron/train.py | 24 ++++++++------- src/art/unsloth/train.py | 7 ++--- tests/integration/megatron_forward_trace.py | 26 ++++++++-------- tests/integration/megatron_oracle_worker.py | 33 +++++++++++---------- 8 files changed, 92 insertions(+), 60 deletions(-) diff --git a/src/art/megatron/finalize_grads.py b/src/art/megatron/finalize_grads.py index 6fce32c3..2a770fea 100644 --- a/src/art/megatron/finalize_grads.py +++ b/src/art/megatron/finalize_grads.py @@ -4,6 +4,7 @@ from megatron.core import parallel_state as ps from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.transformer.module import MegatronModule import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -20,7 +21,7 @@ def _iter_named_trainable_parameters( - model: list[torch.nn.Module], + model: list[MegatronModule], ) -> Iterable[tuple[str, torch.nn.Parameter]]: seen: set[int] = set() for chunk_index, model_chunk in enumerate(model): @@ -36,7 +37,7 @@ def _iter_named_trainable_parameters( def _resolve_domain_group( domain: GradSyncDomain, -) -> torch.distributed.ProcessGroup | None: +) -> Any | None: if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN: group = ps.get_tensor_model_parallel_group(check_initialized=False) if group is None or group.size() <= 1: @@ -53,14 +54,14 @@ def _resolve_domain_group( def _resolve_reduce_op(op: GradSyncOp) -> Any: if op == GRAD_SYNC_OP_SUM: - return torch.distributed.ReduceOp.SUM + return torch.distributed.ReduceOp.SUM # ty: ignore[possibly-missing-attribute] if op == GRAD_SYNC_OP_AVG: - return torch.distributed.ReduceOp.AVG + return torch.distributed.ReduceOp.AVG # ty: ignore[possibly-missing-attribute] raise RuntimeError(f"Unknown grad sync op: {op}") def finalize_model_grads_extended( - model: list[torch.nn.Module], + model: list[MegatronModule], num_tokens: torch.Tensor | None = None, ) -> None: """Run Megatron finalize, then apply extra LoRA grad-sync reductions. @@ -71,7 +72,10 @@ def finalize_model_grads_extended( """ # All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, # embedding grads across first and last pipeline stages (if not tied) - finalize_model_grads(model, num_tokens=num_tokens) + finalize_model_grads( + cast(list[torch.nn.Module], model), + num_tokens=num_tokens, + ) buckets: dict[ tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device], @@ -119,7 +123,11 @@ def finalize_model_grads_extended( if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 else coalesced ) - torch.distributed.all_reduce(reduced, op=_resolve_reduce_op(op), group=group) + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + reduced, + op=_resolve_reduce_op(op), + group=group, + ) if reduced is not coalesced: reduced = reduced.to(dtype=coalesced.dtype) for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index fd62a249..56aa3f86 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -45,7 +45,12 @@ class LoRAParallelSpec(BaseModel): def _distributed_initialized() -> bool: - return torch.distributed.is_available() and torch.distributed.is_initialized() + is_initialized = getattr(torch.distributed, "is_initialized", None) + return ( + torch.distributed.is_available() + and callable(is_initialized) + and bool(is_initialized()) + ) def _get_shard_world_size(domain: ShardDomain) -> int: @@ -70,7 +75,7 @@ def _get_shard_rank(domain: ShardDomain) -> int: return group.rank() -def _get_shard_group(domain: ShardDomain) -> torch.distributed.ProcessGroup | None: +def _get_shard_group(domain: ShardDomain) -> Any | None: if not _distributed_initialized(): return None if domain == "tp": @@ -193,8 +198,14 @@ def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: raise RuntimeError( f"{self.adapter_model_prefix}: missing process group for replicated parameter domain={domain}" ) - src = torch.distributed.get_global_rank(group, 0) - torch.distributed.broadcast(param.data, src=src, group=group) + src = torch.distributed.get_global_rank( # ty: ignore[possibly-missing-attribute] + group, 0 + ) + torch.distributed.broadcast( # ty: ignore[possibly-missing-attribute] + param.data, + src=src, + group=group, + ) def reset_lora_parameters(self) -> None: """Initialize LoRA weights (A=Kaiming, B=zeros) like PEFT defaults.""" @@ -595,6 +606,7 @@ def _build_fc1_lora( alpha: float, num_local_experts: int, ) -> LoRA: + assert linear_fc1 is not None assert isinstance(linear_fc1.weight0, torch.Tensor) a_parallel_spec = LoRAParallelSpec( shard_domain="expert_tp", diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index acd2eda1..1b016628 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -1,7 +1,8 @@ import copy from functools import partial import inspect -from typing import Callable +from pathlib import Path +from typing import Callable, cast from megatron.bridge import AutoBridge from megatron.bridge.models.gpt_provider import GPTModelProvider @@ -70,9 +71,11 @@ def get_provider( "Only Qwen3 MoE models are supported" ) if torch_dtype != torch.bfloat16: + model_name_or_path = bridge.hf_pretrained.model_name_or_path + assert model_name_or_path is not None bridge.hf_pretrained._state_dict_accessor = StateDict( _CastingStateSource( - SafeTensorsStateSource(bridge.hf_pretrained.model_name_or_path), + SafeTensorsStateSource(cast(str | Path, model_name_or_path)), dtype=torch_dtype, ) ) diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py index 104fe185..86f1c4df 100644 --- a/src/art/megatron/routing_replay.py +++ b/src/art/megatron/routing_replay.py @@ -833,13 +833,16 @@ def patched_dispatch_postprocess( tokens_per_expert=tokens_per_expert, ) self._art_replay_expert_input_inverse_permutation = inverse_order_cpu + active_step_routes = controller._active_step_routes + if active_step_routes is None: + raise RuntimeError( + "MoE replay dispatcher preprocess called before set_step" + ) trace_row_uids, trace_uid_span = _canonical_trace_row_uids( canonical_expert_token_uids, tokens_per_expert=tokens_per_expert, local_expert_indices=getattr(self, "local_expert_indices", None), - sample_uid_span=int( - controller._active_step_routes.global_token_uids.numel() - ), + sample_uid_span=int(active_step_routes.global_token_uids.numel()), num_experts=int(getattr(self, "num_experts", 1)), ) _attach_trace_row_uids( @@ -1257,6 +1260,10 @@ def get_route_for_router( context_parallel_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: step_routes = self._active_step_routes + if step_routes is None: + raise RuntimeError( + "Routing replay get_route_for_router called before set_step" + ) call_cursor = self._router_call_cursors.get(router_key, 0) call_sequence = self._router_call_sequences.get(router_key) if call_sequence is None: diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index a67b6eea..b3da7e24 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -206,12 +206,12 @@ def build_training_runtime( ), ) - if not torch.distributed.is_initialized(): + if not torch.distributed.is_initialized(): # ty: ignore[possibly-missing-attribute] raise RuntimeError( "torch.distributed must be initialized before building runtime" ) - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() # ty: ignore[possibly-missing-attribute] + world_size = torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] if rank == 0 and print_env: print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) @@ -404,11 +404,15 @@ def _optimizer_step( def _reduce_loss( loss: torch.Tensor, - op: torch.distributed.ReduceOp.RedOpType = torch.distributed.ReduceOp.AVG, - group: torch.distributed.ProcessGroup | None = None, + op: Any = torch.distributed.ReduceOp.AVG, # ty: ignore[possibly-missing-attribute] + group: Any | None = None, ) -> torch.Tensor: reduced_loss = loss.detach().clone() - torch.distributed.all_reduce(reduced_loss, op=op, group=group) + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + reduced_loss, + op=op, + group=group, + ) return reduced_loss @@ -495,7 +499,7 @@ def run_training_step( experimental_config, reduction="sum", ) - micro_loss = loss_info.policy_loss + config.beta * loss_info.kl + micro_loss = loss_info.policy_loss micro_loss.backward() probs_corr_sum += float(loss_info.probs_corr.item()) detached_micro_loss = micro_loss.detach() @@ -515,7 +519,7 @@ def run_training_step( global_num_tokens = max(num_tokens.item(), 1.0) reduced_loss = _reduce_loss( raw_loss_sum / global_num_tokens, - op=torch.distributed.ReduceOp.SUM, + op=torch.distributed.ReduceOp.SUM, # ty: ignore[possibly-missing-attribute] group=ps.get_data_parallel_group(with_context_parallel=True), ) @@ -537,7 +541,7 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) while True: - torch.distributed.barrier() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] jobs_dir = "/tmp/megatron_training_jobs" os.makedirs(jobs_dir, exist_ok=True) job_names = sorted( @@ -673,7 +677,7 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: gc.collect() torch.cuda.empty_cache() - torch.distributed.barrier() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] if runtime.rank == 0: os.remove(job_path) with open( diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 2f93e316..e5d4b026 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -207,17 +207,14 @@ def compute_loss( ) trainer._metrics["train"]["loss/learning_rate"].append(config.learning_rate) - total_loss = loss.policy_loss + config.beta * loss.kl - trainer._metrics["train"]["loss/train"].append(total_loss.item()) + trainer._metrics["train"]["loss/train"].append(loss.policy_loss.item()) if loss.entropy is not None: trainer._metrics["train"]["loss/entropy"].append(loss.entropy.item()) - if config.beta > 0.0: - trainer._metrics["train"]["loss/kl_div"].append(loss.kl.item()) if loss.kl_policy_ref is not None: trainer._metrics["train"]["loss/kl_policy_ref"].append( loss.kl_policy_ref.item() ) - return total_loss + return loss.policy_loss return compute_loss diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py index 5e36fc87..98f43fc6 100644 --- a/tests/integration/megatron_forward_trace.py +++ b/tests/integration/megatron_forward_trace.py @@ -47,9 +47,9 @@ def _rank_metadata() -> dict[str, int]: """Builds lightweight distributed metadata for one trace call.""" rank = 0 world_size = 1 - if torch.distributed.is_initialized(): - rank = _safe_int(torch.distributed.get_rank(), 0) - world_size = _safe_int(torch.distributed.get_world_size(), 1) + if torch.distributed.is_initialized(): # ty: ignore[possibly-missing-attribute] + rank = _safe_int(torch.distributed.get_rank(), 0) # ty: ignore[possibly-missing-attribute] + world_size = _safe_int(torch.distributed.get_world_size(), 1) # ty: ignore[possibly-missing-attribute] return { "global_rank": rank, "world_size": world_size, @@ -904,15 +904,15 @@ def _gather_rank_traces( local_trace: dict[str, list[dict[str, Any]]], ) -> list[dict[str, list[dict[str, Any]]]] | None: if ( - not torch.distributed.is_initialized() - or torch.distributed.get_world_size() == 1 + not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] + or torch.distributed.get_world_size() == 1 # ty: ignore[possibly-missing-attribute] ): return [local_trace] gathered: list[dict[str, list[dict[str, Any]]] | None] = [ None - ] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(gathered, local_trace) - if torch.distributed.get_rank() != 0: + ] * torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] + torch.distributed.all_gather_object(gathered, local_trace) # ty: ignore[possibly-missing-attribute] + if torch.distributed.get_rank() != 0: # ty: ignore[possibly-missing-attribute] return None return cast(list[dict[str, list[dict[str, Any]]]], gathered) @@ -934,15 +934,15 @@ def _gather_rank_outputs( local_outputs: list[tuple[int | None, int, int | None, torch.Tensor]], ) -> list[list[tuple[int | None, int, int | None, torch.Tensor]]] | None: if ( - not torch.distributed.is_initialized() - or torch.distributed.get_world_size() == 1 + not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] + or torch.distributed.get_world_size() == 1 # ty: ignore[possibly-missing-attribute] ): return [local_outputs] gathered: list[ list[tuple[int | None, int, int | None, torch.Tensor]] | None - ] = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(gathered, local_outputs) - if torch.distributed.get_rank() != 0: + ] = [None] * torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] + torch.distributed.all_gather_object(gathered, local_outputs) # ty: ignore[possibly-missing-attribute] + if torch.distributed.get_rank() != 0: # ty: ignore[possibly-missing-attribute] return None return cast( list[list[tuple[int | None, int, int | None, torch.Tensor]]], diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index 33f3c08a..f84179b3 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -114,10 +114,12 @@ def _gather_full_state( """Gathers local state dicts to rank 0 and merges them.""" import torch - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() # ty: ignore[possibly-missing-attribute] + world_size = torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] gathered = [None for _ in range(world_size)] if rank == 0 else None - torch.distributed.gather_object(local_state, gathered, dst=0) + torch.distributed.gather_object( # ty: ignore[possibly-missing-attribute] + local_state, gathered, dst=0 + ) if rank != 0: return None assert gathered is not None @@ -562,7 +564,7 @@ def _reference_forward( return (out * self.scale).to(dtype=x.dtype) - LoRA.forward = _reference_forward + LoRA.forward = _reference_forward # ty: ignore[invalid-assignment] try: yield finally: @@ -705,7 +707,7 @@ def _worker_run(request: WorkerRunRequest) -> None: local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl") # ty: ignore[possibly-missing-attribute] _set_deterministic_seed(request.case_config.seed) _configure_cuda_precision(request.case_config) @@ -738,7 +740,7 @@ def _worker_run(request: WorkerRunRequest) -> None: shared_init_path = Path(request.shared_init_adapter_path) if not shared_init_path.exists(): initial_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0: + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] shared_init_path.parent.mkdir(parents=True, exist_ok=True) deterministic_init = _build_deterministic_shared_init( _require_not_none(initial_state, "initial_state"), @@ -748,17 +750,17 @@ def _worker_run(request: WorkerRunRequest) -> None: deterministic_init, str(shared_init_path), ) - torch.distributed.barrier() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] # load the shared initial lora into the model and validate we can collect it from the model adapter_model = load_file(str(shared_init_path)) megatron_train.load_adapter_into_model(model_chunks, adapter_model, optimizer) loaded_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0: + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] _validate_loaded_state_matches_adapter( _require_not_none(loaded_state, "loaded_state"), adapter_model ) - torch.distributed.barrier() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] # load the inputs packed_tensors = packed_tensors_from_dir( @@ -771,7 +773,6 @@ def _worker_run(request: WorkerRunRequest) -> None: train_config = types.TrainConfig( learning_rate=request.case_config.learning_rate, - beta=request.case_config.beta, kl_penalty_coef=0.0, grad_accumulation_sequences=global_grad_accumulation_sequences, ) @@ -831,10 +832,10 @@ def _capture_lora_grads() -> None: ) ordered_micro_outputs = forward_trace_capture.ordered_step_outputs() forward_trace_capture.save_current_step(traces_dir) - torch.distributed.barrier() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] current_lora_state = _collect_lora_state(model_chunks) - if torch.distributed.get_rank() == 0: + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] grads = _require_not_none(captured_grads, "captured_grads") initial_state = _require_not_none( initial_lora_state, "initial_lora_state" @@ -886,11 +887,11 @@ def _capture_lora_grads() -> None: lora_file=str(lora_rel), ) ) - torch.distributed.barrier() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] forward_trace_capture.close() - if torch.distributed.get_rank() == 0: + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] # build and save the moe routing replay bundle if request.capture_moe_routing_bundle_path is not None: replay_bundle = build_bundle_from_forward_trace_dir( @@ -918,8 +919,8 @@ def _capture_lora_grads() -> None: steps=step_traces, ) _write_json(topology_dir / "manifest.json", manifest.model_dump(mode="json")) - torch.distributed.barrier() - torch.distributed.destroy_process_group() + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + torch.distributed.destroy_process_group() # ty: ignore[possibly-missing-attribute] def run_worker_cli(run_request_path: Path) -> None: