diff --git a/src/art/dev/train.py b/src/art/dev/train.py index b0e232c5..0ada9ccb 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -1,7 +1,10 @@ -from typing import Literal +from typing import TYPE_CHECKING, Literal from typing_extensions import TypedDict +if TYPE_CHECKING: + from art.megatron.routing_replay import MoeRoutingReplayBundle + class TrainConfig(TypedDict, total=False): advantage_balance: float @@ -22,6 +25,9 @@ class TrainConfig(TypedDict, total=False): logprob_calculation_chunk_size: int mask_prob_ratio: bool max_negative_advantage_importance_sampling_weight: float + moe_routing_replay_bundle: "MoeRoutingReplayBundle | None" + moe_routing_replay_path: str | None + moe_routing_replay_strict: bool num_trajectories_learning_rate_multiplier_power: float plot_tensors: bool ppo: bool diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 5baf200f..64e4f420 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -778,7 +778,10 @@ async def _train_model( packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors" ) # Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train()) - estimated_gradient_steps = disk_packed_tensors["num_sequences"] + grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences)) + estimated_gradient_steps = math.ceil( + disk_packed_tensors["num_sequences"] / grad_accumulation_sequences + ) pbar = tqdm.tqdm(total=estimated_gradient_steps, desc="train") async for result in service.train( disk_packed_tensors, config, dev_config, verbose diff --git a/src/art/loss.py b/src/art/loss.py index 5a73d7b7..0aab6084 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from pydantic import BaseModel, ConfigDict import torch @@ -13,8 +13,10 @@ class Loss(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - mean_policy_loss: torch.Tensor - mean_entropy: torch.Tensor | None + reduction: Literal["mean", "sum"] + policy_loss: torch.Tensor + kl: torch.Tensor + entropy: torch.Tensor | None policy_loss_sum: torch.Tensor probs_corr: torch.Tensor kl_policy_ref: torch.Tensor | None = None @@ -26,6 +28,7 @@ def loss_fn( ref_logprobs: torch.Tensor | None, entropies: torch.Tensor | None, experimental_config: dev.TrainConfig, + reduction: Literal["mean", "sum"] = "mean", ) -> Loss: old_logprobs = shift_tensor(inputs["logprobs"], float("nan")) advantages = shift_tensor(inputs["advantages"], 0.0) @@ -123,19 +126,28 @@ def loss_fn( logprob_diff = old_logprobs - original_logprobs prob_ratio = torch.exp(logprob_diff) policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach() + if ref_logprobs is not None: + kl_div = ( + torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0 + ) + else: + kl_div = torch.zeros_like(policy_loss) policy_loss = policy_loss * weights * assistant_mask - mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6) - # Compute mean entropy for the current step + kl_div = kl_div * weights * assistant_mask + denominator = assistant_mask.sum() + 1e-6 if reduction == "mean" else 1.0 + reduced_policy_loss = policy_loss.sum() / denominator + kl = kl_div.sum() / denominator + # Compute reduced entropy for the current step. if entropies is not None: shifted_entropies = shift_tensor(entropies, 0.0) - mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / ( - assistant_mask.sum() + 1e-6 - ) + entropy = (shifted_entropies * weights * assistant_mask).sum() / denominator else: - mean_entropy = None + entropy = None return Loss( - mean_policy_loss=mean_policy_loss, - mean_entropy=mean_entropy, + reduction=reduction, + policy_loss=reduced_policy_loss, + kl=kl, + entropy=entropy, policy_loss_sum=policy_loss.sum(), probs_corr=probs_corr, kl_policy_ref=kl_policy_ref, diff --git a/src/art/megatron/finalize_grads.py b/src/art/megatron/finalize_grads.py new file mode 100644 index 00000000..2a770fea --- /dev/null +++ b/src/art/megatron/finalize_grads.py @@ -0,0 +1,134 @@ +from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Literal, cast + +from megatron.core import parallel_state as ps +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.transformer.module import MegatronModule +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +GradSyncDomain = Literal["tp_default", "expert_tp"] +GradSyncOp = Literal["none", "sum", "avg"] + +TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" +EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" +GRAD_SYNC_OP_NONE: GradSyncOp = "none" +GRAD_SYNC_OP_SUM: GradSyncOp = "sum" +GRAD_SYNC_OP_AVG: GradSyncOp = "avg" +VALID_DOMAINS = (TP_DEFAULT_GRAD_SYNC_DOMAIN, EXPERT_TP_GRAD_SYNC_DOMAIN) +VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_SUM, GRAD_SYNC_OP_AVG) + + +def _iter_named_trainable_parameters( + model: list[MegatronModule], +) -> Iterable[tuple[str, torch.nn.Parameter]]: + seen: set[int] = set() + for chunk_index, model_chunk in enumerate(model): + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + yield f"chunk{chunk_index}.{name}", param + + +def _resolve_domain_group( + domain: GradSyncDomain, +) -> Any | None: + if domain == TP_DEFAULT_GRAD_SYNC_DOMAIN: + group = ps.get_tensor_model_parallel_group(check_initialized=False) + if group is None or group.size() <= 1: + return None + return group + if domain != EXPERT_TP_GRAD_SYNC_DOMAIN: + raise RuntimeError(f"Unknown grad sync domain: {domain}") + + group = ps.get_expert_tensor_parallel_group(check_initialized=False) + if group is None or group.size() <= 1: + return None + return group + + +def _resolve_reduce_op(op: GradSyncOp) -> Any: + if op == GRAD_SYNC_OP_SUM: + return torch.distributed.ReduceOp.SUM # ty: ignore[possibly-missing-attribute] + if op == GRAD_SYNC_OP_AVG: + return torch.distributed.ReduceOp.AVG # ty: ignore[possibly-missing-attribute] + raise RuntimeError(f"Unknown grad sync op: {op}") + + +def finalize_model_grads_extended( + model: list[MegatronModule], + num_tokens: torch.Tensor | None = None, +) -> None: + """Run Megatron finalize, then apply extra LoRA grad-sync reductions. + + Megatron finalize handles DP/CP(via `param.allreduce=True`)(and expert-DP via `param.allreduce=False`) internally. + This extension handles extra TP/expert-TP reductions for params annotated + with grad_sync_* metadata. + """ + # All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, + # embedding grads across first and last pipeline stages (if not tied) + finalize_model_grads( + cast(list[torch.nn.Module], model), + num_tokens=num_tokens, + ) + + buckets: dict[ + tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device], + list[tuple[str, torch.Tensor]], + ] = defaultdict(list) + + for name, param in _iter_named_trainable_parameters(model): + domain: GradSyncDomain = getattr( + param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN + ) + if _resolve_domain_group(domain) is None: + continue + + op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) + if op not in VALID_SYNC_OPS: + raise RuntimeError(f"{name}: unsupported grad_sync_op={op}") + if op == GRAD_SYNC_OP_NONE: + continue + + if not hasattr(param, "main_grad"): + raise RuntimeError( + f"{name}: expected main_grad for domain={domain} reduce_op={op}, but attribute is missing" + ) + grad = param.main_grad + if grad is None: + raise RuntimeError( + f"{name}: expected non-None main_grad for domain={domain} reduce_op={op}" + ) + local_grad = cast( # local part of dtensor + torch.Tensor, grad._local_tensor if hasattr(grad, "_local_tensor") else grad + ) + buckets[(domain, op, local_grad.dtype, local_grad.device)].append( + (name, local_grad) + ) + + for (domain, op, _dtype, _device), entries in buckets.items(): + group = _resolve_domain_group( + domain + ) # already checked if the domain is one we are handling + + grads = [grad for _name, grad in entries] + coalesced = _flatten_dense_tensors(grads) + reduced = ( + coalesced.float() + if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 + else coalesced + ) + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + reduced, + op=_resolve_reduce_op(op), + group=group, + ) + if reduced is not coalesced: + reduced = reduced.to(dtype=coalesced.dtype) + for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): + grad.copy_(synced) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 3ba97a77..56aa3f86 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence import math -from typing import Sequence +from typing import Any, Literal from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.core import parallel_state as ps @@ -9,12 +10,133 @@ TERowParallelGroupedLinear, TERowParallelLinear, ) +from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.moe import grouped_gemm_util from megatron.core.transformer.moe.experts import TEGroupedMLP from megatron.core.transformer.transformer_layer import TransformerLayer +from pydantic import BaseModel, ConfigDict import torch +ShardDomain = Literal["tp", "expert_tp"] +GradSyncDomain = Literal["tp_default", "expert_tp"] +GradSyncOp = Literal["none", "sum", "avg"] + +TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" +EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" +GRAD_SYNC_OP_NONE: GradSyncOp = "none" +GRAD_SYNC_OP_SUM: GradSyncOp = "sum" +GRAD_SYNC_OP_AVG: GradSyncOp = "avg" + + +class LoRAParallelSpec(BaseModel): + # This spec only describes TP / expert-TP behavior. + # DP/CP vs expert-DP behavior is selected separately via `allreduce`. + model_config = ConfigDict(frozen=True) + + shard_domain: ShardDomain = "tp" + sharded: bool = False + shard_dim: int | None = None + grad_sync_domain: GradSyncDomain = TP_DEFAULT_GRAD_SYNC_DOMAIN + grad_sync_op: GradSyncOp = GRAD_SYNC_OP_NONE + + +def _distributed_initialized() -> bool: + is_initialized = getattr(torch.distributed, "is_initialized", None) + return ( + torch.distributed.is_available() + and callable(is_initialized) + and bool(is_initialized()) + ) + + +def _get_shard_world_size(domain: ShardDomain) -> int: + if not _distributed_initialized(): + return 1 + if domain == "tp": + return ps.get_tensor_model_parallel_world_size() + group = ps.get_expert_tensor_parallel_group(check_initialized=False) + if group is None: + return 1 + return group.size() + + +def _get_shard_rank(domain: ShardDomain) -> int: + if not _distributed_initialized(): + return 0 + if domain == "tp": + return ps.get_tensor_model_parallel_rank() + group = ps.get_expert_tensor_parallel_group(check_initialized=False) + if group is None: + return 0 + return group.rank() + + +def _get_shard_group(domain: ShardDomain) -> Any | None: + if not _distributed_initialized(): + return None + if domain == "tp": + return ps.get_tensor_model_parallel_group() + return ps.get_expert_tensor_parallel_group(check_initialized=False) + + +def _normalize_axis(axis: int, ndim: int) -> int: + if axis < 0: + axis += ndim + if axis < 0 or axis >= ndim: + raise ValueError(f"Invalid shard axis {axis} for tensor ndim={ndim}") + return axis + + +def _set_lora_parallel_metadata( + param: torch.nn.Parameter, + *, + parallel_spec: LoRAParallelSpec, + allreduce: bool, +) -> None: + replicated = not parallel_spec.sharded + setattr(param, "lora_shard_domain", parallel_spec.shard_domain) + setattr(param, "lora_tp_sharded", parallel_spec.sharded) + setattr(param, "lora_tp_replicated", replicated) + setattr(param, "lora_tp_shard_dim", parallel_spec.shard_dim) + setattr(param, "grad_sync_domain", parallel_spec.grad_sync_domain) + setattr(param, "grad_sync_op", parallel_spec.grad_sync_op) + # Megatron DDP routing flag: + # - allreduce=True: sync with regular DP/CP replicas. + # - allreduce=False: sync with expert-DP replicas. + # TP / expert-TP replica handling is controlled by grad_sync_* metadata. + setattr(param, "allreduce", allreduce) + + # Megatron's native TP finalize path consumes this attr. + setattr( + param, + "average_gradients_across_tp_domain", + ( + replicated + and parallel_spec.grad_sync_domain == TP_DEFAULT_GRAD_SYNC_DOMAIN + and parallel_spec.grad_sync_op == GRAD_SYNC_OP_AVG + ), + ) + + # Megatron optimizer and checkpoint logic rely on tensor model-parallel metadata + # to distinguish true shards from TP-duplicate params. + if parallel_spec.sharded: + shard_dim = parallel_spec.shard_dim + if shard_dim is None: + raise ValueError("LoRAParallelSpec.shard_dim must be set when sharded=True") + setattr(param, "tensor_model_parallel", True) + setattr(param, "partition_dim", _normalize_axis(shard_dim, param.ndim)) + # stride > 1 means the dim is split into blocks and each tp rank holds a shard of the block + # this might happen for fused e.g. gate_(up|proj), but loras are individual per module + setattr(param, "partition_stride", 1) + else: + setattr(param, "tensor_model_parallel", False) + setattr(param, "partition_dim", -1) + setattr(param, "partition_stride", 1) + class LoRA(torch.nn.Module): def __init__( @@ -27,6 +149,9 @@ def __init__( dtype: torch.dtype, device: torch.device, num_local_experts: int = 1, + a_parallel_spec: LoRAParallelSpec = LoRAParallelSpec(), + b_parallel_spec: LoRAParallelSpec = LoRAParallelSpec(), + allreduce: bool = True, ) -> None: super().__init__() assert num_local_experts == 1 or "{expert}" in adapter_model_prefix, ( @@ -44,6 +169,16 @@ def __init__( num_local_experts, rank, out_features, dtype=dtype, device=device ).squeeze(0) ) + _set_lora_parallel_metadata( + self.A_T, + parallel_spec=a_parallel_spec, + allreduce=allreduce, + ) + _set_lora_parallel_metadata( + self.B_T, + parallel_spec=b_parallel_spec, + allreduce=allreduce, + ) self._expert_offset = ps.get_expert_model_parallel_rank() * num_local_experts self.reset_lora_parameters() @@ -51,6 +186,27 @@ def __init__( def num_local_experts(self) -> int: return self.A_T.shape[0] if self.A_T.ndim == 3 else 1 + def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: + if not param.lora_tp_replicated: # ty: ignore[unresolved-attribute] + return + domain = param.lora_shard_domain # ty: ignore[unresolved-attribute] + world_size = _get_shard_world_size(domain) + if world_size <= 1: + return + group = _get_shard_group(domain) + if group is None: + raise RuntimeError( + f"{self.adapter_model_prefix}: missing process group for replicated parameter domain={domain}" + ) + src = torch.distributed.get_global_rank( # ty: ignore[possibly-missing-attribute] + group, 0 + ) + torch.distributed.broadcast( # ty: ignore[possibly-missing-attribute] + param.data, + src=src, + group=group, + ) + def reset_lora_parameters(self) -> None: """Initialize LoRA weights (A=Kaiming, B=zeros) like PEFT defaults.""" if self.A_T.ndim == 3: @@ -59,22 +215,38 @@ def reset_lora_parameters(self) -> None: else: torch.nn.init.kaiming_uniform_(self.A_T.T, a=math.sqrt(5)) torch.nn.init.zeros_(self.B_T) + self._broadcast_if_replicated(self.A_T) + self._broadcast_if_replicated(self.B_T) + + def _expected_weight_keys(self, suffix: str) -> list[str]: + if self.num_local_experts > 1: + return [ + f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{suffix}.weight" + for expert in range(self.num_local_experts) + ] + return [f"{self.adapter_model_prefix}.{suffix}.weight"] def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: - try: - self.load_weights( - adapter_model, - suffix="lora_A", - into=self.A_T, + missing_keys = [ + key + for suffix in ("lora_A", "lora_B") + for key in self._expected_weight_keys(suffix) + if key not in adapter_model + ] + if missing_keys: + raise KeyError( + f"Missing LoRA adapter keys for {self.adapter_model_prefix}: {sorted(missing_keys)}" ) - self.load_weights( - adapter_model, - suffix="lora_B", - into=self.B_T, - ) - except KeyError: - print("Unable to find LoRA weights for", self.adapter_model_prefix) - self.reset_lora_parameters() + self.load_weights( + adapter_model, + suffix="lora_A", + into=self.A_T, + ) + self.load_weights( + adapter_model, + suffix="lora_B", + into=self.B_T, + ) def load_weights( self, @@ -83,65 +255,104 @@ def load_weights( suffix: str, into: torch.nn.Parameter, ) -> None: - self.load_weight( - ( - torch.stack( - [ - adapter_model[ - f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{suffix}.weight" - ].T - for expert in range(self.num_local_experts) - ] - ) - if self.num_local_experts > 1 - else adapter_model[f"{self.adapter_model_prefix}.{suffix}.weight"].T - ), - into=into, - ) + keys = self._expected_weight_keys(suffix) + if self.num_local_experts > 1: + weight = torch.stack([adapter_model[key].T for key in keys]) + else: + weight = adapter_model[keys[0]].T + self.load_weight(weight, into=into) def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: - setattr(into, "sharded", False) - tp_world_size = ps.get_tensor_model_parallel_world_size() - tp_rank = ps.get_tensor_model_parallel_rank() - for axis in (-2, -1): - if weight.shape[axis] == into.shape[axis]: - continue - # assume our param is tensor sharded along this axis - assert weight.shape[axis] // tp_world_size == into.shape[axis], ( - f"Weight shape {weight.shape} does not match into shape {into.shape} along axis {axis}" + domain = into.lora_shard_domain # ty: ignore[unresolved-attribute] + if into.lora_tp_sharded: # ty: ignore[unresolved-attribute] + axis = into.lora_tp_shard_dim # ty: ignore[unresolved-attribute] + axis = _normalize_axis(axis, weight.ndim) + world_size = _get_shard_world_size(domain) + rank = _get_shard_rank(domain) + if weight.shape[axis] % world_size != 0: + raise ValueError( + f"{self.adapter_model_prefix}: weight shape {tuple(weight.shape)} is not divisible by world size " + f"{world_size} on axis {axis}" + ) + local_size = weight.shape[axis] // world_size + if into.shape[axis] != local_size: + raise ValueError( + f"{self.adapter_model_prefix}: expected local shard size {into.shape[axis]}, got {local_size}" + ) + weight = weight.narrow(axis, rank * local_size, local_size) + elif tuple(weight.shape) != tuple(into.shape): + raise ValueError( + f"{self.adapter_model_prefix}: unsharded load shape mismatch, got {tuple(weight.shape)} " + f"expected {tuple(into.shape)}" ) - s = into.shape[axis] - weight = weight.narrow(axis, tp_rank * s, s) - setattr(into, "sharded", True) into.data.copy_(weight) into.requires_grad = True - def sharded_lora_state_dict(self) -> dict[str, torch.Tensor]: - if self.num_local_experts > 1: + def _should_export_parameter(self, param: torch.nn.Parameter) -> bool: + """ + Determine if the given LoRA param should be exported in the sharded LoRA state dict + (drop replicated ranks/params). + """ + if self.num_local_experts > 1: # self is a MoE layer if ps.get_expert_data_parallel_rank() != 0: - return {} - return { - f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{key}": param.data[ - expert - ].T - for expert in range(self.num_local_experts) - for key, param in ( - ("lora_A.weight", self.A_T), - ("lora_B.weight", self.B_T), - ) - } - if ps.get_data_parallel_rank() != 0 or torch.all(self.A_T == 0): - return {} + return False + else: # self is a non-MoE layer + # dp x cp rank 0 participates + if ps.get_data_parallel_rank(with_context_parallel=True) != 0: + return False + + # this param is fully sharded, all shard ranks participate + if param.lora_tp_sharded: # ty: ignore[unresolved-attribute] + return True + # param is replicated, tp rank 0 or etp rank 0 participates + return _get_shard_rank(param.lora_shard_domain) == 0 # ty: ignore[unresolved-attribute] + + def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: return { - f"{self.adapter_model_prefix}.{key}": param.data.T - for key, param in ( - ("lora_A.weight", self.A_T), - ("lora_B.weight", self.B_T), - ) - if getattr(param, "sharded", False) - or ps.get_tensor_model_parallel_rank() == 0 + "domain": param.lora_shard_domain, # ty: ignore[unresolved-attribute] + "sharded": param.lora_tp_sharded, # ty: ignore[unresolved-attribute] + "shard_dim": param.lora_tp_shard_dim, # ty: ignore[unresolved-attribute] + "shard_world_size": _get_shard_world_size(param.lora_shard_domain) # ty: ignore[unresolved-attribute] + if param.lora_tp_sharded # ty: ignore[unresolved-attribute] + else 1, + "shard_rank": _get_shard_rank(param.lora_shard_domain) # ty: ignore[unresolved-attribute] + if param.lora_tp_sharded # ty: ignore[unresolved-attribute] + else 0, } + def _lora_params(self) -> list[tuple[str, torch.nn.Parameter]]: + return [ + ("lora_A.weight", self.A_T), + ("lora_B.weight", self.B_T), + ] + + def _export_items( + self, + ) -> list[tuple[str, torch.nn.Parameter, int | None]]: + export_items: list[tuple[str, torch.nn.Parameter, int | None]] = [] + for key, param in self._lora_params(): + if not self._should_export_parameter(param): + continue + if self.num_local_experts > 1: + for expert in range(self.num_local_experts): + full_key = f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{key}" + export_items.append((full_key, param, expert)) + else: + export_items.append((f"{self.adapter_model_prefix}.{key}", param, None)) + return export_items + + def sharded_lora_manifest(self) -> dict[str, dict[str, Any]]: + return { + key: self._manifest_for_param(param) + for key, param, _expert in self._export_items() + } + + def sharded_lora_state_dict(self) -> dict[str, torch.Tensor]: + state: dict[str, torch.Tensor] = {} + for key, param, expert in self._export_items(): + state[key] = param.data[expert].T if expert is not None else param.data.T + return state + def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None ) -> torch.Tensor: @@ -152,14 +363,13 @@ def forward( bsz = tokens_per_expert if isinstance(bsz, list): bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") - # If no tokens routed locally, return zeros + # If no tokens routed locally, return zeros. if isinstance(bsz, torch.Tensor) and int(torch.count_nonzero(bsz)) == 0: return x.new_zeros((x.shape[0], self.B_T.shape[-1])) tmp = grouped_gemm_util.ops.gmm(x, self.A_T, bsz, trans_b=False) # type: ignore[attr-defined] out = grouped_gemm_util.ops.gmm(tmp, self.B_T, bsz, trans_b=False) # type: ignore[attr-defined] return out * self.scale - else: - return ((x @ self.A_T) @ self.B_T) * self.scale + return ((x @ self.A_T) @ self.B_T) * self.scale class SelfAttentionLinearProjLoRA(torch.nn.Module): @@ -175,6 +385,20 @@ def __init__( self.provider = provider self.linear_proj = linear_proj assert isinstance(linear_proj.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=True, + shard_dim=-2, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": False, + "shard_dim": None, + "grad_sync_op": GRAD_SYNC_OP_SUM, # sum replicated TP contributions + } + ) self.lora = LoRA( adapter_model_prefix=adapter_model_prefix, in_features=linear_proj.in_features, @@ -183,22 +407,23 @@ def __init__( alpha=alpha, dtype=linear_proj.weight.dtype, device=linear_proj.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. + allreduce=True, ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: base_output, bias_output = self.linear_proj(x) assert isinstance(base_output, torch.Tensor) assert isinstance(bias_output, (torch.Tensor, type(None))) + lora_output = self.lora(x) - if ( - self.provider.sequence_parallel - and self.provider.tensor_model_parallel_size > 1 - ): - tp_rank = ps.get_tensor_model_parallel_rank() - tokens_per_rank = base_output.shape[0] - start = tp_rank * tokens_per_rank - end = start + tokens_per_rank - lora_output = lora_output[start:end] + if self.provider.tensor_model_parallel_size > 1: + if self.provider.sequence_parallel: + lora_output = reduce_scatter_to_sequence_parallel_region(lora_output) + else: + lora_output = reduce_from_tensor_model_parallel_region(lora_output) return base_output + lora_output, bias_output @@ -219,6 +444,10 @@ def __init__( assert self.provider.kv_channels is not None assert self.provider.num_query_groups is not None assert self.provider.num_attention_heads is not None + if self.provider.num_attention_heads % self.provider.num_query_groups != 0: + raise ValueError( + "num_attention_heads must be divisible by num_query_groups for QKV LoRA" + ) q_out_features = self.provider.kv_channels * self.provider.num_attention_heads kv_out_features = self.provider.kv_channels * self.provider.num_query_groups tp_world_size = ps.get_tensor_model_parallel_world_size() @@ -230,33 +459,72 @@ def __init__( ) q_out_features_per_rank = q_out_features // tp_world_size kv_out_features_per_rank = kv_out_features // tp_world_size + self.num_query_groups_per_partition = ( + self.provider.num_query_groups // tp_world_size + ) + self.num_attention_heads_per_group = ( + self.provider.num_attention_heads // self.provider.num_query_groups + ) + self.hidden_size_per_attention_head = self.provider.kv_channels assert isinstance(linear_qkv.weight, torch.Tensor) - self.q_proj_lora = LoRA( + self.q_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.q_proj", - in_features=linear_qkv.in_features, - out_features=q_out_features_per_rank, + linear_qkv=linear_qkv, rank=rank, alpha=alpha, - dtype=linear_qkv.weight.dtype, - device=linear_qkv.weight.device, + out_features=q_out_features_per_rank, ) - self.k_proj_lora = LoRA( + self.k_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.k_proj", - in_features=linear_qkv.in_features, - out_features=kv_out_features_per_rank, + linear_qkv=linear_qkv, rank=rank, alpha=alpha, - dtype=linear_qkv.weight.dtype, - device=linear_qkv.weight.device, + out_features=kv_out_features_per_rank, ) - self.v_proj_lora = LoRA( + self.v_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.v_proj", - in_features=linear_qkv.in_features, + linear_qkv=linear_qkv, + rank=rank, + alpha=alpha, out_features=kv_out_features_per_rank, + ) + + @staticmethod + def _build_qkv_lora( + *, + adapter_model_prefix: str, + linear_qkv: TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + out_features: int, + ) -> LoRA: + assert isinstance(linear_qkv.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=False, + shard_dim=None, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_SUM, # sum replicated TP contributions + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_dim": -1, + "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear_qkv.in_features, + out_features=out_features, rank=rank, alpha=alpha, dtype=linear_qkv.weight.dtype, device=linear_qkv.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. + allreduce=True, ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -272,20 +540,32 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: query = self.q_proj_lora(layernorm_output) key = self.k_proj_lora(layernorm_output) value = self.v_proj_lora(layernorm_output) - - assert isinstance(self.linear_qkv.config.kv_channels, int) - query_4d = query.reshape( - query.shape[0], query.shape[1], -1, self.linear_qkv.config.kv_channels + # Match Megatron mixed_qkv layout: + # [S, B, nqg, (nah/nqg + 2), hn] where each query-group packs + # [all query heads for that group, key, value]. + query_5d = query.reshape( + query.shape[0], + query.shape[1], + self.num_query_groups_per_partition, + self.num_attention_heads_per_group, + self.hidden_size_per_attention_head, ) - key_4d = key.reshape( - key.shape[0], key.shape[1], -1, self.linear_qkv.config.kv_channels + key_5d = key.reshape( + key.shape[0], + key.shape[1], + self.num_query_groups_per_partition, + 1, + self.hidden_size_per_attention_head, ) - value_4d = value.reshape( - value.shape[0], value.shape[1], -1, self.linear_qkv.config.kv_channels + value_5d = value.reshape( + value.shape[0], + value.shape[1], + self.num_query_groups_per_partition, + 1, + self.hidden_size_per_attention_head, ) - - qkv_4d = torch.cat([query_4d, key_4d, value_4d], dim=2) - adapter_output = qkv_4d.reshape(qkv_4d.shape[0], qkv_4d.shape[1], -1) + qkv_5d = torch.cat([query_5d, key_5d, value_5d], dim=3) + adapter_output = qkv_5d.reshape(qkv_5d.shape[0], qkv_5d.shape[1], -1) return linear_output + adapter_output, bias @@ -302,19 +582,49 @@ def __init__( super().__init__() assert linear_fc1 is not None self.linear_fc1 = linear_fc1 - assert isinstance(linear_fc1.weight0, torch.Tensor) - self.gate_lora = LoRA( + self.gate_lora = self._build_fc1_lora( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_proj", - in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features // 2, + linear_fc1=linear_fc1, rank=rank, alpha=alpha, - dtype=linear_fc1.weight0.dtype, - device=linear_fc1.weight0.device, num_local_experts=num_local_experts, ) - self.up_lora = LoRA( + self.up_lora = self._build_fc1_lora( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.up_proj", + linear_fc1=linear_fc1, + rank=rank, + alpha=alpha, + num_local_experts=num_local_experts, + ) + + @staticmethod + def _build_fc1_lora( + *, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelGroupedLinear, + rank: int, + alpha: float, + num_local_experts: int, + ) -> LoRA: + assert linear_fc1 is not None + assert isinstance(linear_fc1.weight0, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="expert_tp", + sharded=False, + shard_dim=None, + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_dim": -1, + "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, + "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, in_features=linear_fc1.in_features, out_features=linear_fc1.out_features // 2, rank=rank, @@ -322,6 +632,10 @@ def __init__( dtype=linear_fc1.weight0.dtype, device=linear_fc1.weight0.device, num_local_experts=num_local_experts, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Expert LoRA params use Megatron's expert-DP gradient buckets. + allreduce=False, ) def forward( @@ -347,6 +661,21 @@ def __init__( assert linear_fc2 is not None assert isinstance(linear_fc2.weight0, torch.Tensor) self.linear_fc2 = linear_fc2 + a_parallel_spec = LoRAParallelSpec( + shard_domain="expert_tp", + sharded=True, + shard_dim=-2, + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": False, + "shard_dim": None, + "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, + "grad_sync_op": GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads + } + ) self.lora = LoRA( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.down_proj", in_features=linear_fc2.in_features, @@ -356,6 +685,10 @@ def __init__( dtype=linear_fc2.weight0.dtype, device=linear_fc2.weight0.device, num_local_experts=num_local_experts, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + # Expert LoRA params use Megatron's expert-DP gradient buckets. + allreduce=False, ) def forward( @@ -363,83 +696,76 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: base_out, bias_out = self.linear_fc2(x, tokens_per_expert) adapter_out = self.lora(x, tokens_per_expert=tokens_per_expert) + # the reason there is no TP comm here is because the MoE token routing handles + # expert TP comm externally return base_out + adapter_out, bias_out def apply_lora_adapters( model: Sequence[torch.nn.Module], provider: GPTModelProvider, -) -> None: - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if isinstance(module, TransformerLayer): - adapter_model_prefix = ( - f"base_model.model.model.layers.{module.layer_number - 1}" - ) - assert isinstance(module.self_attention, SelfAttention) - self_attention_linear_proj = module.self_attention.linear_proj - if not isinstance(self_attention_linear_proj, TERowParallelLinear): - self_attention_linear_proj = ( - self_attention_linear_proj.linear_proj - ) - assert isinstance( - self_attention_linear_proj, TERowParallelLinear - ) - module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", - linear_proj=self_attention_linear_proj, - rank=1, - alpha=32, - provider=provider, - ) - self_attention_linear_qkv = module.self_attention.linear_qkv - if not isinstance( - self_attention_linear_qkv, TELayerNormColumnParallelLinear - ): - self_attention_linear_qkv = self_attention_linear_qkv.linear_qkv - assert isinstance( - self_attention_linear_qkv, TELayerNormColumnParallelLinear - ) - module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn", - linear_qkv=self_attention_linear_qkv, - rank=1, - alpha=32, - provider=provider, - ) - assert isinstance(module.mlp.experts, TEGroupedMLP) - mlp_experts_linear_fc1 = module.mlp.experts.linear_fc1 - if not isinstance( - mlp_experts_linear_fc1, - TEColumnParallelGroupedLinear, # type: ignore - ): - mlp_experts_linear_fc1 = mlp_experts_linear_fc1.linear_fc1 - assert isinstance( - mlp_experts_linear_fc1, - TEColumnParallelGroupedLinear, # type: ignore - ) - module.mlp.experts.linear_fc1 = MLPExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc1=mlp_experts_linear_fc1, - rank=1, - alpha=32, - num_local_experts=module.mlp.experts.num_local_experts, - ) - mlp_experts_linear_fc2 = module.mlp.experts.linear_fc2 - if not isinstance( - mlp_experts_linear_fc2, - TERowParallelGroupedLinear, # type: ignore - ): - mlp_experts_linear_fc2 = mlp_experts_linear_fc2.linear_fc2 - assert isinstance( - mlp_experts_linear_fc2, - TERowParallelGroupedLinear, # type: ignore - ) - module.mlp.experts.linear_fc2 = MLPExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, - rank=1, - alpha=32, - num_local_experts=module.mlp.experts.num_local_experts, - ) +) -> list[torch.nn.Module]: + def _unwrap_attr(value: Any, attr_name: str, expected_type: type[Any]) -> Any: + if isinstance(value, expected_type): + return value + unwrapped = getattr(value, attr_name) + assert isinstance(unwrapped, expected_type) + return unwrapped + + for chunk in model: + for module in chunk.modules(): + if isinstance(module, TransformerLayer): + adapter_model_prefix = ( + f"base_model.model.model.layers.{module.layer_number - 1}" + ) + assert isinstance(module.self_attention, SelfAttention) + self_attention_linear_proj = _unwrap_attr( + module.self_attention.linear_proj, + "linear_proj", + TERowParallelLinear, + ) + module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", + linear_proj=self_attention_linear_proj, + rank=1, + alpha=32, + provider=provider, + ) + self_attention_linear_qkv = _unwrap_attr( + module.self_attention.linear_qkv, + "linear_qkv", + TELayerNormColumnParallelLinear, + ) + module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn", + linear_qkv=self_attention_linear_qkv, + rank=1, + alpha=32, + provider=provider, + ) + assert isinstance(module.mlp.experts, TEGroupedMLP) + mlp_experts_linear_fc1 = _unwrap_attr( + module.mlp.experts.linear_fc1, + "linear_fc1", + TEColumnParallelGroupedLinear, # type: ignore[arg-type] + ) + module.mlp.experts.linear_fc1 = MLPExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc1=mlp_experts_linear_fc1, + rank=1, + alpha=32, + num_local_experts=module.mlp.experts.num_local_experts, + ) + mlp_experts_linear_fc2 = _unwrap_attr( + module.mlp.experts.linear_fc2, + "linear_fc2", + TERowParallelGroupedLinear, # type: ignore[arg-type] + ) + module.mlp.experts.linear_fc2 = MLPExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc2=mlp_experts_linear_fc2, + rank=1, + alpha=32, + num_local_experts=module.mlp.experts.num_local_experts, + ) + return list(model) diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index d1029b35..1b016628 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -1,10 +1,16 @@ import copy from functools import partial import inspect -from typing import Callable +from pathlib import Path +from typing import Callable, cast from megatron.bridge import AutoBridge from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.state import ( + SafeTensorsStateSource, + StateDict, + StateSource, +) from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.spec_utils import ModuleSpec @@ -28,15 +34,51 @@ def _resolve_layer_spec( return base_layer_spec(config, **kwargs) -def get_provider(model: str) -> GPTModelProvider: +class _CastingStateSource(StateSource): + def __init__(self, source: StateSource, *, dtype: torch.dtype): + self._source = source + self._dtype = dtype + + def get_all_keys(self) -> list[str]: + return self._source.get_all_keys() + + def load_tensors(self, keys: list[str]) -> dict[str, torch.Tensor]: + loaded = self._source.load_tensors(keys) + return { + key: ( + value.to(dtype=self._dtype) + if torch.is_floating_point(value) and value.dtype != self._dtype + else value + ) + for key, value in loaded.items() + } + + def has_glob(self, pattern: str) -> bool: + return self._source.has_glob(pattern) + + +def get_provider( + model: str, + *, + torch_dtype: torch.dtype = torch.bfloat16, +) -> GPTModelProvider: bridge = AutoBridge.from_hf_pretrained( model, - torch_dtype=torch.bfloat16, + dtype=torch_dtype, trust_remote_code=True, ) assert isinstance(bridge._model_bridge, Qwen3MoEBridge), ( "Only Qwen3 MoE models are supported" ) + if torch_dtype != torch.bfloat16: + model_name_or_path = bridge.hf_pretrained.model_name_or_path + assert model_name_or_path is not None + bridge.hf_pretrained._state_dict_accessor = StateDict( + _CastingStateSource( + SafeTensorsStateSource(cast(str | Path, model_name_or_path)), + dtype=torch_dtype, + ) + ) provider = bridge.to_megatron_provider() base_layer_spec = provider.transformer_layer_spec @@ -62,6 +104,11 @@ def _flex_attention_layer_spec( provider.expert_tensor_parallel_size = 1 provider.moe_shared_expert_overlap = True provider.moe_router_dtype = "fp32" + # params are disabled anyways, but should know about this if we switch to full FT + # because DP 'dummy' microbatches will unintentionally have loss for this + provider.moe_aux_loss_coeff = 0.0 + # effectively just a flag modifying finalize_model_grads behavior for DPxCP + provider.calculate_per_token_loss = True if provider.tensor_model_parallel_size > 1: provider.sequence_parallel = True return provider diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py new file mode 100644 index 00000000..86f1c4df --- /dev/null +++ b/src/art/megatron/routing_replay.py @@ -0,0 +1,1460 @@ +from __future__ import annotations + +from collections import defaultdict +import json +from pathlib import Path +import re +import types +from typing import Any, Protocol + +from megatron.core.tensor_parallel import ( + all_to_all, + gather_from_sequence_parallel_region, +) +from megatron.core.transformer.moe.moe_utils import permute, sort_chunks_by_idxs +from pydantic import BaseModel, ConfigDict, model_validator +from safetensors.torch import load_file, save_file +import torch + +ROUTER_NAME_TOKEN = ".mlp.router" +ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1" +GLOBAL_TOKEN_UIDS_KEY = "global_token_uids" +TRACE_ROW_TOKEN_UIDS_ATTR = "_art_trace_row_token_uids" +TRACE_UID_SPAN_ATTR = "_art_trace_uid_span" + +_ROUTER_LAYER_PATTERN = re.compile(r"decoder\.layers\.(?P\d+)\.mlp\.router$") +_TRACE_CHUNK_PREFIX_PATTERN = re.compile(r"^chunk(?P\d+)\.(?P.+)$") + + +def _to_tensor_cpu_contiguous( + tensor: torch.Tensor, *, dtype: torch.dtype +) -> torch.Tensor: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") + return tensor.detach().to(device="cpu", dtype=dtype).contiguous() + + +def _normalize_step_index(step_index: int) -> str: + if step_index < 0: + raise ValueError(f"step_index must be non-negative, got {step_index}") + return f"{step_index:06d}" + + +def _build_tensor_key(router_key: str, call_index: int, field_name: str) -> str: + return f"{router_key}/call_{call_index}/{field_name}" + + +def _flatten_router_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.ndim < 2: + raise RuntimeError( + f"Router tensor must have rank >=2, got shape={tuple(tensor.shape)}" + ) + num_experts = int(tensor.shape[-1]) + return tensor.reshape(-1, num_experts).contiguous() + + +def _extract_router_output_tensors(output: Any) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(output, (list, tuple)) and len(output) >= 2: + probs, routing_map = output[0], output[1] + elif isinstance(output, dict): + probs = output.get("probs") + routing_map = output.get("routing_map") + else: + raise RuntimeError(f"Unsupported router output type: {type(output)}") + + if not isinstance(probs, torch.Tensor): + raise RuntimeError(f"Expected probs tensor, got {type(probs)}") + if not isinstance(routing_map, torch.Tensor): + raise RuntimeError(f"Expected routing_map tensor, got {type(routing_map)}") + + probs_2d = _flatten_router_tensor(probs.to(torch.float32)) + routing_map_2d = _flatten_router_tensor(routing_map.bool()) + if probs_2d.shape != routing_map_2d.shape: + raise RuntimeError( + "Router output shape mismatch: " + f"probs={tuple(probs_2d.shape)} routing_map={tuple(routing_map_2d.shape)}" + ) + return probs_2d, routing_map_2d + + +def _extract_dp_slot_from_rank_meta(rank_meta: Any) -> tuple[int, int] | None: + if isinstance(rank_meta, dict): + rank_meta = [rank_meta] + if not isinstance(rank_meta, list) or not rank_meta: + return None + dp_ranks = { + int(item["dp_rank"]) + for item in rank_meta + if isinstance(item, dict) and "dp_rank" in item + } + dp_world_sizes = { + int(item["dp_world_size"]) + for item in rank_meta + if isinstance(item, dict) and "dp_world_size" in item + } + if len(dp_ranks) != 1 or len(dp_world_sizes) != 1: + return None + return next(iter(dp_ranks)), next(iter(dp_world_sizes)) + + +def _trace_call_route_metadata( + call_entry: dict[str, Any], +) -> tuple[int | None, int | None]: + sample_index = call_entry.get("micro_sample_index") + if isinstance(sample_index, int): + return int(sample_index), None + dp_slot = _extract_dp_slot_from_rank_meta(call_entry.get("rank_meta")) + micro_order = int(call_entry.get("micro_order", 0)) + if dp_slot is None: + return None, micro_order + dp_rank, dp_world_size = dp_slot + return None, micro_order * dp_world_size + dp_rank + + +def build_router_key_from_module_name(*, chunk_index: int, module_name: str) -> str: + match = _ROUTER_LAYER_PATTERN.search(module_name) + if match is None: + raise RuntimeError( + f"Unable to derive router key from module name '{module_name}'. " + f"Expected suffix matching '{_ROUTER_LAYER_PATTERN.pattern}'." + ) + layer_index = int(match.group("layer")) + return f"chunk_{chunk_index:02d}.layer_{layer_index:04d}.mlp.router" + + +def build_router_key_from_trace_name(trace_module_name: str) -> str: + chunk_match = _TRACE_CHUNK_PREFIX_PATTERN.match(trace_module_name) + if chunk_match is None: + raise RuntimeError( + "Forward trace router module name must start with 'chunk.'; " + f"got '{trace_module_name}'" + ) + chunk_index = int(chunk_match.group("chunk")) + module_name = chunk_match.group("name") + return build_router_key_from_module_name( + chunk_index=chunk_index, + module_name=module_name, + ) + + +class ParallelTopology(BaseModel): + tp: int + ep: int + etp: int = 1 + dp: int = 1 + sp: bool = False + cp: int = 1 + pp: int = 1 + vpp: int = 1 + + +class RouterCallRoute(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + expert_indices: torch.Tensor + expert_probs: torch.Tensor + expert_mask: torch.Tensor + routing_map: torch.Tensor | None = None + num_experts: int + sample_index: int | None = None + micro_slot: int | None = None + + @model_validator(mode="after") + def _validate(self) -> "RouterCallRoute": + self.expert_indices = _to_tensor_cpu_contiguous( + self.expert_indices, dtype=torch.int32 + ) + self.expert_probs = _to_tensor_cpu_contiguous( + self.expert_probs, dtype=torch.float32 + ) + self.expert_mask = _to_tensor_cpu_contiguous(self.expert_mask, dtype=torch.bool) + if self.routing_map is not None: + self.routing_map = _to_tensor_cpu_contiguous( + self.routing_map, dtype=torch.bool + ) + + if self.expert_indices.ndim != 2: + raise RuntimeError( + "expert_indices must have shape [num_tokens, max_topk], got " + f"{tuple(self.expert_indices.shape)}" + ) + if self.expert_probs.shape != self.expert_indices.shape: + raise RuntimeError( + "expert_probs shape must match expert_indices shape, got " + f"{tuple(self.expert_probs.shape)} vs {tuple(self.expert_indices.shape)}" + ) + if self.expert_mask.shape != self.expert_indices.shape: + raise RuntimeError( + "expert_mask shape must match expert_indices shape, got " + f"{tuple(self.expert_mask.shape)} vs {tuple(self.expert_indices.shape)}" + ) + if self.num_experts <= 0: + raise RuntimeError(f"num_experts must be >0, got {self.num_experts}") + if self.sample_index is not None: + self.sample_index = int(self.sample_index) + if self.micro_slot is not None: + self.micro_slot = int(self.micro_slot) + if self.routing_map is not None: + expected = (self.expert_indices.shape[0], self.num_experts) + if tuple(self.routing_map.shape) != expected: + raise RuntimeError( + "routing_map shape mismatch: " + f"expected={expected}, got={tuple(self.routing_map.shape)}" + ) + return self + + @property + def num_global_tokens(self) -> int: + return int(self.expert_indices.shape[0]) + + @property + def max_topk(self) -> int: + return int(self.expert_indices.shape[1]) + + +class StepRouterRoutes(BaseModel): + calls: dict[int, RouterCallRoute] + + @model_validator(mode="after") + def _validate_calls(self) -> "StepRouterRoutes": + if not self.calls: + raise RuntimeError("StepRouterRoutes.calls cannot be empty") + for call_index in self.calls: + if call_index < 0: + raise RuntimeError(f"call_index must be >=0, got {call_index}") + return self + + +class StepRoutes(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + routers: dict[str, StepRouterRoutes] + global_token_uids: torch.Tensor + + @model_validator(mode="after") + def _validate(self) -> "StepRoutes": + if not self.routers: + raise RuntimeError("StepRoutes.routers cannot be empty") + self.global_token_uids = _to_tensor_cpu_contiguous( + self.global_token_uids, dtype=torch.int64 + ) + if self.global_token_uids.ndim != 1: + raise RuntimeError( + "global_token_uids must have shape [num_global_tokens], got " + f"{tuple(self.global_token_uids.shape)}" + ) + if int(torch.unique(self.global_token_uids).numel()) != int( + self.global_token_uids.numel() + ): + raise RuntimeError("global_token_uids must be unique per step") + expected_tokens = int(self.global_token_uids.numel()) + for router_key, step_router in self.routers.items(): + for call_index, route in step_router.calls.items(): + if route.num_global_tokens != expected_tokens: + raise RuntimeError( + "Route token count mismatch for " + f"router='{router_key}' call={call_index}: " + f"route_tokens={route.num_global_tokens}, " + f"expected_tokens={expected_tokens}" + ) + return self + + +class MoeRoutingReplayBundle(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + format_version: str = ROUTER_KEY_FORMAT_VERSION + topology: ParallelTopology + num_steps: int + max_topk: int + router_keys: list[str] + steps: dict[int, StepRoutes] + + @model_validator(mode="after") + def _validate(self) -> "MoeRoutingReplayBundle": + if self.format_version != ROUTER_KEY_FORMAT_VERSION: + raise RuntimeError( + f"Unsupported format_version={self.format_version}; " + f"expected={ROUTER_KEY_FORMAT_VERSION}" + ) + if self.num_steps <= 0: + raise RuntimeError(f"num_steps must be >0, got {self.num_steps}") + if self.max_topk < 0: + raise RuntimeError(f"max_topk must be >=0, got {self.max_topk}") + if set(self.steps.keys()) != set(range(self.num_steps)): + raise RuntimeError( + "steps must be indexed from 0..num_steps-1 without gaps: " + f"num_steps={self.num_steps}, step_keys={sorted(self.steps.keys())}" + ) + if not self.router_keys: + raise RuntimeError("router_keys cannot be empty") + router_key_set = set(self.router_keys) + for step_index, step_routes in self.steps.items(): + step_router_keys = set(step_routes.routers.keys()) + if step_router_keys != router_key_set: + raise RuntimeError( + f"Step {step_index} router set mismatch. " + f"expected={sorted(router_key_set)}, got={sorted(step_router_keys)}" + ) + return self + + @classmethod + def from_dir(cls, bundle_dir: str | Path) -> "MoeRoutingReplayBundle": + base_dir = Path(bundle_dir) + manifest_path = base_dir / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"Missing routing replay manifest: {manifest_path}") + with manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + if manifest.get("format_version") != ROUTER_KEY_FORMAT_VERSION: + raise RuntimeError( + "Unsupported routing replay manifest version: " + f"{manifest.get('format_version')}" + ) + + topology = ParallelTopology.model_validate(manifest["topology"]) + num_steps = int(manifest["num_steps"]) + max_topk = int(manifest["max_topk"]) + router_keys = [str(key) for key in manifest["router_keys"]] + manifest_steps = manifest["steps"] + + steps: dict[int, StepRoutes] = {} + for step_index in range(num_steps): + step_manifest = manifest_steps[str(step_index)] + step_file = base_dir / step_manifest["file"] + if not step_file.exists(): + raise FileNotFoundError( + f"Missing routing replay step file for step={step_index}: {step_file}" + ) + step_tensors = load_file(str(step_file)) + if GLOBAL_TOKEN_UIDS_KEY not in step_tensors: + raise RuntimeError( + f"Step file missing '{GLOBAL_TOKEN_UIDS_KEY}': {step_file}" + ) + global_token_uids = step_tensors[GLOBAL_TOKEN_UIDS_KEY] + + routers: dict[str, StepRouterRoutes] = {} + for router_key in router_keys: + router_step_manifest = step_manifest["routers"].get(router_key) + if router_step_manifest is None: + raise RuntimeError( + f"Step manifest missing router_key='{router_key}' for step={step_index}" + ) + calls: dict[int, RouterCallRoute] = {} + for call_index_raw, call_manifest in router_step_manifest.items(): + call_index = int(call_index_raw) + expert_indices_key = _build_tensor_key( + router_key, call_index, "expert_indices" + ) + expert_probs_key = _build_tensor_key( + router_key, call_index, "expert_probs" + ) + expert_mask_key = _build_tensor_key( + router_key, call_index, "expert_mask" + ) + routing_map_key = _build_tensor_key( + router_key, call_index, "routing_map" + ) + if expert_indices_key not in step_tensors: + raise RuntimeError( + f"Missing tensor key '{expert_indices_key}' in {step_file}" + ) + if expert_probs_key not in step_tensors: + raise RuntimeError( + f"Missing tensor key '{expert_probs_key}' in {step_file}" + ) + if expert_mask_key not in step_tensors: + raise RuntimeError( + f"Missing tensor key '{expert_mask_key}' in {step_file}" + ) + routing_map = ( + step_tensors[routing_map_key] + if routing_map_key in step_tensors + else None + ) + calls[call_index] = RouterCallRoute( + expert_indices=step_tensors[expert_indices_key], + expert_probs=step_tensors[expert_probs_key], + expert_mask=step_tensors[expert_mask_key], + routing_map=routing_map, + num_experts=int(call_manifest["num_experts"]), + sample_index=call_manifest.get("sample_index"), + micro_slot=call_manifest.get("micro_slot"), + ) + routers[router_key] = StepRouterRoutes(calls=calls) + steps[step_index] = StepRoutes( + routers=routers, + global_token_uids=global_token_uids, + ) + + return cls( + format_version=ROUTER_KEY_FORMAT_VERSION, + topology=topology, + num_steps=num_steps, + max_topk=max_topk, + router_keys=router_keys, + steps=steps, + ) + + def to_dir(self, bundle_dir: str | Path) -> None: + base_dir = Path(bundle_dir) + base_dir.mkdir(parents=True, exist_ok=True) + + manifest_steps: dict[str, dict[str, Any]] = {} + for step_index in range(self.num_steps): + step_routes = self.steps[step_index] + step_file_name = f"step_{_normalize_step_index(step_index)}.safetensors" + step_file_path = base_dir / step_file_name + step_tensors: dict[str, torch.Tensor] = { + GLOBAL_TOKEN_UIDS_KEY: _to_tensor_cpu_contiguous( + step_routes.global_token_uids, dtype=torch.int64 + ) + } + step_manifest_routers: dict[str, dict[str, dict[str, int]]] = {} + for router_key in self.router_keys: + router_routes = step_routes.routers[router_key] + call_manifest: dict[str, dict[str, int]] = {} + for call_index, route in sorted(router_routes.calls.items()): + step_tensors[ + _build_tensor_key(router_key, call_index, "expert_indices") + ] = _to_tensor_cpu_contiguous( + route.expert_indices, dtype=torch.int32 + ) + step_tensors[ + _build_tensor_key(router_key, call_index, "expert_probs") + ] = _to_tensor_cpu_contiguous( + route.expert_probs, dtype=torch.float32 + ) + step_tensors[ + _build_tensor_key(router_key, call_index, "expert_mask") + ] = _to_tensor_cpu_contiguous(route.expert_mask, dtype=torch.bool) + if route.routing_map is not None: + step_tensors[ + _build_tensor_key(router_key, call_index, "routing_map") + ] = _to_tensor_cpu_contiguous( + route.routing_map, dtype=torch.bool + ) + call_entry: dict[str, int] = {"num_experts": route.num_experts} + if route.sample_index is not None: + call_entry["sample_index"] = int(route.sample_index) + if route.micro_slot is not None: + call_entry["micro_slot"] = int(route.micro_slot) + call_manifest[str(call_index)] = call_entry + step_manifest_routers[router_key] = call_manifest + save_file(step_tensors, str(step_file_path)) + manifest_steps[str(step_index)] = { + "file": step_file_name, + "routers": step_manifest_routers, + } + + manifest = { + "format_version": ROUTER_KEY_FORMAT_VERSION, + "topology": self.topology.model_dump(mode="json"), + "num_steps": self.num_steps, + "max_topk": self.max_topk, + "router_keys": self.router_keys, + "steps": manifest_steps, + } + with (base_dir / "manifest.json").open("w", encoding="utf-8") as handle: + json.dump(manifest, handle, indent=2, sort_keys=True) + + +class LocalTokenIndexer(Protocol): + def build_local_token_uids( + self, + *, + global_token_uids: torch.Tensor, + num_local_tokens: int, + sequence_parallel: bool, + context_parallel_size: int, + ) -> torch.Tensor: + """Build local token uid order for current rank.""" + + +class TopologyAwareLocalTokenIndexer: + def __init__(self, parallel_state_module: Any | None = None) -> None: + self._parallel_state = parallel_state_module + + def _ps(self) -> Any: + if self._parallel_state is not None: + return self._parallel_state + from megatron.core import parallel_state as ps + + self._parallel_state = ps + return ps + + def build_local_token_uids( + self, + *, + global_token_uids: torch.Tensor, + num_local_tokens: int, + sequence_parallel: bool, + context_parallel_size: int, + ) -> torch.Tensor: + ps = self._ps() + + local_uids = global_token_uids.to(dtype=torch.int64, device="cpu").view(1, -1) + + cp_size = int(ps.get_context_parallel_world_size()) + if context_parallel_size > 1 and cp_size > 1: + from megatron.core.utils import get_batch_on_this_cp_rank + + local_uids = get_batch_on_this_cp_rank({"tokens": local_uids})["tokens"] + + tp_size = int(ps.get_tensor_model_parallel_world_size()) + tp_rank = int(ps.get_tensor_model_parallel_rank()) if tp_size > 1 else 0 + if sequence_parallel and tp_size > 1: + tokens_per_tp_rank = local_uids.shape[1] // tp_size + start = tp_rank * tokens_per_tp_rank + local_uids = local_uids[:, start : start + tokens_per_tp_rank] + + return local_uids.reshape(-1).contiguous() + + +_ACTIVE_ROUTING_REPLAY_CONTROLLER: MoeRoutingReplayController | None = None + + +def _active_routing_replay_controller() -> MoeRoutingReplayController | None: + return _ACTIVE_ROUTING_REPLAY_CONTROLLER + + +def _dispatcher_local_token_uids( + controller: MoeRoutingReplayController, + dispatcher: Any, + *, + num_local_tokens: int, +) -> torch.Tensor: + step_routes = controller._active_step_routes + if step_routes is None: + raise RuntimeError("Routing replay dispatcher used without an active step") + local_uids = controller.local_token_indexer.build_local_token_uids( + global_token_uids=step_routes.global_token_uids, + num_local_tokens=num_local_tokens, + sequence_parallel=bool( + getattr(getattr(dispatcher, "config", None), "sequence_parallel", False) + ), + context_parallel_size=int( + getattr(getattr(dispatcher, "config", None), "context_parallel_size", 1) + ), + ) + if int(local_uids.numel()) != num_local_tokens: + raise RuntimeError( + "Local routing replay uid count mismatch: " + f"expected={num_local_tokens}, got={int(local_uids.numel())}" + ) + sample_index = getattr(controller, "_active_sample_index", None) + uid_span = int(step_routes.global_token_uids.numel()) + if isinstance(sample_index, int) and sample_index >= 0 and uid_span > 0: + local_uids = local_uids + sample_index * uid_span + return local_uids + + +def _trace_row_uids_from_source(source: Any) -> tuple[torch.Tensor | None, int | None]: + row_token_uids = getattr(source, TRACE_ROW_TOKEN_UIDS_ATTR, None) + if not isinstance(row_token_uids, torch.Tensor): + return None, None + uid_span = getattr(source, TRACE_UID_SPAN_ATTR, None) + uid_span_int = uid_span if isinstance(uid_span, int) and uid_span > 0 else None + return row_token_uids, uid_span_int + + +def _attach_trace_row_uids( + target: Any, + *, + row_token_uids: torch.Tensor, + uid_span: int | None, +) -> None: + setattr( + target, + TRACE_ROW_TOKEN_UIDS_ATTR, + row_token_uids.detach().to(device="cpu", dtype=torch.int64).reshape(-1), + ) + setattr(target, TRACE_UID_SPAN_ATTR, uid_span) + + +def _canonicalize_expert_token_order( + expert_inputs: torch.Tensor, + expert_probs: torch.Tensor, + expert_token_uids: torch.Tensor, + *, + tokens_per_expert: torch.Tensor | list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(tokens_per_expert, torch.Tensor): + counts = [int(count) for count in tokens_per_expert.tolist()] + else: + counts = [int(count) for count in tokens_per_expert] + + if sum(counts) != int(expert_token_uids.numel()): + raise RuntimeError( + "Expert token uid count mismatch after dispatch: " + f"uids={int(expert_token_uids.numel())}, " + f"tokens_per_expert_sum={sum(counts)}" + ) + + order_segments: list[torch.Tensor] = [] + cursor = 0 + for count in counts: + if count <= 1: + order_segments.append( + torch.arange(cursor, cursor + count, dtype=torch.long) + ) + cursor += count + continue + segment_uids = expert_token_uids[cursor : cursor + count].to(device="cpu") + segment_order = torch.argsort(segment_uids, stable=True) + cursor + order_segments.append(segment_order) + cursor += count + + if not order_segments: + empty = torch.empty(0, dtype=torch.long) + return expert_inputs, expert_probs, expert_token_uids, empty + + canonical_order_cpu = torch.cat(order_segments, dim=0) + inverse_order_cpu = torch.empty_like(canonical_order_cpu) + inverse_order_cpu[canonical_order_cpu] = torch.arange( + canonical_order_cpu.numel(), dtype=torch.long + ) + + canonical_order = canonical_order_cpu.to( + device=expert_inputs.device, dtype=torch.long + ) + reordered_inputs = expert_inputs.index_select(0, canonical_order) + reordered_probs = expert_probs.index_select(0, canonical_order) + reordered_uids = expert_token_uids.index_select( + 0, + canonical_order_cpu.to(device=expert_token_uids.device, dtype=torch.long), + ) + return ( + reordered_inputs, + reordered_probs, + reordered_uids, + inverse_order_cpu, + ) + + +def _canonical_trace_row_uids( + expert_token_uids: torch.Tensor, + *, + tokens_per_expert: torch.Tensor | list[int], + local_expert_indices: list[int] | tuple[int, ...] | None, + sample_uid_span: int, + num_experts: int, +) -> tuple[torch.Tensor, int]: + if isinstance(tokens_per_expert, torch.Tensor): + counts = [int(count) for count in tokens_per_expert.tolist()] + else: + counts = [int(count) for count in tokens_per_expert] + + expert_indices = ( + [int(expert_index) for expert_index in local_expert_indices] + if local_expert_indices is not None + else list(range(len(counts))) + ) + if len(expert_indices) != len(counts): + raise RuntimeError( + "Local expert index metadata mismatch: " + f"num_expert_indices={len(expert_indices)}, num_counts={len(counts)}" + ) + row_uid_span = sample_uid_span * max(int(num_experts), 1) + row_uid_chunks: list[torch.Tensor] = [] + cursor = 0 + for global_expert_id, count in zip(expert_indices, counts): + count_int = int(count) + segment = expert_token_uids[cursor : cursor + count_int].to(dtype=torch.int64) + sample_ids = torch.div(segment, sample_uid_span, rounding_mode="floor") + local_token_ids = torch.remainder(segment, sample_uid_span) + row_uid_chunks.append( + sample_ids * row_uid_span + + int(global_expert_id) * sample_uid_span + + local_token_ids + ) + cursor += count_int + if cursor != int(expert_token_uids.numel()): + raise RuntimeError( + "Canonical trace row uid construction did not consume all expert rows: " + f"consumed={cursor}, total={int(expert_token_uids.numel())}" + ) + if not row_uid_chunks: + return expert_token_uids.new_empty((0,), dtype=torch.int64), row_uid_span + return torch.cat(row_uid_chunks, dim=0).contiguous(), row_uid_span + + +def _patch_alltoall_dispatcher_preprocess() -> None: + try: + from megatron.core.transformer.moe.experts import TEGroupedMLP + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAlltoAllTokenDispatcher, + ) + + from art.megatron.lora import MLPExpertsLinearFC2LoRA + except Exception: + return + + if hasattr(MoEAlltoAllTokenDispatcher, "_art_router_replay_preprocess_patched"): + return + + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + original_dispatch_preprocess = MoEAlltoAllTokenDispatcher.dispatch_preprocess + original_token_dispatch = MoEAlltoAllTokenDispatcher.token_dispatch + original_dispatch_postprocess = MoEAlltoAllTokenDispatcher.dispatch_postprocess + original_combine_preprocess = MoEAlltoAllTokenDispatcher.combine_preprocess + original_te_grouped_mlp_forward = TEGroupedMLP.forward + original_fc2_forward = MLPExpertsLinearFC2LoRA.forward + + def patched_preprocess( + self: Any, routing_map: torch.Tensor, *args: Any, **kwargs: Any + ): + result = original_preprocess(self, routing_map, *args, **kwargs) + if ( + not getattr(self, "drop_and_pad", False) + and getattr(self.config, "moe_expert_capacity_factor", None) is None + and not ( + getattr(self.config, "moe_router_padding_for_quantization", None) + or getattr(self.config, "moe_router_padding_for_fp8", None) + ) + ): + self.num_out_tokens = int(routing_map.sum().item()) + return result + + def patched_dispatch_preprocess( + self: Any, + hidden_states: torch.Tensor, + routing_map: torch.Tensor, + probs: torch.Tensor, + ): + result = original_dispatch_preprocess(self, hidden_states, routing_map, probs) + self._art_replay_permuted_local_token_uids = None + self._art_replay_global_input_token_uids = None + self._art_replay_expert_input_inverse_permutation = None + + controller = _active_routing_replay_controller() + if controller is None: + return result + + local_token_uids = _dispatcher_local_token_uids( + controller, + self, + num_local_tokens=int(routing_map.shape[0]), + ) + permuted_local_uids, _, _ = permute( + local_token_uids.to( + device=hidden_states.device, dtype=torch.int64 + ).unsqueeze(-1), + self.routing_map, + num_out_tokens=self.num_out_tokens, + fused=False, + drop_and_pad=self.drop_and_pad, + ) + self._art_replay_permuted_local_token_uids = permuted_local_uids.reshape( + -1 + ).contiguous() + return result + + def patched_token_dispatch( + self: Any, + permutated_local_input_tokens: torch.Tensor, + permuted_probs: torch.Tensor, + ): + result = original_token_dispatch( + self, + permutated_local_input_tokens, + permuted_probs, + ) + controller = _active_routing_replay_controller() + permuted_local_token_uids = getattr( + self, "_art_replay_permuted_local_token_uids", None + ) + if controller is None or permuted_local_token_uids is None: + return result + + global_token_uids = permuted_local_token_uids.to( + device=permutated_local_input_tokens.device, dtype=torch.int64 + ).unsqueeze(-1) + if self.ep_size > 1: + global_token_uids = all_to_all( + self.ep_group, + global_token_uids, + self.output_splits, + self.input_splits, + ) + if self.tp_size > 1: + output_split_sizes = ( + None + if self.output_splits_tp is None + else self.output_splits_tp.tolist() + ) + global_token_uids = gather_from_sequence_parallel_region( + global_token_uids, + group=self.tp_group, + output_split_sizes=output_split_sizes, + ) + self._art_replay_global_input_token_uids = global_token_uids.reshape( + -1 + ).contiguous() + return result + + def patched_dispatch_postprocess( + self: Any, + global_input_tokens: torch.Tensor, + global_probs: torch.Tensor, + ): + expert_inputs, tokens_per_expert, expert_probs = original_dispatch_postprocess( + self, + global_input_tokens, + global_probs, + ) + controller = _active_routing_replay_controller() + global_input_token_uids = getattr( + self, "_art_replay_global_input_token_uids", None + ) + if controller is None or global_input_token_uids is None or self.drop_and_pad: + return expert_inputs, tokens_per_expert, expert_probs + + expert_token_uids = global_input_token_uids + if self.num_local_experts > 1: + sorted_token_uids, _ = sort_chunks_by_idxs( + expert_token_uids.unsqueeze(-1), + self.num_global_tokens_per_local_expert.ravel(), + self.sort_input_by_local_experts, + fused=False, + ) + expert_token_uids = sorted_token_uids.reshape(-1).contiguous() + + ( + expert_inputs, + expert_probs, + canonical_expert_token_uids, + inverse_order_cpu, + ) = _canonicalize_expert_token_order( + expert_inputs, + expert_probs, + expert_token_uids, + tokens_per_expert=tokens_per_expert, + ) + self._art_replay_expert_input_inverse_permutation = inverse_order_cpu + active_step_routes = controller._active_step_routes + if active_step_routes is None: + raise RuntimeError( + "MoE replay dispatcher preprocess called before set_step" + ) + trace_row_uids, trace_uid_span = _canonical_trace_row_uids( + canonical_expert_token_uids, + tokens_per_expert=tokens_per_expert, + local_expert_indices=getattr(self, "local_expert_indices", None), + sample_uid_span=int(active_step_routes.global_token_uids.numel()), + num_experts=int(getattr(self, "num_experts", 1)), + ) + _attach_trace_row_uids( + expert_inputs, + row_token_uids=trace_row_uids, + uid_span=trace_uid_span, + ) + return expert_inputs, tokens_per_expert, expert_probs + + def patched_combine_preprocess(self: Any, hidden_states: torch.Tensor): + inverse_order_cpu = getattr( + self, "_art_replay_expert_input_inverse_permutation", None + ) + if inverse_order_cpu is not None and inverse_order_cpu.numel() > 0: + hidden_states = hidden_states.index_select( + 0, + inverse_order_cpu.to(device=hidden_states.device, dtype=torch.long), + ) + self._art_replay_expert_input_inverse_permutation = None + return original_combine_preprocess(self, hidden_states) + + def patched_te_grouped_mlp_forward( + self: Any, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + permuted_probs: torch.Tensor, + ): + row_token_uids, uid_span = _trace_row_uids_from_source( + permuted_local_hidden_states + ) + if row_token_uids is not None: + _attach_trace_row_uids( + self.linear_fc2, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + return original_te_grouped_mlp_forward( + self, + permuted_local_hidden_states, + tokens_per_expert, + permuted_probs, + ) + + def patched_fc2_forward( + self: Any, + x: torch.Tensor, + tokens_per_expert: list[int] | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + row_token_uids, uid_span = _trace_row_uids_from_source(x) + if row_token_uids is None: + row_token_uids, uid_span = _trace_row_uids_from_source(self) + if row_token_uids is not None: + _attach_trace_row_uids( + self.linear_fc2, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + _attach_trace_row_uids( + self.lora, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + return original_fc2_forward(self, x, tokens_per_expert) + + setattr(MoEAlltoAllTokenDispatcher, "preprocess", patched_preprocess) + setattr( + MoEAlltoAllTokenDispatcher, + "dispatch_preprocess", + patched_dispatch_preprocess, + ) + setattr(MoEAlltoAllTokenDispatcher, "token_dispatch", patched_token_dispatch) + setattr( + MoEAlltoAllTokenDispatcher, + "dispatch_postprocess", + patched_dispatch_postprocess, + ) + setattr( + MoEAlltoAllTokenDispatcher, + "combine_preprocess", + patched_combine_preprocess, + ) + setattr(TEGroupedMLP, "forward", patched_te_grouped_mlp_forward) + setattr(MLPExpertsLinearFC2LoRA, "forward", patched_fc2_forward) + setattr(MoEAlltoAllTokenDispatcher, "_art_router_replay_preprocess_patched", True) + + +class MoeRoutingReplayController: + def __init__( + self, + *, + bundle: MoeRoutingReplayBundle, + strict: bool, + local_token_indexer: LocalTokenIndexer | None = None, + ) -> None: + self.bundle = bundle + self.strict = strict + self.local_token_indexer = ( + local_token_indexer or TopologyAwareLocalTokenIndexer() + ) + + self._active_step_index: int | None = None + self._active_sample_index: int | None = None + self._active_step_routes: StepRoutes | None = None + self._router_call_cursors: dict[str, int] = {} + self._router_call_sequences: dict[str, list[int]] = {} + self._global_uid_to_row_index: dict[int, int] = {} + self._local_router_keys: set[str] = set() + self._active_micro_order: int | None = None + + self._patched_router_modules: list[dict[str, Any]] = [] + + def install_router_patches(self, model_chunks: list[Any]) -> None: + if self._patched_router_modules: + return + _patch_alltoall_dispatcher_preprocess() + + for chunk_index, chunk in enumerate(model_chunks): + for module_name, module in chunk.named_modules(): + if ROUTER_NAME_TOKEN not in module_name: + continue + if not hasattr(module, "routing"): + continue + router_key = build_router_key_from_module_name( + chunk_index=chunk_index, + module_name=module_name, + ) + if self.strict and router_key not in self.bundle.router_keys: + raise RuntimeError( + "Router key from model is missing in replay bundle: " + f"router_key='{router_key}'" + ) + + original_routing = module.routing + if getattr(module, "_art_router_replay_patched", False): + continue + + sequence_parallel = bool( + getattr(getattr(module, "config", None), "sequence_parallel", False) + ) + context_parallel_size = int( + getattr(getattr(module, "config", None), "context_parallel_size", 1) + ) + + def routing_wrapper( + _module: Any, + logits: torch.Tensor, + *args: Any, + _router_key: str = router_key, + _sequence_parallel: bool = sequence_parallel, + _context_parallel_size: int = context_parallel_size, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + live_probs, live_routing_map = original_routing( + logits, *args, **kwargs + ) + replay_probs, replay_routing_map = self.get_route_for_router( + router_key=_router_key, + logits=live_probs, + sequence_parallel=_sequence_parallel, + context_parallel_size=_context_parallel_size, + ) + # same result, but autograd goes through + probs = ( + live_probs + + ( + replay_probs.to( + device=live_probs.device, + dtype=live_probs.dtype, + ) + - live_probs + ).detach() + ) + routing_map = replay_routing_map.to( + device=live_routing_map.device, + dtype=live_routing_map.dtype, + ) + return probs, routing_map + + module.routing = types.MethodType(routing_wrapper, module) + module._art_router_replay_patched = True + self._local_router_keys.add(router_key) + self._patched_router_modules.append( + { + "module": module, + "router_key": router_key, + "original_routing": original_routing, + } + ) + + def remove_router_patches(self) -> None: + global _ACTIVE_ROUTING_REPLAY_CONTROLLER + for item in self._patched_router_modules: + module = item["module"] + module.routing = item["original_routing"] + if hasattr(module, "_art_router_replay_patched"): + delattr(module, "_art_router_replay_patched") + self._patched_router_modules.clear() + self._local_router_keys.clear() + if _ACTIVE_ROUTING_REPLAY_CONTROLLER is self: + _ACTIVE_ROUTING_REPLAY_CONTROLLER = None + + def begin_micro(self, sample_index: int | None, micro_order: int) -> None: + self._active_sample_index = sample_index + self._active_micro_order = micro_order + + def set_step( + self, + *, + step_index: int, + sample_index: int | list[int | None], + global_grad_accumulation_sequences: int | None = None, + ) -> None: + global _ACTIVE_ROUTING_REPLAY_CONTROLLER + + if step_index not in self.bundle.steps: + raise RuntimeError( + f"Replay bundle missing step_index={step_index}. " + f"Available steps={sorted(self.bundle.steps.keys())}" + ) + step_routes = self.bundle.steps[step_index] + self._active_step_index = step_index + if isinstance(sample_index, list): + self._active_sample_index = next( + (index for index in sample_index if index is not None), + None, + ) + else: + self._active_sample_index = sample_index + self._active_micro_order = None + self._active_step_routes = step_routes + for local_router_key in sorted(self._local_router_keys): + if local_router_key not in step_routes.routers: + raise RuntimeError( + "Replay bundle step is missing local router key: " + f"step={step_index}, router='{local_router_key}'" + ) + self._router_call_cursors = {} + self._router_call_sequences = {} + local_call_keys = self._build_local_call_keys( + sample_index=sample_index, + ) + for router_key in sorted(self._local_router_keys): + router_calls = step_routes.routers[router_key].calls + if all( + self._router_call_key(route) is not None + for route in router_calls.values() + ): + calls_by_key: dict[tuple[str, int], list[int]] = defaultdict(list) + for call_index, route in sorted(router_calls.items()): + call_key = self._router_call_key(route) + assert call_key is not None + calls_by_key[call_key].append(call_index) + call_sequence = [] + for call_key in local_call_keys: + if call_key is None: + continue + matching_call_indices = calls_by_key.get(call_key) + if not matching_call_indices: + raise RuntimeError( + "Replay router call sequence is missing local micro metadata: " + f"step={step_index}, router='{router_key}', call_key={call_key}" + ) + call_sequence.extend(matching_call_indices) + else: + call_sequence = self._legacy_router_call_sequence( + step_index=step_index, + router_key=router_key, + sample_index=sample_index, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + total_calls=len(router_calls), + ) + self._router_call_cursors[router_key] = 0 + self._router_call_sequences[router_key] = call_sequence + self._global_uid_to_row_index = { + int(uid.item()): row_index + for row_index, uid in enumerate(step_routes.global_token_uids) + } + _ACTIVE_ROUTING_REPLAY_CONTROLLER = self + + def _build_local_call_keys( + self, + *, + sample_index: int | list[int | None], + ) -> list[tuple[str, int] | None]: + if not isinstance(sample_index, list): + if sample_index is None: + return [self._dummy_micro_call_key(local_micro_index=0)] + return [("sample", int(sample_index))] + return [ + self._sample_or_dummy_call_key( + global_sample_index=global_sample_index, + local_micro_index=local_micro_index, + ) + for local_micro_index, global_sample_index in enumerate(sample_index) + ] + + def _sample_or_dummy_call_key( + self, + *, + global_sample_index: int | None, + local_micro_index: int, + ) -> tuple[str, int] | None: + if global_sample_index is not None: + return ("sample", int(global_sample_index)) + return self._dummy_micro_call_key(local_micro_index=local_micro_index) + + def _dummy_micro_call_key( + self, + *, + local_micro_index: int, + ) -> tuple[str, int]: + from megatron.core import parallel_state as ps + + dp_rank = int(ps.get_data_parallel_rank()) + dp_world_size = int(ps.get_data_parallel_world_size()) + micro_slot = local_micro_index * dp_world_size + dp_rank + return ("dummy_micro_slot", micro_slot) + + @staticmethod + def _router_call_key(route: RouterCallRoute) -> tuple[str, int] | None: + if route.sample_index is not None: + return ("sample", int(route.sample_index)) + if route.micro_slot is not None: + return ("dummy_micro_slot", int(route.micro_slot)) + return None + + @staticmethod + def _legacy_router_call_sequence( + *, + step_index: int, + router_key: str, + sample_index: int | list[int | None], + global_grad_accumulation_sequences: int | None, + total_calls: int, + ) -> list[int]: + step_sample_count = global_grad_accumulation_sequences + if step_sample_count is None: + if isinstance(sample_index, list): + step_sample_count = len( + [index for index in sample_index if index is not None] + ) + else: + step_sample_count = 1 + if step_sample_count <= 0 or total_calls % step_sample_count != 0: + raise RuntimeError( + "Replay router call count is not divisible by step sample count: " + f"step={step_index}, router='{router_key}', " + f"total_calls={total_calls}, step_sample_count={step_sample_count}" + ) + calls_per_sample = total_calls // step_sample_count + step_base_sample_index = step_index * step_sample_count + if isinstance(sample_index, list): + call_sequence: list[int] = [] + for global_sample_index in sample_index: + if global_sample_index is None: + continue + sample_offset = int(global_sample_index) - step_base_sample_index + if sample_offset < 0 or sample_offset >= step_sample_count: + raise RuntimeError( + "Replay router call index is outside the step-local range: " + f"step={step_index}, router='{router_key}', " + f"global_sample_index={global_sample_index}, " + f"step_base_sample_index={step_base_sample_index}, " + f"step_sample_count={step_sample_count}" + ) + call_start = sample_offset * calls_per_sample + call_sequence.extend(range(call_start, call_start + calls_per_sample)) + return call_sequence + + sample_offset = int(sample_index) - step_base_sample_index + if sample_offset < 0 or sample_offset >= step_sample_count: + raise RuntimeError( + "Replay router call index is outside the step-local range: " + f"step={step_index}, router='{router_key}', " + f"sample_index={sample_index}, " + f"step_sample_count={step_sample_count}" + ) + call_start = sample_offset * calls_per_sample + return list(range(call_start, call_start + calls_per_sample)) + + def finalize_step(self) -> None: + global _ACTIVE_ROUTING_REPLAY_CONTROLLER + if self._active_step_routes is None: + raise RuntimeError("finalize_step called before set_step") + for router_key in sorted(self._local_router_keys): + consumed = self._router_call_cursors.get(router_key, 0) + call_sequence = self._router_call_sequences.get(router_key) + if call_sequence is None: + raise RuntimeError( + "Routing replay call sequence missing for router key: " + f"step={self._active_step_index}, router='{router_key}'" + ) + if consumed != len(call_sequence): + raise RuntimeError( + "Routing replay step consumption mismatch: " + f"step={self._active_step_index}, router='{router_key}', " + f"consumed={consumed}, expected={len(call_sequence)}" + ) + self._active_step_index = None + self._active_sample_index = None + self._active_step_routes = None + self._router_call_cursors = {} + self._router_call_sequences = {} + self._global_uid_to_row_index = {} + self._active_micro_order = None + if _ACTIVE_ROUTING_REPLAY_CONTROLLER is self: + _ACTIVE_ROUTING_REPLAY_CONTROLLER = None + + def get_route_for_router( + self, + *, + router_key: str, + logits: torch.Tensor, + sequence_parallel: bool, + context_parallel_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + step_routes = self._active_step_routes + if step_routes is None: + raise RuntimeError( + "Routing replay get_route_for_router called before set_step" + ) + call_cursor = self._router_call_cursors.get(router_key, 0) + call_sequence = self._router_call_sequences.get(router_key) + if call_sequence is None: + raise RuntimeError( + "Routing replay call sequence missing for router key: " + f"step={self._active_step_index}, router='{router_key}'" + ) + router_calls = step_routes.routers[router_key].calls + if call_cursor >= len(call_sequence): + raise RuntimeError( + "Routing replay call cursor exceeded local call sequence: " + f"step={self._active_step_index}, router='{router_key}', " + f"call_cursor={call_cursor}, sequence_length={len(call_sequence)}" + ) + route = router_calls[call_sequence[call_cursor]] + self._router_call_cursors[router_key] = call_cursor + 1 + + num_local_tokens = int(logits.shape[0]) + num_experts = int(logits.shape[1]) + + local_uids = self.local_token_indexer.build_local_token_uids( + global_token_uids=step_routes.global_token_uids, + num_local_tokens=num_local_tokens, + sequence_parallel=sequence_parallel, + context_parallel_size=context_parallel_size, + ) + row_index_tensor = torch.tensor( + [self._global_uid_to_row_index[int(uid)] for uid in local_uids.tolist()], + dtype=torch.int64, + ) + + local_indices = route.expert_indices.index_select(0, row_index_tensor) + local_probs = route.expert_probs.index_select(0, row_index_tensor) + local_mask = route.expert_mask.index_select(0, row_index_tensor) + + probs = torch.zeros( + (num_local_tokens, num_experts), + dtype=logits.dtype, + device=logits.device, + ) + routing_map = torch.zeros( + (num_local_tokens, num_experts), + dtype=torch.bool, + device=logits.device, + ) + + if local_indices.numel() > 0: + indices_device = local_indices.to(device=logits.device, dtype=torch.long) + probs_device = local_probs.to(device=logits.device, dtype=logits.dtype) + mask_device = local_mask.to(device=logits.device, dtype=torch.bool) + row_index_device = ( + torch.arange(num_local_tokens, device=logits.device) + .unsqueeze(1) + .expand_as(indices_device) + ) + + selected_rows = row_index_device[mask_device] + selected_cols = indices_device[mask_device] + selected_probs = probs_device[mask_device] + + if selected_rows.numel() > 0: + probs[selected_rows, selected_cols] = selected_probs + routing_map[selected_rows, selected_cols] = True + + return probs, routing_map + + +def _compact_route_from_dense( + probs_2d: torch.Tensor, + routing_map_2d: torch.Tensor, +) -> RouterCallRoute: + num_tokens, num_experts = probs_2d.shape + if num_tokens == 0: + return RouterCallRoute( + expert_indices=torch.zeros((0, 0), dtype=torch.int32), + expert_probs=torch.zeros((0, 0), dtype=torch.float32), + expert_mask=torch.zeros((0, 0), dtype=torch.bool), + num_experts=num_experts, + ) + + max_topk = int(routing_map_2d.sum(dim=1).max().item()) + expert_indices = torch.zeros((num_tokens, max_topk), dtype=torch.int32) + expert_probs = torch.zeros((num_tokens, max_topk), dtype=torch.float32) + expert_mask = torch.zeros((num_tokens, max_topk), dtype=torch.bool) + for token_index in range(num_tokens): + expert_ids = torch.nonzero( + routing_map_2d[token_index], as_tuple=False + ).flatten() + slot_count = int(expert_ids.numel()) + if slot_count == 0: + continue + expert_indices[token_index, :slot_count] = expert_ids.to(torch.int32) + expert_probs[token_index, :slot_count] = probs_2d[token_index, expert_ids].to( + torch.float32 + ) + expert_mask[token_index, :slot_count] = True + + return RouterCallRoute( + expert_indices=expert_indices, + expert_probs=expert_probs, + expert_mask=expert_mask, + num_experts=num_experts, + ) + + +def build_bundle_from_forward_trace_dir( + *, + traces_dir: str | Path, + num_steps: int, + topology: ParallelTopology, +) -> MoeRoutingReplayBundle: + """Build a replay bundle from saved forward traces for the correctness harness. + + This helper is intended for testing/oracle routing replay workflows and is not + part of inference routing capture/export. + """ + trace_dir = Path(traces_dir) + steps: dict[int, StepRoutes] = {} + router_keys_union: set[str] = set() + max_topk = 0 + + for step_index in range(num_steps): + trace_path = trace_dir / f"forward_trace_step_{step_index:03d}.pt" + if not trace_path.exists(): + raise FileNotFoundError( + f"Missing forward trace for step={step_index}: {trace_path}" + ) + step_trace: dict[str, list[dict[str, Any]]] = torch.load( + trace_path, map_location="cpu", weights_only=False + ) + + step_routers: dict[str, StepRouterRoutes] = {} + step_global_tokens: int | None = None + for module_name in sorted(step_trace.keys()): + if ROUTER_NAME_TOKEN not in module_name: + continue + router_key = build_router_key_from_trace_name(module_name) + router_calls: dict[int, RouterCallRoute] = {} + for call_index, call_entry in enumerate(step_trace[module_name]): + output = call_entry.get("output") + probs_2d, routing_map_2d = _extract_router_output_tensors(output) + compact_route = _compact_route_from_dense(probs_2d, routing_map_2d) + sample_index, micro_slot = _trace_call_route_metadata(call_entry) + compact_route.sample_index = sample_index + compact_route.micro_slot = micro_slot + router_calls[call_index] = compact_route + max_topk = max(max_topk, compact_route.max_topk) + token_count = compact_route.num_global_tokens + if step_global_tokens is None: + step_global_tokens = token_count + elif step_global_tokens != token_count: + raise RuntimeError( + "Inconsistent token count across routers within step: " + f"step={step_index}, expected={step_global_tokens}, got={token_count}, " + f"router='{router_key}', call={call_index}" + ) + + if not router_calls: + raise RuntimeError( + f"Router trace has no calls for module '{module_name}' at step={step_index}" + ) + step_routers[router_key] = StepRouterRoutes(calls=router_calls) + router_keys_union.add(router_key) + + if not step_routers: + raise RuntimeError( + f"No router traces found for step={step_index} in {trace_path}" + ) + if step_global_tokens is None: + raise RuntimeError( + f"Could not infer token count for step={step_index} from router traces" + ) + global_token_uids = torch.arange(step_global_tokens, dtype=torch.int64) + steps[step_index] = StepRoutes( + routers=step_routers, + global_token_uids=global_token_uids, + ) + + router_keys = sorted(router_keys_union) + for step_index, step_routes in steps.items(): + if set(step_routes.routers.keys()) != set(router_keys): + raise RuntimeError( + f"Step {step_index} router keys differ from global set: " + f"step_keys={sorted(step_routes.routers.keys())}, router_keys={router_keys}" + ) + + return MoeRoutingReplayBundle( + format_version=ROUTER_KEY_FORMAT_VERSION, + topology=topology, + num_steps=num_steps, + max_topk=max_topk, + router_keys=router_keys, + steps=steps, + ) diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 8ed6b82c..42ec4f9a 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -12,7 +12,7 @@ from peft.tuners.lora.config import LoraConfig from pydantic import BaseModel from safetensors import safe_open -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file import torch from vllm import AsyncEngineArgs from vllm.lora.request import LoRARequest @@ -36,6 +36,8 @@ class MegatronTrainingJob(BaseModel): disk_packed_tensors: DiskPackedTensors config: types.TrainConfig experimental_config: dev.TrainConfig + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True @dataclass @@ -241,12 +243,19 @@ async def train( for job_name in os.listdir(jobs_dir): if job_name.endswith(".json"): os.remove(os.path.join(jobs_dir, job_name)) + if _config.get("moe_routing_replay_bundle") is not None: + raise RuntimeError( + "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " + "MegatronService subprocess jobs must use moe_routing_replay_path." + ) job = MegatronTrainingJob( lora_path=lora_path, optimizer_state_path=self._optimizer_state_path, disk_packed_tensors=disk_packed_tensors, config=config, experimental_config=_config, + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True), ) job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") with open(job_path, "w") as f: @@ -311,26 +320,91 @@ def _merge_lora_adapter(self, lora_path: str) -> None: if not shard_filenames: return - adapter_model_path = base_dir / "adapter_model.safetensors" - sharded_tensors: dict[str, list[torch.Tensor]] = {} + shard_files_by_suffix = { + path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path + for path in shard_filenames + } + manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json")) + manifest_files_by_suffix = { + path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path + for path in manifest_filenames + } - for filename in shard_filenames: - with safe_open(filename, framework="pt") as file: - for key in file.keys(): - tensor = file.get_tensor(key) - sharded_tensors.setdefault(key, []).append(tensor) + if set(shard_files_by_suffix) != set(manifest_files_by_suffix): + raise RuntimeError( + "Shard/manifest coverage mismatch: " + f"shards={sorted(shard_files_by_suffix)}, " + f"manifests={sorted(manifest_files_by_suffix)}" + ) - adapter_model: dict[str, torch.Tensor] = {} - if adapter_model_path.exists(): - adapter_model = load_file(adapter_model_path) + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {} + for suffix in sorted(shard_files_by_suffix): + shard_path = shard_files_by_suffix[suffix] + manifest_path = manifest_files_by_suffix[suffix] + with open(manifest_path, "r", encoding="utf-8") as manifest_file: + shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file) - for key, tensors in sharded_tensors.items(): - tensor = torch.cat(tensors, dim=1 if "lora_A" in key else 0) + with safe_open(shard_path, framework="pt") as file: + shard_tensors = {key: file.get_tensor(key) for key in file.keys()} + + if set(shard_tensors) != set(shard_manifest): + raise RuntimeError( + f"Tensor/manifest key mismatch for shard suffix={suffix}: " + f"tensor_keys={sorted(shard_tensors)}, " + f"manifest_keys={sorted(shard_manifest)}" + ) + + for key, tensor in shard_tensors.items(): + entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor)) + + adapter_model: dict[str, torch.Tensor] = {} + for key, key_entries in entries_by_key.items(): + first_manifest = key_entries[0][0] + sharded = bool(first_manifest["sharded"]) + shard_world_size = int(first_manifest["shard_world_size"]) + + for manifest_entry, _tensor in key_entries: + if bool(manifest_entry["sharded"]) != sharded: + raise RuntimeError(f"Inconsistent sharded flag for key={key}") + if int(manifest_entry["shard_world_size"]) != shard_world_size: + raise RuntimeError(f"Inconsistent shard world size for key={key}") + + if not sharded: + if len(key_entries) != 1: + raise RuntimeError( + f"Replicated key={key} expected 1 shard, got {len(key_entries)}" + ) + tensor = key_entries[0][1] + else: + shard_rank_to_tensor: dict[int, torch.Tensor] = {} + for manifest_entry, shard_tensor in key_entries: + shard_rank = int(manifest_entry["shard_rank"]) + if shard_rank in shard_rank_to_tensor: + raise RuntimeError( + f"Duplicate shard_rank={shard_rank} for key={key}" + ) + shard_rank_to_tensor[shard_rank] = shard_tensor + + expected_shard_ranks = set(range(shard_world_size)) + if set(shard_rank_to_tensor.keys()) != expected_shard_ranks: + raise RuntimeError( + f"Shard rank coverage mismatch for key={key}: " + f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor.keys())}" + ) + + ordered_shards = [ + shard_rank_to_tensor[i] for i in range(shard_world_size) + ] + concat_dim = 1 if "lora_A" in key else 0 + tensor = torch.cat(ordered_shards, dim=concat_dim) adapter_model[key] = tensor + adapter_model_path = base_dir / "adapter_model.safetensors" save_file(adapter_model, adapter_model_path) for filename in shard_filenames: filename.unlink() + for filename in manifest_filenames: + filename.unlink() @cached_property def llm(self) -> asyncio.Task[AsyncLLM]: diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 02e3b7cd..b3da7e24 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -20,33 +20,73 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: import math import shutil import time -from typing import Any, cast +from typing import Any, Callable, cast from megatron.core import parallel_state as ps from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer.module import MegatronModule -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from safetensors.torch import load_file, save_file import torch from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir from art import dev, types from art.loss import loss_fn, shift_tensor +from art.megatron.finalize_grads import finalize_model_grads_extended from art.megatron.flex_attention import create_shared_prefix_attention_state from art.megatron.lora import apply_lora_adapters from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu from art.megatron.provider import get_provider +from art.megatron.routing_replay import ( + MoeRoutingReplayBundle, + MoeRoutingReplayController, +) from art.preprocessing.pack import ( DiskPackedTensors, PackedTensors, packed_tensors_from_dir, ) -provider = get_provider( - os.environ.get("MODEL_IDENTIFIER", "Qwen/Qwen3-30B-A3B-Instruct-2507") -) +DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507" + + +class TrainingJob(BaseModel): + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dev.TrainConfig + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True + + +class TrainingRuntime(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + provider: Any + model: list[MegatronModule] + optimizer: Any + rank: int + world_size: int + moe_routing_replay_controller: MoeRoutingReplayController | None = None + + +class TrainStepResult(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + reduced_loss: torch.Tensor + probs_corr: float + new_logprobs: torch.Tensor + update_successful: bool + grad_norm: float + num_zeros_in_grad: int | None + + +def print0(rank: int, *values: Any) -> None: + if rank == 0: + print(*values) def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: @@ -56,275 +96,603 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: return model_chunks -provider.register_pre_wrap_hook(lambda x: freeze_model(x) or x) - -model = provider.provide_distributed_model( - ddp_config=DistributedDataParallelConfig(), - data_parallel_random_init=False, -) - -rank = torch.distributed.get_rank() # ty:ignore[possibly-missing-attribute] -world_size = torch.distributed.get_world_size() # ty:ignore[possibly-missing-attribute] - -if rank == 0: - print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) - print("Resolved inductor cache_dir():", inductor_cache_dir()) - print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) - -for module in model: - while not isinstance(module, GPTModel) and hasattr(module, "module"): - module = module.module - if isinstance(module, GPTModel): - _preprocess = module._preprocess +def _install_gpt_preprocess_hook(model_chunks: list[MegatronModule]) -> None: + for chunk in model_chunks: + module: Any = chunk + while not isinstance(module, GPTModel) and hasattr(module, "module"): + module = module.module + if not isinstance(module, GPTModel): + continue + preprocess = module._preprocess - def _preprocess_hook(*args, **kwargs): + def preprocess_hook(*args, _preprocess=preprocess, **kwargs): preproc_output = list(_preprocess(*args, **kwargs)) - preproc_output[0].requires_grad = True # type: ignore - table = preproc_output[1] # [S,B,1,D] type: ignore - D = table.size(-1) # type: ignore - table_flat = table.view(table.size(0), D) # type: ignore - # position_ids: [B, S] - position_ids = kwargs["position_ids"] - B, S = position_ids.shape - gathered = table_flat.index_select(0, position_ids.reshape(-1)) # [B*S, D] - gathered = gathered.view(B, S, D).permute(1, 0, 2).contiguous() # [S, B, D] + preproc_output[0].requires_grad = True # type: ignore[index] + table = preproc_output[1] # [S, B, 1, D] # type: ignore[index] + embedding_dim = table.size(-1) + table_flat = table.view(table.size(0), embedding_dim) + position_ids = kwargs["position_ids"] # [B, S] + batch_size, sequence_length = position_ids.shape + gathered = table_flat.index_select(0, position_ids.reshape(-1)) + gathered = ( + gathered.view(batch_size, sequence_length, embedding_dim) + .permute(1, 0, 2) + .contiguous() + ) preproc_output[1] = gathered.unsqueeze(2) # [S, B, 1, D] return tuple(preproc_output) - module._preprocess = _preprocess_hook # type: ignore[attr-defined] + module._preprocess = preprocess_hook # type: ignore[attr-defined] -apply_lora_adapters(model, provider) - -optimizer = get_megatron_optimizer( - config=OptimizerConfig( +def _default_optimizer_config() -> OptimizerConfig: + return OptimizerConfig( bf16=True, lr=5e-6, adam_beta1=0.9, adam_beta2=0.99, clip_grad=0.1, weight_decay=0.1, - ), - model_chunks=model, # type: ignore -) - -if rank == 0: - # Print the number of parameters in the optimizer, nicely formatted - num_params = sum( - p.numel() - for group in optimizer.param_groups - if not group["is_decoupled_lr"] - for p in group["params"] + adam_eps=1e-13, ) - print(f"Number of parameters in optimizer: {num_params:,}") - total_params = sum(p.numel() for m in model for p in m.parameters()) - percent = (num_params / total_params) * 100 if total_params > 0 else 0 - print(f"Optimizer parameters as percent of total: {percent:0.2f}%") -class TrainingJob(BaseModel): - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig +def configure_moe_routing_replay( + runtime: TrainingRuntime, + *, + replay_bundle_path: str | None = None, + replay_bundle: MoeRoutingReplayBundle | None = None, + strict: bool = True, +) -> None: + if runtime.moe_routing_replay_controller is not None: + runtime.moe_routing_replay_controller.remove_router_patches() + runtime.moe_routing_replay_controller = None + + if replay_bundle is not None and replay_bundle_path is not None: + raise RuntimeError( + "Provide either replay_bundle_path or replay_bundle, not both" + ) + if replay_bundle is None and replay_bundle_path is None: + return + + if replay_bundle is None: + if replay_bundle_path is None: + raise RuntimeError( + "replay_bundle_path is required when replay_bundle is None" + ) + replay_bundle = MoeRoutingReplayBundle.from_dir(replay_bundle_path) + + controller = MoeRoutingReplayController( + bundle=replay_bundle, + strict=strict, + ) + controller.install_router_patches(runtime.model) + runtime.moe_routing_replay_controller = controller + + +def build_training_runtime( + *, + model_identifier: str | None = None, + provider_torch_dtype: torch.dtype = torch.bfloat16, + provider_configure: Callable[[Any], None] | None = None, + optimizer_config: OptimizerConfig | None = None, + moe_routing_replay_path: str | None = None, + moe_routing_replay_bundle: MoeRoutingReplayBundle | None = None, + moe_routing_replay_strict: bool = True, + print_env: bool = True, + print_optimizer_stats: bool = True, +) -> TrainingRuntime: + provider = get_provider( + model_identifier + or os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER), + torch_dtype=provider_torch_dtype, + ) + if provider_configure is not None: + provider_configure(provider) + provider.register_pre_wrap_hook(freeze_model) + provider.register_pre_wrap_hook( + lambda chunks: apply_lora_adapters(chunks, provider) + ) + model = cast( + list[MegatronModule], + provider.provide_distributed_model( + ddp_config=DistributedDataParallelConfig( + # memory and comm for this should be small anyways cause lora + grad_reduce_in_fp32=True, + average_in_collective=False, + ), + data_parallel_random_init=False, + ), + ) -def print0(*values: Any) -> None: - if rank == 0: - print(*values) + if not torch.distributed.is_initialized(): # ty: ignore[possibly-missing-attribute] + raise RuntimeError( + "torch.distributed must be initialized before building runtime" + ) + rank = torch.distributed.get_rank() # ty: ignore[possibly-missing-attribute] + world_size = torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] + + if rank == 0 and print_env: + print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) + print("Resolved inductor cache_dir():", inductor_cache_dir()) + print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) + _install_gpt_preprocess_hook(model) -offload_state = OffloadState() + optimizer = get_megatron_optimizer( + config=optimizer_config or _default_optimizer_config(), + model_chunks=model, + ) + if rank == 0 and print_optimizer_stats: + num_params = sum( + p.numel() + for group in optimizer.param_groups + if not group["is_decoupled_lr"] + for p in group["params"] + ) + print(f"Number of parameters in optimizer: {num_params:,}") + total_params = sum(p.numel() for module in model for p in module.parameters()) + percent = (num_params / total_params) * 100 if total_params > 0 else 0 + print(f"Optimizer parameters as percent of total: {percent:0.2f}%") + + runtime = TrainingRuntime( + provider=provider, + model=model, + optimizer=optimizer, + rank=rank, + world_size=world_size, + ) + configure_moe_routing_replay( + runtime, + replay_bundle_path=moe_routing_replay_path, + replay_bundle=moe_routing_replay_bundle, + strict=moe_routing_replay_strict, + ) + return runtime -offload_to_cpu(model, optimizer, rank, offload_state) -while True: - torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - job_names = sorted( - job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") +def iter_modules(model_chunks: list[MegatronModule]) -> Any: + for chunk in model_chunks: + for module in chunk.modules(): + yield module + + +def load_adapter_into_model( + model_chunks: list[MegatronModule], + adapter_model: dict[str, torch.Tensor], + optimizer: Any | None = None, +) -> None: + with torch.no_grad(): + for module in iter_modules(model_chunks): + if hasattr(module, "load_lora"): + module.load_lora(adapter_model) # type: ignore[attr-defined] + + if optimizer is None: + return + optimizer.reload_model_params() + + +def collect_sharded_lora_state( + model_chunks: list[MegatronModule], + adapter_model: dict[str, torch.Tensor], +) -> tuple[dict[str, torch.Tensor], dict[str, dict[str, Any]]]: + sharded_state_dict: dict[str, torch.Tensor] = {} + sharded_state_manifest: dict[str, dict[str, Any]] = {} + for module in iter_modules(model_chunks): + if hasattr(module, "sharded_lora_state_dict"): + module_sharded_lora_state_dict: dict[str, torch.Tensor] = ( + module.sharded_lora_state_dict() # type: ignore[attr-defined] + ) + for key, value in module_sharded_lora_state_dict.items(): + target_dtype = ( + adapter_model[key].dtype if key in adapter_model else value.dtype + ) + sharded_state_dict[key] = value.to(target_dtype) + if hasattr(module, "sharded_lora_manifest"): + module_sharded_lora_manifest: dict[str, dict[str, Any]] = ( + module.sharded_lora_manifest() # type: ignore[attr-defined] + ) + sharded_state_manifest.update(module_sharded_lora_manifest) + return sharded_state_dict, sharded_state_manifest + + +@torch.no_grad() +def select_indexed_inputs(packed_tensors: PackedTensors, index: int) -> PackedTensors: + return PackedTensors( # type: ignore[call-arg] + **{ + key: value[index : index + 1] + for key, value in packed_tensors.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=[None], + image_grid_thw=[None], ) - if not job_names: - time.sleep(1) - continue - - wake_lock_path = "/tmp/megatron_vllm_waking" - while os.path.exists(wake_lock_path): - time.sleep(0.2) - - reload_to_gpu(model, optimizer, rank, offload_state) - - job_name = job_names[0] - job_path = os.path.join(jobs_dir, job_name) - with open(job_path, "rb") as f: - job = TrainingJob.model_validate_json(f.read()) - config = job.config - experimental_config = job.experimental_config - print0("Loaded job from", job_path) - print0("Job:", job) - adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" - if os.path.exists(adapter_model_path): - print0("Loading adapter model from", adapter_model_path) - adapter_model = load_file(adapter_model_path) - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "load_lora"): - module.load_lora(adapter_model) # type: ignore - else: - print0("No adapter model found at", adapter_model_path) - adapter_model = {} - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "reset_lora_parameters"): - module.reset_lora_parameters() # type: ignore - optimizer_shard_path = os.path.join( - job.optimizer_state_path, f"{rank + 1:02d}-of-{world_size:02d}.pt" + + +@torch.no_grad() +def _clone_packed_tensors(inputs: PackedTensors) -> PackedTensors: + return PackedTensors( # type: ignore[call-arg] + **{ + key: value.clone() + for key, value in inputs.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=[None], + image_grid_thw=[None], ) - if os.path.exists(optimizer_shard_path): - print( - "Loading optimizer state from", - optimizer_shard_path, - ) - optimizer.load_state_dict(torch.load(optimizer_shard_path)) - else: - # No checkpoint for this run; reset optimizer state to avoid cross-run leakage - print( - "No optimizer state found at", - optimizer_shard_path, - "— resetting optimizer for new run", + + +@torch.no_grad() +def _zero_contribution_inputs(template: PackedTensors) -> PackedTensors: + dummy = _clone_packed_tensors(template) + dummy["assistant_mask"].zero_() + return dummy + + +def resolve_local_grad_accumulation_sequences( + global_grad_accumulation_sequences: int, +) -> int: + dp_world_size = ps.get_data_parallel_world_size() + if ( + global_grad_accumulation_sequences <= 0 + or global_grad_accumulation_sequences % dp_world_size != 0 + ): + raise RuntimeError( + "Invalid global grad accumulation / DP world size combination: " + f"global_grad_accumulation_sequences={global_grad_accumulation_sequences}, " + f"dp_world_size={dp_world_size}" ) - optimizer.optimizer.state.clear() - optimizer.reload_model_params() - print0("Loading packed tensors from", job.disk_packed_tensors["dir"]) - packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) - num_sequences = job.disk_packed_tensors["num_sequences"] + return global_grad_accumulation_sequences // dp_world_size + + +def build_micro_sample_indices( + step_index: int, + num_sequences: int, + global_grad_accumulation_sequences: int, +) -> list[int | None]: dp_rank = ps.get_data_parallel_rank() dp_world_size = ps.get_data_parallel_world_size() - num_indices = math.ceil(num_sequences / dp_world_size) - indices = list(range(dp_rank, num_sequences, dp_world_size)) - if not indices: - indices = [dp_rank % num_sequences] - # pad indices by repeating & slicing to target length - repeat = math.ceil(num_indices / len(indices)) - indices = (indices * repeat)[:num_indices] - for index in indices: - inputs = PackedTensors( # type: ignore - **{ - key: value[index : index + 1] - for key, value in packed_tensors.items() - if isinstance(value, torch.Tensor) - }, - pixel_values=[None], - image_grid_thw=[None], + local_grad_accumulation_sequences = resolve_local_grad_accumulation_sequences( + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + base_global_sample_index = step_index * global_grad_accumulation_sequences + global_step_indices: list[int | None] = [] + for offset in range(global_grad_accumulation_sequences): + global_sample_index = base_global_sample_index + offset + global_step_indices.append( + global_sample_index if global_sample_index < num_sequences else None ) - ref_logprobs = None - device = next(model[0].parameters()).device - for key, value in inputs.items(): - if isinstance(value, torch.Tensor): - inputs[key] = value.to(device) # type: ignore - attention_state = create_shared_prefix_attention_state( # should happen after group_ids is moved to device - group_ids=inputs["group_ids"], - parent_ids=inputs["parent_ids"], + return [ + global_step_indices[offset * dp_world_size + dp_rank] + for offset in range(local_grad_accumulation_sequences) + ] + + +def select_micro_inputs( + packed_tensors: PackedTensors, + sample_indices: list[int | None], + zero_template: PackedTensors, +) -> list[PackedTensors]: + return [ + _clone_packed_tensors(zero_template) + if sample_index is None + else select_indexed_inputs(packed_tensors, sample_index) + for sample_index in sample_indices + ] + + +def _move_inputs_to_device(inputs: PackedTensors, device: torch.device) -> None: + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.to(device) # type: ignore[index] + + +def _optimizer_step( + optimizer: Any, + learning_rate: float, +) -> tuple[bool, float, int | None]: + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + update_successful, grad_norm, num_zeros_in_grad = cast( + tuple[bool, float, int | None], optimizer.step() + ) + optimizer.zero_grad() + return update_successful, grad_norm, num_zeros_in_grad + + +def _reduce_loss( + loss: torch.Tensor, + op: Any = torch.distributed.ReduceOp.AVG, # ty: ignore[possibly-missing-attribute] + group: Any | None = None, +) -> torch.Tensor: + reduced_loss = loss.detach().clone() + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + reduced_loss, + op=op, + group=group, + ) + return reduced_loss + + +def _count_trainable_tokens(inputs: PackedTensors) -> float: + assistant_mask = shift_tensor(inputs["assistant_mask"], False) + return float(assistant_mask.sum().item()) + + +def _local_trainable_token_count_tensor( + micro_inputs: list[PackedTensors], + device: torch.device, +) -> torch.Tensor: + local_token_total = sum(_count_trainable_tokens(micro) for micro in micro_inputs) + return torch.tensor([local_token_total], device=device, dtype=torch.float32) + + +def run_training_step( + *, + model_chunks: list[MegatronModule], + optimizer: Any, + learning_rate: float, + inputs: PackedTensors | list[PackedTensors], + config: types.TrainConfig, + experimental_config: dev.TrainConfig, + step_index: int, + sample_index: int | list[int | None], + ref_logprobs: torch.Tensor | None = None, + moe_routing_replay_controller: MoeRoutingReplayController | None = None, +) -> TrainStepResult: + micro_inputs = inputs if isinstance(inputs, list) else [inputs] + if not micro_inputs: + raise ValueError("run_training_step requires at least one packed sequence") + + if isinstance(sample_index, list): + if len(sample_index) != len(micro_inputs): + raise ValueError( + "sample_index list length must match number of micro inputs: " + f"{len(sample_index)} != {len(micro_inputs)}" + ) + micro_sample_indices = sample_index + else: + assert len(micro_inputs) == 1 + micro_sample_indices = [sample_index] + + if moe_routing_replay_controller is not None: + moe_routing_replay_controller.set_step( + step_index=step_index, + sample_index=micro_sample_indices, + global_grad_accumulation_sequences=config.grad_accumulation_sequences, + ) + + device = next(model_chunks[0].parameters()).device + + for chunk in model_chunks: + chunk.zero_grad_buffer() # ty: ignore[call-non-callable] + + micro_count = len(micro_inputs) + raw_loss_sum: torch.Tensor | None = None + num_tokens = _local_trainable_token_count_tensor(micro_inputs, device=device) + probs_corr_sum = 0.0 + new_logprobs: torch.Tensor | None = None + + for micro in micro_inputs: + _move_inputs_to_device(micro, device) + attention_state = create_shared_prefix_attention_state( + group_ids=micro["group_ids"], + parent_ids=micro["parent_ids"], ) - # Megatron full-layer recompute saves positional tensor args, so keep a tiny - # placeholder Tensor here and pass flex BlockMask state via attention_bias. attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) - new_logprobs: torch.Tensor = -model[0]( - input_ids=inputs["tokens"], - position_ids=inputs["input_pos"], + + new_logprobs = -model_chunks[0]( + input_ids=micro["tokens"], + position_ids=micro["input_pos"], attention_mask=attention_mask, - labels=shift_tensor(inputs["tokens"], 0), + labels=shift_tensor(micro["tokens"], 0), extra_block_kwargs={"attention_bias": attention_state}, ) - loss = loss_fn( - inputs, # type: ignore + + loss_info = loss_fn( + micro, # ty: ignore[invalid-argument-type] new_logprobs, ref_logprobs, None, experimental_config, + reduction="sum", ) - probs_corr = loss.probs_corr.item() - print0("Correlation between old and new probabilities:", probs_corr) - loss = loss.mean_policy_loss - loss.backward() - # Reduce LoRA grads - start = time.perf_counter() - num_grads = 0 - for chunk in model: - for param in chunk.parameters(): - if param.grad is None: - continue - torch.distributed.all_reduce( # ty:ignore[possibly-missing-attribute] - param.grad, - op=torch.distributed.ReduceOp.AVG, # ty:ignore[possibly-missing-attribute] - group=ps.get_data_parallel_group(), - ) - num_grads += 1 - print0( - f"Reduced {num_grads} LoRA grads in {(time.perf_counter() - start) * 1e3:.1f} ms" + micro_loss = loss_info.policy_loss + micro_loss.backward() + probs_corr_sum += float(loss_info.probs_corr.item()) + detached_micro_loss = micro_loss.detach() + if raw_loss_sum is None: + raw_loss_sum = detached_micro_loss + else: + raw_loss_sum = raw_loss_sum + detached_micro_loss + + if new_logprobs is None or raw_loss_sum is None: + raise RuntimeError("run_training_step did not produce outputs") + + finalize_model_grads_extended(model_chunks, num_tokens=num_tokens) + update_successful, grad_norm, num_zeros_in_grad = _optimizer_step( + optimizer, + learning_rate, + ) + global_num_tokens = max(num_tokens.item(), 1.0) + reduced_loss = _reduce_loss( + raw_loss_sum / global_num_tokens, + op=torch.distributed.ReduceOp.SUM, # ty: ignore[possibly-missing-attribute] + group=ps.get_data_parallel_group(with_context_parallel=True), + ) + + if moe_routing_replay_controller is not None: + moe_routing_replay_controller.finalize_step() + + return TrainStepResult( + reduced_loss=reduced_loss, + probs_corr=probs_corr_sum / micro_count, + new_logprobs=new_logprobs, + update_successful=update_successful, + grad_norm=grad_norm, + num_zeros_in_grad=num_zeros_in_grad, + ) + + +def _run_service_loop(runtime: TrainingRuntime) -> None: + offload_state = OffloadState() + offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) + + while True: + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + jobs_dir = "/tmp/megatron_training_jobs" + os.makedirs(jobs_dir, exist_ok=True) + job_names = sorted( + job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") ) - for param_group in optimizer.param_groups: - param_group["lr"] = config.learning_rate - update_successful, grad_norm, num_zeros_in_grad = cast( - tuple[bool, float, int | None], optimizer.step() + if not job_names: + time.sleep(1) + continue + + wake_lock_path = "/tmp/megatron_vllm_waking" + while os.path.exists(wake_lock_path): + time.sleep(0.2) + + reload_to_gpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) + + job_name = job_names[0] + job_path = os.path.join(jobs_dir, job_name) + with open(job_path, "rb") as handle: + job = TrainingJob.model_validate_json(handle.read()) + config = job.config + experimental_config = job.experimental_config + + configure_moe_routing_replay( + runtime, + replay_bundle_path=job.moe_routing_replay_path, + strict=job.moe_routing_replay_strict, ) - optimizer.zero_grad() - - # Mean reduce loss across all ranks for logging - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) # ty:ignore[possibly-missing-attribute] - - if rank == 0: - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_msg = json.dumps( - { - "loss/train": loss.item(), - "loss/grad_norm": grad_norm, - "probs_corr": probs_corr, - } - ) - print("Logging", log_msg) - log_file.write(log_msg + "\n") - sharded_state_dict = {} - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "sharded_lora_state_dict"): - module_sharded_lora_state_dict: dict[str, torch.Tensor] = ( - module.sharded_lora_state_dict() # type: ignore - ) - for key, value in module_sharded_lora_state_dict.items(): - target_dtype = ( - adapter_model[key].dtype - if key in adapter_model - else value.dtype + print0(runtime.rank, "Loaded job from", job_path) + print0(runtime.rank, "Job:", job) + + adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" + if not os.path.exists(adapter_model_path): + raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") + print0(runtime.rank, "Loading adapter model from", adapter_model_path) + adapter_model = load_file(adapter_model_path) + load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) + + optimizer_shard_path = os.path.join( + job.optimizer_state_path, + f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", + ) + if os.path.exists(optimizer_shard_path): + print("Loading optimizer state from", optimizer_shard_path) + runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path)) + else: + print( + "No optimizer state found at", + optimizer_shard_path, + "- resetting optimizer for new run", + ) + runtime.optimizer.optimizer.state.clear() + runtime.optimizer.reload_model_params() + + print0( + runtime.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"] + ) + packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + template = select_indexed_inputs(packed_tensors, 0) + zero_template = _zero_contribution_inputs(template) + num_sequences = job.disk_packed_tensors["num_sequences"] + global_grad_accumulation_sequences = config.grad_accumulation_sequences + num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) + for step_index in range(num_steps): + micro_indices = build_micro_sample_indices( + step_index=step_index, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + micro_inputs = select_micro_inputs( + packed_tensors, micro_indices, zero_template + ) + step_result = run_training_step( + model_chunks=runtime.model, + optimizer=runtime.optimizer, + learning_rate=config.learning_rate, + inputs=micro_inputs, + config=config, + experimental_config=experimental_config, + ref_logprobs=None, + step_index=step_index, + sample_index=micro_indices, + moe_routing_replay_controller=runtime.moe_routing_replay_controller, + ) + print0( + runtime.rank, + "Correlation between old and new probabilities:", + step_result.probs_corr, + ) + + if runtime.rank == 0: + with open( + "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" + ) as log_file: + log_msg = json.dumps( + { + "loss": step_result.reduced_loss.item(), + "grad_norm": step_result.grad_norm, + "probs_corr": step_result.probs_corr, + } ) - sharded_state_dict[key] = value.to(target_dtype) - shard_path = os.path.join( - job.lora_path, - f"adapter_model-{rank + 1:02d}-of-{world_size:02d}.safetensors", + print("Logging", log_msg) + log_file.write(log_msg + "\n") + + sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state( + runtime.model, + adapter_model, + ) + shard_path = os.path.join( + job.lora_path, + f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors", + ) + manifest_path = os.path.join( + job.lora_path, + f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json", + ) + print("Saving adapter shard to", shard_path) + save_file(sharded_state_dict, shard_path) + print("Saving adapter shard manifest to", manifest_path) + with open(manifest_path, "w", encoding="utf-8") as manifest_file: + json.dump(sharded_state_manifest, manifest_file, sort_keys=True) + + print("Saving optimizer shard to", optimizer_shard_path) + os.makedirs(job.optimizer_state_path, exist_ok=True) + torch.save(runtime.optimizer.state_dict(), optimizer_shard_path) + + offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) + + del packed_tensors + del adapter_model + if "micro_inputs" in locals(): + del micro_inputs + gc.collect() + torch.cuda.empty_cache() + + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + if runtime.rank == 0: + os.remove(job_path) + with open( + "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" + ) as log_file: + log_file.write("all done\n") + shutil.rmtree(job.disk_packed_tensors["dir"]) + + +def main() -> None: + runtime = build_training_runtime( + model_identifier=os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER) ) - print("Saving adapter shard to", shard_path) - save_file(sharded_state_dict, shard_path) - print("Saving optimizer shard to", optimizer_shard_path) - os.makedirs(job.optimizer_state_path, exist_ok=True) - torch.save(optimizer.state_dict(), optimizer_shard_path) - offload_to_cpu(model, optimizer, rank, offload_state) - # Release mmap-backed packed tensor references on all ranks before rank0 cleanup. - del packed_tensors - del adapter_model - if "inputs" in locals(): - del inputs - gc.collect() - torch.cuda.empty_cache() - # Ensure all ranks have finished saving before signaling completion - torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] - if rank == 0: - os.remove(job_path) - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_file.write("all done\n") - shutil.rmtree(job.disk_packed_tensors["dir"]) + _run_service_loop(runtime) + + +if __name__ == "__main__": + main() diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 2eebafc4..90e704eb 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -80,7 +80,7 @@ def custom_loss_fn( for mask, lp in zip(masks, logprobs_list): logprobs[mask] = lp loss = loss_fn(inputs, logprobs.unsqueeze(0), None, None, _config) - return loss.mean_policy_loss, {"loss/train": loss.mean_policy_loss.item()} + return loss.policy_loss, {"loss/train": loss.policy_loss.item()} shifted_tokens = shift_tensor(packed_tensors["tokens"], 0) diff --git a/src/art/types.py b/src/art/types.py index 088041ad..f905d881 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -17,6 +17,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 kl_penalty_coef: float = 0.0 + grad_accumulation_sequences: int = pydantic.Field(default=1, ge=1) class TrainSFTConfig(pydantic.BaseModel): diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index f6c42a2c..e5d4b026 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -207,14 +207,14 @@ def compute_loss( ) trainer._metrics["train"]["loss/learning_rate"].append(config.learning_rate) - trainer._metrics["train"]["loss/train"].append(loss.mean_policy_loss.item()) - if loss.mean_entropy is not None: - trainer._metrics["train"]["loss/entropy"].append(loss.mean_entropy.item()) + trainer._metrics["train"]["loss/train"].append(loss.policy_loss.item()) + if loss.entropy is not None: + trainer._metrics["train"]["loss/entropy"].append(loss.entropy.item()) if loss.kl_policy_ref is not None: trainer._metrics["train"]["loss/kl_policy_ref"].append( loss.kl_policy_ref.item() ) - return loss.mean_policy_loss + return loss.policy_loss return compute_loss diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron_forward_trace.py new file mode 100644 index 00000000..98f43fc6 --- /dev/null +++ b/tests/integration/megatron_forward_trace.py @@ -0,0 +1,994 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Callable, cast + +import torch + +CAPTURE_NAME_TOKENS = ( + ".self_attention.linear_qkv", + ".self_attention.linear_qkv.q_proj_lora", + ".self_attention.linear_qkv.k_proj_lora", + ".self_attention.linear_qkv.v_proj_lora", + ".self_attention.linear_proj", + ".self_attention.linear_proj.lora", + ".mlp.router", + ".mlp.experts.linear_fc1", + ".mlp.experts.linear_fc1.gate_lora", + ".mlp.experts.linear_fc1.up_lora", + ".mlp.experts.linear_fc2", + ".mlp.experts.linear_fc2.lora", +) +ROUTER_NAME_TOKEN = ".mlp.router" +PRIMARY_OUTPUT_CANONICAL_KEY = "primary_output__is_canonical" + + +def _safe_int(value: Any, default: int = 0) -> int: + """Coerces scalar values to int for trace metadata.""" + try: + return int(value) + except Exception: + return default + + +def _safe_ps_stat(name: str, default: int) -> int: + """Reads one Megatron parallel-state integer when available.""" + try: + from megatron.core import parallel_state as ps + + getter = getattr(ps, name) + return _safe_int(getter(), default) + except Exception: + return default + + +def _rank_metadata() -> dict[str, int]: + """Builds lightweight distributed metadata for one trace call.""" + rank = 0 + world_size = 1 + if torch.distributed.is_initialized(): # ty: ignore[possibly-missing-attribute] + rank = _safe_int(torch.distributed.get_rank(), 0) # ty: ignore[possibly-missing-attribute] + world_size = _safe_int(torch.distributed.get_world_size(), 1) # ty: ignore[possibly-missing-attribute] + return { + "global_rank": rank, + "world_size": world_size, + "tp_rank": _safe_ps_stat("get_tensor_model_parallel_rank", 0), + "tp_world_size": _safe_ps_stat("get_tensor_model_parallel_world_size", 1), + "ep_rank": _safe_ps_stat("get_expert_model_parallel_rank", 0), + "ep_world_size": _safe_ps_stat("get_expert_model_parallel_world_size", 1), + "etp_rank": _safe_ps_stat("get_expert_tensor_parallel_rank", 0), + "etp_world_size": _safe_ps_stat("get_expert_tensor_parallel_world_size", 1), + "dp_rank": _safe_ps_stat("get_data_parallel_rank", 0), + "dp_world_size": _safe_ps_stat("get_data_parallel_world_size", 1), + "expert_dp_rank": _safe_ps_stat("get_expert_data_parallel_rank", 0), + "expert_dp_world_size": _safe_ps_stat("get_expert_data_parallel_world_size", 1), + } + + +def _extract_dp_slot_from_rank_meta(rank_meta: Any) -> tuple[int, int] | None: + """Returns one stable `(dp_rank, dp_world_size)` pair from merged rank metadata.""" + if isinstance(rank_meta, dict): + rank_meta = [rank_meta] + if not isinstance(rank_meta, list) or not rank_meta: + return None + dp_ranks = { + _safe_int(item.get("dp_rank"), 0) + for item in rank_meta + if isinstance(item, dict) and "dp_rank" in item + } + dp_world_sizes = { + _safe_int(item.get("dp_world_size"), 1) + for item in rank_meta + if isinstance(item, dict) and "dp_world_size" in item + } + if len(dp_ranks) != 1 or len(dp_world_sizes) != 1: + return None + return next(iter(dp_ranks)), next(iter(dp_world_sizes)) + + +def _trace_call_sort_key(call: dict[str, Any]) -> tuple[int, int]: + """Builds a stable micro identity for merged trace ordering.""" + sample_index = call.get("micro_sample_index") + if isinstance(sample_index, int): + return 0, int(sample_index) + micro_order = _safe_int(call.get("micro_order"), 0) + dp_slot = _extract_dp_slot_from_rank_meta(call.get("rank_meta")) + if dp_slot is None: + return 1, micro_order + dp_rank, dp_world_size = dp_slot + return 1, micro_order * dp_world_size + dp_rank + + +def _local_dummy_micro_slot(micro_order: int) -> int: + """Builds the stable dummy-micro slot used when one micro has no sample id.""" + dp_rank = _safe_ps_stat("get_data_parallel_rank", 0) + dp_world_size = _safe_ps_stat("get_data_parallel_world_size", 1) + return micro_order * dp_world_size + dp_rank + + +def _captured_output_sort_key( + sample_index: int | None, + micro_order: int, + micro_slot: int | None, +) -> tuple[int, int, int]: + """Builds the deterministic ordering used for captured root outputs.""" + if isinstance(sample_index, int): + return 0, int(sample_index), micro_order + return 1, _safe_int(micro_slot, micro_order), 0 + + +def _shard_world_size_for_domain(domain: Any) -> int: + """Returns shard-group world size for one LoRA shard domain.""" + if domain == "tp": + return _safe_ps_stat("get_tensor_model_parallel_world_size", 1) + if domain == "expert_tp": + return _safe_ps_stat("get_expert_tensor_parallel_world_size", 1) + return 1 + + +def _extract_primary_tensor(value: Any) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, dict): + for item in value.values(): + tensor = _extract_primary_tensor(item) + if tensor is not None: + return tensor + if isinstance(value, (list, tuple)): + for item in value: + tensor = _extract_primary_tensor(item) + if tensor is not None: + return tensor + return None + + +def _materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: + full_tensor = getattr(tensor, "full_tensor", None) + if callable(full_tensor): + tensor = cast(torch.Tensor, full_tensor()) + else: + to_local = getattr(tensor, "to_local", None) + if callable(to_local): + tensor = cast(torch.Tensor, to_local()) + else: + local_tensor = getattr(tensor, "_local_tensor", None) + if isinstance(local_tensor, torch.Tensor): + tensor = local_tensor + return tensor.detach().cpu() + + +def _materialize_trace_value(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return _materialize_tensor(value) + if isinstance(value, dict): + return {key: _materialize_trace_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_materialize_trace_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_materialize_trace_value(item) for item in value) + return value + + +def _extract_tensor_attr(value: Any, attr_name: str) -> Any: + if isinstance(value, torch.Tensor): + return getattr(value, attr_name, None) + if isinstance(value, dict): + for item in value.values(): + attr_value = _extract_tensor_attr(item, attr_name) + if attr_value is not None: + return attr_value + if isinstance(value, (list, tuple)): + for item in value: + attr_value = _extract_tensor_attr(item, attr_name) + if attr_value is not None: + return attr_value + return None + + +def _extract_router_topk(output: Any) -> tuple[torch.Tensor, torch.Tensor] | None: + if not isinstance(output, tuple) or len(output) < 2: + return None + probs = output[0] + routing_map = output[1] + if not isinstance(probs, torch.Tensor) or not isinstance(routing_map, torch.Tensor): + return None + probs = _materialize_tensor(probs.float()) + routing_map = _materialize_tensor(routing_map) + topk = int(routing_map.sum(dim=-1).max().item()) + if topk < 0: + raise RuntimeError(f"Invalid router topk={topk}") + if topk == 0: + topk_scores = probs.new_zeros((probs.shape[0], 0)) + topk_ids = torch.zeros((probs.shape[0], 0), dtype=torch.int64) + else: + topk_scores, topk_ids = torch.topk(probs, k=topk, dim=-1) + return topk_ids.contiguous(), topk_scores.contiguous() + + +class ForwardTraceCapture: + def __init__( + self, + model_chunks: list[Any], + *, + enabled: bool, + capture_name_tokens: tuple[str, ...] = CAPTURE_NAME_TOKENS, + micro_start_callback: Callable[[int | None, int], None] | None = None, + ) -> None: + self.enabled = enabled + self.capture_name_tokens = capture_name_tokens + self.micro_start_callback = micro_start_callback + self.current_step_index: int | None = None + self.current_step_trace: dict[str, list[dict[str, Any]]] = {} + self.current_micro_sample_index: int | None = None + self.current_micro_order = 0 + self.current_micro_module_call_counts: dict[str, int] = {} + self.current_step_sample_indices: list[int | None] = [] + self.current_step_outputs: list[ + tuple[int | None, int, int | None, torch.Tensor] + ] = [] + self._next_micro_order = 0 + self._hook_handles: list[Any] = [] + if not enabled: + return + self._register_hooks(model_chunks) + + def _register_hooks(self, model_chunks: list[Any]) -> None: + if not model_chunks: + raise RuntimeError("Expected at least one model chunk for forward tracing") + root_module = model_chunks[0] + self._hook_handles.append( + root_module.register_forward_pre_hook(self._root_pre_hook) + ) + self._hook_handles.append( + root_module.register_forward_hook(self._root_post_hook) + ) + for chunk_index, chunk in enumerate(model_chunks): + for module_name, module in chunk.named_modules(): + trace_module_name = f"chunk{chunk_index}.{module_name}" + is_layer_output = ( + ".decoder.layers." in module_name + and module_name.rsplit(".", 1)[-1].isdigit() + ) + if not is_layer_output and not any( + module_name.endswith(token) for token in self.capture_name_tokens + ): + continue + self._hook_handles.append( + module.register_forward_hook( + self._make_hook(trace_module_name, module) + ) + ) + + @staticmethod + def _sequence_parallel_enabled(module: Any) -> bool: + """Returns sequence-parallel flag from module/provider/config when present.""" + for owner in ( + module, + getattr(module, "provider", None), + getattr(module, "config", None), + ): + if owner is None: + continue + value = getattr(owner, "sequence_parallel", None) + if isinstance(value, bool): + return value + return False + + @staticmethod + def _lora_primary_output_merge_hint(module: Any) -> dict[str, Any] | None: + """Infers the correct output merge op for LoRA modules.""" + if module.__class__.__name__ != "LoRA": + return None + lora_module = module + b_param = getattr(lora_module, "B_T", None) + if b_param is None: + return None + b_domain = getattr(b_param, "lora_shard_domain", None) + b_world_size = _shard_world_size_for_domain(b_domain) + if bool(getattr(b_param, "lora_tp_sharded", False)) and b_world_size > 1: + shard_dim = getattr(b_param, "lora_tp_shard_dim", None) + if isinstance(shard_dim, int): + return {"op": "concat", "dim": shard_dim} + a_param = getattr(lora_module, "A_T", None) + if a_param is None: + return None + a_domain = getattr(a_param, "lora_shard_domain", None) + a_world_size = _shard_world_size_for_domain(a_domain) + if bool(getattr(a_param, "lora_tp_sharded", False)) and a_world_size > 1: + return {"op": "sum"} + return None + + def _infer_primary_output_merge_hint( + self, name: str, module: Any + ) -> dict[str, Any] | None: + """Chooses canonical cross-rank concat axis for one module output.""" + if ROUTER_NAME_TOKEN in name: + return {"op": "concat", "dim": 0} + + lora_hint = self._lora_primary_output_merge_hint(module) + if lora_hint is not None: + return lora_hint + + # Base MoE expert linears need expert-TP aware merge semantics. + # With etp>1: + # - FC1 (column-parallel) shards output features -> concat on feature dim. + # - FC2 (row-parallel) emits partial output contributions -> sum across ranks. + # With etp==1, keep the existing token-row concat behavior. + etp_world_size = _safe_ps_stat("get_expert_tensor_parallel_world_size", 1) + if ".mlp.experts.linear_fc1" in name and ".lora" not in name: + if etp_world_size > 1: + return { + "op": "concat", + "dim": -1, + "layout": "gate_up_rank_interleaved", + } + return {"op": "concat", "dim": 0} + if ".mlp.experts.linear_fc2" in name and ".lora" not in name: + if etp_world_size > 1: + return {"op": "sum"} + return {"op": "concat", "dim": 0} + + gather_output = getattr(module, "gather_output", None) + if isinstance(gather_output, bool) and not gather_output: + return {"op": "concat", "dim": -1} + + if ".self_attention.linear_qkv" in name: + return {"op": "concat", "dim": -1} + + if ".mlp.experts." in name: + return {"op": "concat", "dim": 0} + + if bool( + getattr(module, "input_is_parallel", False) + ) and self._sequence_parallel_enabled(module): + return {"op": "concat", "dim": 0} + + return None + + def _build_merge_hints(self, name: str, module: Any) -> dict[str, dict[str, Any]]: + """Builds field-level tensor merge hints for one call record.""" + hints: dict[str, dict[str, Any]] = {} + primary_output_hint = self._infer_primary_output_merge_hint(name, module) + if primary_output_hint is not None: + hints["primary_output"] = primary_output_hint + if ROUTER_NAME_TOKEN in name: + concat_dim0 = {"op": "concat", "dim": 0} + hints["output"] = concat_dim0 + hints["router_topk_ids"] = concat_dim0 + hints["router_topk_scores"] = concat_dim0 + return hints + + def _make_hook(self, name: str, module: Any): + def _hook(_module: Any, inputs: Any, output: Any) -> None: + if self.current_step_index is None: + return + micro_call_index = self.current_micro_module_call_counts.get(name, 0) + self.current_micro_module_call_counts[name] = micro_call_index + 1 + trace_item: dict[str, Any] = { + "micro_call_index": micro_call_index, + "micro_order": self.current_micro_order, + "micro_sample_index": self.current_micro_sample_index, + "module_type": module.__class__.__name__, + "rank_meta": _rank_metadata(), + "merge_hints": self._build_merge_hints(name, module), + "inputs": _materialize_trace_value(inputs), + "output": _materialize_trace_value(output), + "primary_input": self.guess_primary_tensor(inputs), + "primary_output": self.guess_primary_tensor(output), + } + if ROUTER_NAME_TOKEN in name: + router_topk = _extract_router_topk(output) + if router_topk is not None: + topk_ids, topk_scores = router_topk + trace_item["router_topk_ids"] = topk_ids + trace_item["router_topk_scores"] = topk_scores + trace_items = self._split_expert_trace_items( + module_name=name, + module=module, + inputs=inputs, + trace_item=trace_item, + ) + trace_calls = self.current_step_trace.setdefault(name, []) + for split_item in trace_items: + split_item["call_index"] = len(trace_calls) + trace_calls.append(split_item) + + return _hook + + @staticmethod + def guess_primary_tensor(value: Any) -> torch.Tensor | None: + tensor = _extract_primary_tensor(value) + if tensor is None: + return None + return _materialize_tensor(tensor) + + def _sample_index_for_micro(self, micro_order: int) -> int | None: + if micro_order < len(self.current_step_sample_indices): + return self.current_step_sample_indices[micro_order] + return None + + def _root_pre_hook(self, _module: Any, _args: Any) -> None: + if self.current_step_index is None: + return + micro_order = self._next_micro_order + sample_index = self._sample_index_for_micro(micro_order) + self.begin_micro(sample_index=sample_index, micro_order=micro_order) + + def _root_post_hook(self, _module: Any, _inputs: Any, output: Any) -> None: + if self.current_step_index is None: + return + output_tensor = self.guess_primary_tensor(output) + if output_tensor is None: + raise RuntimeError( + f"Expected root forward output to contain a tensor, got {type(output)}" + ) + sample_index = self.current_micro_sample_index + micro_order = self.current_micro_order + self.current_step_outputs.append( + ( + sample_index, + micro_order, + None + if sample_index is not None + else _local_dummy_micro_slot(micro_order), + output_tensor.float(), + ) + ) + self._next_micro_order = micro_order + 1 + + def set_step( + self, + step_index: int, + sample_indices: list[int | None] | None = None, + ) -> None: + self.current_step_index = step_index + self.current_step_trace = {} + self.current_step_sample_indices = list(sample_indices or []) + self.current_step_outputs = [] + self.current_micro_sample_index = None + self.current_micro_order = 0 + self.current_micro_module_call_counts = {} + self._next_micro_order = 0 + + def begin_micro(self, sample_index: int | None, micro_order: int) -> None: + self.current_micro_sample_index = sample_index + self.current_micro_order = micro_order + self.current_micro_module_call_counts = {} + if self.micro_start_callback is not None: + self.micro_start_callback(sample_index, micro_order) + + @staticmethod + def _row_token_uids_for_trace( + *, + inputs: Any, + module: Any, + ) -> tuple[torch.Tensor | None, int | None]: + row_token_uids = _extract_tensor_attr(inputs, "_art_trace_row_token_uids") + if row_token_uids is None: + row_token_uids = getattr(module, "_art_trace_row_token_uids", None) + if not isinstance(row_token_uids, torch.Tensor): + return None, None + + uid_span = _extract_tensor_attr(inputs, "_art_trace_uid_span") + if uid_span is None: + uid_span = getattr(module, "_art_trace_uid_span", None) + uid_span_int = uid_span if isinstance(uid_span, int) and uid_span > 0 else None + return ( + row_token_uids.detach().to(device="cpu", dtype=torch.int64).reshape(-1), + uid_span_int, + ) + + @classmethod + def _slice_row_aligned_value( + cls, + value: Any, + *, + row_indices: torch.Tensor, + total_rows: int, + ) -> Any: + if isinstance(value, torch.Tensor): + if value.ndim > 0 and int(value.shape[0]) == total_rows: + return value.index_select(0, row_indices) + return value + if isinstance(value, dict): + return { + key: cls._slice_row_aligned_value( + item, + row_indices=row_indices, + total_rows=total_rows, + ) + for key, item in value.items() + } + if isinstance(value, list): + return [ + cls._slice_row_aligned_value( + item, + row_indices=row_indices, + total_rows=total_rows, + ) + for item in value + ] + if isinstance(value, tuple): + return tuple( + cls._slice_row_aligned_value( + item, + row_indices=row_indices, + total_rows=total_rows, + ) + for item in value + ) + return value + + @classmethod + def _split_expert_trace_items( + cls, + *, + module_name: str, + module: Any, + inputs: Any, + trace_item: dict[str, Any], + ) -> list[dict[str, Any]]: + if not cls._is_moe_expert_forward_module(module_name): + return [trace_item] + + primary_output = trace_item.get("primary_output") + if not isinstance(primary_output, torch.Tensor) or primary_output.ndim == 0: + return [trace_item] + + row_token_uids, uid_span = cls._row_token_uids_for_trace( + inputs=inputs, + module=module, + ) + if row_token_uids is None: + return [trace_item] + + total_rows = int(row_token_uids.numel()) + if total_rows == 0 or int(primary_output.shape[0]) != total_rows: + return [trace_item] + + trace_item["row_token_uids"] = row_token_uids + if uid_span is None: + return [trace_item] + + sample_ids = torch.div(row_token_uids, uid_span, rounding_mode="floor") + ordered_sample_ids: list[int] = [] + seen_sample_ids: set[int] = set() + for sample_id in sample_ids.tolist(): + sample_id_int = int(sample_id) + if sample_id_int in seen_sample_ids: + continue + seen_sample_ids.add(sample_id_int) + ordered_sample_ids.append(sample_id_int) + + if len(ordered_sample_ids) <= 1: + if ordered_sample_ids: + trace_item["micro_sample_index"] = ordered_sample_ids[0] + return [trace_item] + + split_items: list[dict[str, Any]] = [] + for sample_id in ordered_sample_ids: + row_indices = (sample_ids == sample_id).nonzero(as_tuple=False).reshape(-1) + split_item = { + key: cls._slice_row_aligned_value( + value, + row_indices=row_indices, + total_rows=total_rows, + ) + for key, value in trace_item.items() + if key not in {"call_index", "micro_sample_index", "row_token_uids"} + } + split_item["micro_sample_index"] = sample_id + split_item["row_token_uids"] = row_token_uids.index_select(0, row_indices) + split_items.append(split_item) + return split_items + + @staticmethod + def _is_moe_expert_forward_module(module_name: str) -> bool: + """Returns whether one module emits MoE expert forward outputs.""" + if ".mlp.experts." not in module_name: + return False + if ".mlp.router" in module_name: + return False + return ".linear_fc1" in module_name or ".linear_fc2" in module_name + + @staticmethod + def _primary_output_merge_hint(call: dict[str, Any]) -> dict[str, Any] | None: + """Reads primary-output merge metadata from one call payload.""" + merge_hints = call.get("merge_hints") + if not isinstance(merge_hints, dict): + return None + primary_hint = merge_hints.get("primary_output") + if not isinstance(primary_hint, dict): + return None + return primary_hint + + @classmethod + def _canonicalize_etp_fc1_feature_layout( + cls, + *, + module_name: str, + tensor: torch.Tensor, + call: dict[str, Any], + ) -> torch.Tensor: + """Normalizes expert-TP fc1 feature order to a topology-independent layout.""" + if ".mlp.experts.linear_fc1" not in module_name or ".lora" in module_name: + return tensor + if tensor.ndim != 2: + return tensor + primary_hint = cls._primary_output_merge_hint(call) + if not isinstance(primary_hint, dict): + return tensor + if primary_hint.get("layout") != "gate_up_rank_interleaved": + return tensor + rank_meta = call.get("rank_meta") + etp_world_size = None + if isinstance(rank_meta, list) and rank_meta: + first_meta = rank_meta[0] + if isinstance(first_meta, dict): + etp_world_size = first_meta.get("etp_world_size") + elif isinstance(rank_meta, dict): + etp_world_size = rank_meta.get("etp_world_size") + if not isinstance(etp_world_size, int) or etp_world_size <= 1: + return tensor + block_count = 2 * etp_world_size + if tensor.shape[1] % block_count != 0: + return tensor + blocks = torch.chunk(tensor, block_count, dim=1) + reordered = [blocks[index] for index in range(0, block_count, 2)] + [ + blocks[index] for index in range(1, block_count, 2) + ] + return torch.cat(reordered, dim=1).contiguous() + + @classmethod + def _canonicalize_moe_expert_row_order( + cls, + *, + module_name: str, + tensor: torch.Tensor, + call: dict[str, Any], + ) -> torch.Tensor: + """Canonicalizes MoE expert rows using dispatch-time UID identities.""" + if not cls._is_moe_expert_forward_module(module_name): + return tensor + if tensor.ndim != 2: + return tensor + primary_hint = cls._primary_output_merge_hint(call) + if isinstance(primary_hint, dict) and ( + primary_hint.get("op") != "concat" or primary_hint.get("dim") != 0 + ): + return tensor + row_token_uids = call.get("row_token_uids") + if not isinstance(row_token_uids, torch.Tensor): + return tensor + if int(row_token_uids.numel()) != int(tensor.shape[0]): + return tensor + order = torch.argsort(row_token_uids, stable=True) + return tensor.index_select(0, order) + + @classmethod + def _canonicalize_primary_output_tensor( + cls, + *, + module_name: str, + tensor: torch.Tensor, + call: dict[str, Any], + ) -> torch.Tensor: + """Runs all remaining primary-output canonicalization passes for one call.""" + tensor = cls._canonicalize_etp_fc1_feature_layout( + module_name=module_name, + tensor=tensor, + call=call, + ) + return cls._canonicalize_moe_expert_row_order( + module_name=module_name, + tensor=tensor, + call=call, + ) + + @classmethod + def canonicalize_trace( + cls, + trace: dict[str, list[dict[str, Any]]], + ) -> dict[str, list[dict[str, Any]]]: + """Canonicalizes topology-dependent trace outputs in place.""" + for module_name in sorted(trace.keys()): + calls = trace[module_name] + for call_offset, call in enumerate(calls): + if bool(call.get(PRIMARY_OUTPUT_CANONICAL_KEY)): + continue + call_index = int(call.get("call_index", call_offset)) + tensor = call.get("primary_output") + if isinstance(tensor, torch.Tensor): + call["primary_output"] = cls._canonicalize_primary_output_tensor( + module_name=module_name, + tensor=tensor, + call=call, + ) + call[PRIMARY_OUTPUT_CANONICAL_KEY] = True + return trace + + @classmethod + def flatten_trace_tensors( + cls, + trace: dict[str, list[dict[str, Any]]], + *, + value_key: str, + ) -> dict[str, Any]: + """Flattens trace calls into deterministic key->value tensor maps.""" + if value_key == "primary_output": + cls.canonicalize_trace(trace) + flattened: dict[str, Any] = {} + for module_name in sorted(trace.keys()): + for call_offset, call in enumerate(trace[module_name]): + tensor = call.get(value_key) + if tensor is None: + continue + call_index = call.get("call_index", call_offset) + flattened[f"{module_name}.call_{call_index}"] = tensor + return flattened + + @classmethod + def _merge_rank_values( + cls, + values_by_rank: list[Any], + *, + preferred_cat_dim: int | None = None, + preferred_reduce: str | None = None, + ) -> Any: + if not values_by_rank: + raise RuntimeError("Cannot merge empty rank value list") + if all(isinstance(value, torch.Tensor) for value in values_by_rank): + tensors = cast(list[torch.Tensor], values_by_rank) + if preferred_reduce == "sum" and all( + tensors[0].shape == tensor.shape for tensor in tensors[1:] + ): + return torch.stack(tensors, dim=0).sum(dim=0) + if ( + preferred_cat_dim is not None + and all(tensor.ndim > 0 for tensor in tensors) + and cls._can_cat_along_dim(tensors, dim=preferred_cat_dim) + ): + return torch.cat(tensors, dim=preferred_cat_dim) + if all(tensor.ndim > 0 for tensor in tensors): + if cls._can_cat_along_dim(tensors, dim=0): + return torch.cat(tensors, dim=0) + if cls._can_cat_along_dim(tensors, dim=-1): + return torch.cat(tensors, dim=-1) + if all(tensors[0].shape == tensor.shape for tensor in tensors[1:]): + return torch.stack(tensors, dim=0) + return tensors + if all(isinstance(value, dict) for value in values_by_rank): + dicts = cast(list[dict[str, Any]], values_by_rank) + keys = sorted(set().union(*(value.keys() for value in dicts))) + return { + key: cls._merge_rank_values( + [value[key] for value in dicts if key in value], + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + for key in keys + } + if all(isinstance(value, list) for value in values_by_rank): + lists = cast(list[list[Any]], values_by_rank) + if any(len(values) != len(lists[0]) for values in lists[1:]): + return lists + return [ + cls._merge_rank_values( + [value[index] for value in lists], + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + for index in range(len(lists[0])) + ] + if all(isinstance(value, tuple) for value in values_by_rank): + tuples = cast(list[tuple[Any, ...]], values_by_rank) + if any(len(values) != len(tuples[0]) for values in tuples[1:]): + return tuples + return tuple( + cls._merge_rank_values( + [value[index] for value in tuples], + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + for index in range(len(tuples[0])) + ) + if all(value == values_by_rank[0] for value in values_by_rank[1:]): + return values_by_rank[0] + return values_by_rank + + @classmethod + def _merge_rank_call_entries( + cls, + rank_call_entries: list[dict[str, Any]], + ) -> dict[str, Any]: + """Merges one module call across ranks using per-field merge hints.""" + merged_call: dict[str, Any] = {} + keys = sorted(set().union(*(entry.keys() for entry in rank_call_entries))) + for key in keys: + values = [entry[key] for entry in rank_call_entries if key in entry] + if key == "rank_meta": + merged_call[key] = values + continue + preferred_cat_dim: int | None = None + preferred_reduce: str | None = None + if values and key not in {"merge_hints", "call_index", "module_type"}: + hint_values = [ + cast(dict[str, Any], entry["merge_hints"]).get(key) + for entry in rank_call_entries + if isinstance(entry.get("merge_hints"), dict) + ] + op_hints = [ + hint + for hint in hint_values + if isinstance(hint, dict) and isinstance(hint.get("op"), str) + ] + if op_hints: + selected_hint = op_hints[0] + op = selected_hint.get("op") + if op == "concat": + dim = selected_hint.get("dim") + if isinstance(dim, int): + preferred_cat_dim = dim + elif op == "sum": + preferred_reduce = "sum" + if ( + preferred_reduce is None + and preferred_cat_dim == 0 + and all(isinstance(value, torch.Tensor) for value in values) + ): + merged_call[f"{key}__row_splits"] = [ + int(cast(torch.Tensor, value).shape[0]) for value in values + ] + merged_call[key] = cls._merge_rank_values( + values, + preferred_cat_dim=preferred_cat_dim, + preferred_reduce=preferred_reduce, + ) + return merged_call + + @staticmethod + def _can_cat_along_dim(tensors: list[torch.Tensor], dim: int) -> bool: + if not tensors: + return False + if tensors[0].ndim == 0: + return False + ndim = tensors[0].ndim + axis = dim if dim >= 0 else ndim + dim + if axis < 0 or axis >= ndim: + return False + if any(tensor.ndim != ndim for tensor in tensors[1:]): + return False + for dim_index in range(ndim): + if dim_index == axis: + continue + dim_size = tensors[0].shape[dim_index] + if any(tensor.shape[dim_index] != dim_size for tensor in tensors[1:]): + return False + return True + + @classmethod + def _merge_rank_traces( + cls, + rank_traces: list[dict[str, list[dict[str, Any]]]], + ) -> dict[str, list[dict[str, Any]]]: + if len(rank_traces) == 1: + return rank_traces[0] + merged: dict[str, list[dict[str, Any]]] = {} + module_names = sorted(set().union(*(trace.keys() for trace in rank_traces))) + for module_name in module_names: + module_calls: list[dict[str, Any]] = [] + grouped_calls: dict[ + tuple[int, int, int, int], + list[dict[str, Any]], + ] = {} + for trace in rank_traces: + for call in trace.get(module_name, []): + sample_kind, sample_sort_index = _trace_call_sort_key(call) + merge_key = ( + sample_kind, + sample_sort_index, + int(call.get("micro_order", 0)), + int(call.get("micro_call_index", call.get("call_index", 0))), + ) + grouped_calls.setdefault(merge_key, []).append(call) + for merged_index, merge_key in enumerate(sorted(grouped_calls)): + merged_call = cls._merge_rank_call_entries(grouped_calls[merge_key]) + merged_call["call_index"] = merged_index + module_calls.append(merged_call) + merged[module_name] = module_calls + return merged + + @staticmethod + def _gather_rank_traces( + local_trace: dict[str, list[dict[str, Any]]], + ) -> list[dict[str, list[dict[str, Any]]]] | None: + if ( + not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] + or torch.distributed.get_world_size() == 1 # ty: ignore[possibly-missing-attribute] + ): + return [local_trace] + gathered: list[dict[str, list[dict[str, Any]]] | None] = [ + None + ] * torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] + torch.distributed.all_gather_object(gathered, local_trace) # ty: ignore[possibly-missing-attribute] + if torch.distributed.get_rank() != 0: # ty: ignore[possibly-missing-attribute] + return None + return cast(list[dict[str, list[dict[str, Any]]]], gathered) + + @staticmethod + def _merge_group_tensor(tensors: list[torch.Tensor]) -> torch.Tensor: + if len(tensors) == 1: + return tensors[0] + first = tensors[0] + if all(tensor.shape == first.shape for tensor in tensors[1:]) and all( + torch.equal(first, tensor) for tensor in tensors[1:] + ): + return first + raise RuntimeError( + "Mismatched output captures for the same micro output across non-DP ranks" + ) + + @staticmethod + def _gather_rank_outputs( + local_outputs: list[tuple[int | None, int, int | None, torch.Tensor]], + ) -> list[list[tuple[int | None, int, int | None, torch.Tensor]]] | None: + if ( + not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] + or torch.distributed.get_world_size() == 1 # ty: ignore[possibly-missing-attribute] + ): + return [local_outputs] + gathered: list[ + list[tuple[int | None, int, int | None, torch.Tensor]] | None + ] = [None] * torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] + torch.distributed.all_gather_object(gathered, local_outputs) # ty: ignore[possibly-missing-attribute] + if torch.distributed.get_rank() != 0: # ty: ignore[possibly-missing-attribute] + return None + return cast( + list[list[tuple[int | None, int, int | None, torch.Tensor]]], + gathered, + ) + + def ordered_step_outputs(self) -> list[torch.Tensor] | None: + if not self.enabled: + return None + gathered_outputs = self._gather_rank_outputs(self.current_step_outputs) + if gathered_outputs is None: + return None + grouped: dict[tuple[int | None, int | None, int], list[torch.Tensor]] = {} + for rank_outputs in gathered_outputs: + for sample_index, micro_order, micro_slot, tensor in rank_outputs: + group_key = (sample_index, micro_slot, micro_order) + grouped.setdefault(group_key, []).append(tensor) + ordered_group_keys = sorted( + grouped, + key=lambda item: _captured_output_sort_key(item[0], item[2], item[1]), + ) + return [ + self._merge_group_tensor(grouped[group_key]) + for group_key in ordered_group_keys + ] + + def save_current_step(self, traces_dir: Path) -> Path | None: + if not self.enabled or self.current_step_index is None: + return None + gathered_traces = self._gather_rank_traces(self.current_step_trace) + if gathered_traces is None: + return None + merged_trace = self.canonicalize_trace(self._merge_rank_traces(gathered_traces)) + traces_dir.mkdir(parents=True, exist_ok=True) + trace_path = traces_dir / f"forward_trace_step_{self.current_step_index:03d}.pt" + tmp_trace_path = trace_path.with_suffix(f"{trace_path.suffix}.tmp") + torch.save(merged_trace, tmp_trace_path) + os.replace(tmp_trace_path, trace_path) + return trace_path + + @classmethod + def load_trace(cls, trace_path: Path) -> dict[str, list[dict[str, Any]]]: + trace = torch.load(trace_path, map_location="cpu", weights_only=False) + return cls.canonicalize_trace(trace) + + def close(self) -> None: + for handle in self._hook_handles: + handle.remove() + self._hook_handles.clear() diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py new file mode 100644 index 00000000..033cd5b9 --- /dev/null +++ b/tests/integration/megatron_oracle_harness.py @@ -0,0 +1,1467 @@ +from __future__ import annotations + +from functools import partial +import hashlib +import json +import math +import os +from pathlib import Path +import re +import shutil +from typing import Any, Callable, Literal, TypeVar, cast + +from pydantic import BaseModel, ConfigDict, Field +from rich import box +from rich.console import Console +from rich.table import Table +import torch + +from .megatron_forward_trace import ForwardTraceCapture + +REPO_ROOT = Path(__file__).resolve().parents[2] +ARTIFACT_ROOT = Path(REPO_ROOT / ".local/megatron_lora_correctness") +ORACLE_MOE_ROUTING_BUNDLE_DIRNAME = "oracle_moe_routing_replay" +ORACLE_REPLAY_TOPOLOGY_SUFFIX = "oracle_replay" + +REGENERATE_ENV = "ART_REGENERATE_ORACLE" +EXTENDED_TOPOLOGIES_ENV = "ART_ENABLE_EXTENDED_TOPOLOGIES" +SENSITIVITY_MUTATION_ENV = "ART_SENSITIVITY_MUTATIONS" + +DEFAULT_SENSITIVITY_MUTATION = "skip_finalize" +SUPPORTED_SENSITIVITY_MUTATIONS = ( + DEFAULT_SENSITIVITY_MUTATION, + "fwd_skip_o_proj_tp_reduce", + "fwd_o_proj_tp_reduce_avg_not_sum", + "bwd_skip_sync_qkv_a", + "bwd_skip_sync_o_proj_b", + "bwd_skip_sync_fc1_a", + "save_drop_nonzero_ranked_tp_shards", + "save_duplicate_replicated_entries", + "dp_grad_accumulation_seqs", + "dp_local_token_normalization", +) +SensitivityMutation = str + +REQUIRED_PACKED_TENSOR_FILES = ( + "tokens.pt", + "group_ids.pt", + "parent_ids.pt", + "input_pos.pt", + "assistant_mask.pt", + "logprobs.pt", + "advantages.pt", + "weights.pt", +) +NON_FINITE_METRIC_VALUE = 1e30 +EXPERT_TABLE_ROW_LIMIT = 8 +EXPERT_TRIPLET_PARAM_RE = re.compile( + r"layers\.(?P\d+|__layer_avg__)\.mlp\.experts\.(?P\d+)\." + r"(?Pgate_proj|up_proj|down_proj)\." +) +LAYER_INDEX_RE = re.compile(r"layers\.(\d+)\.") +PHASE_PRINT_ORDER = { + "forward": 0, + "router_scores": 1, + "router_topk_ids": 2, + "outputs": 3, + "losses": 4, + "grads": 5, + "deltas": 6, +} + + +class Topology(BaseModel): + """Defines distributed topology settings for one Megatron run variant.""" + + model_config = ConfigDict(frozen=True) + + tp: int + ep: int + etp: int = 1 + dp: int = 1 + sp: bool = False + cp: int = 1 + pp: int = 1 + vpp: int = 1 + + def resolved_expert_dp(self) -> int: + """Derives expert data parallel size from topology/world-size constraints.""" + attention_world = self.tp * self.cp * self.pp * self.dp + expert_divisor = self.etp * self.ep * self.pp + if attention_world % expert_divisor != 0: + raise ValueError( + "Invalid topology for Megatron expert parallelism: " + f"world_size={attention_world} is not divisible by " + f"etp*ep*pp={expert_divisor}." + ) + return attention_world // expert_divisor + + def slug(self) -> str: + """Builds a deterministic topology identifier used for output directories.""" + return ( + f"tp{self.tp}_ep{self.ep}_etp{self.etp}" + f"_dp{self.dp}_edp{self.resolved_expert_dp()}" + f"_cp{self.cp}_pp{self.pp}_vpp{self.vpp}_sp{int(self.sp)}" + ) + + def world_size(self) -> int: + # Mirrors Megatron parallel-state sizing: + # attention side: world = tp * pp * cp * dp + # expert side must also divide this world size (validated in resolved_expert_dp()). + attention_world = self.tp * self.cp * self.pp * self.dp + self.resolved_expert_dp() + return attention_world + + +TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=1, sp=False), + Topology(tp=2, ep=1, etp=1, dp=1, sp=True), + Topology(tp=2, ep=2, etp=1, dp=1, sp=True), + Topology(tp=2, ep=1, etp=2, dp=1, sp=True), +] +EXTENDED_TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=2, sp=False), + Topology(tp=1, ep=2, etp=1, dp=2, sp=False), + Topology(tp=1, ep=1, etp=2, dp=2, sp=True), +] +ORACLE_TOPOLOGY = TOPOLOGIES[0] +SENSITIVITY_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) +SENSITIVITY_TOPOLOGY_BY_MUTATION: dict[SensitivityMutation, Topology] = { + mutation: SENSITIVITY_TOPOLOGY for mutation in SUPPORTED_SENSITIVITY_MUTATIONS +} +SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology( + tp=2, ep=1, etp=2, dp=1, sp=True +) +SENSITIVITY_TOPOLOGY_BY_MUTATION |= { + k: Topology(tp=1, ep=2, etp=1, dp=2, sp=False) + for k in ["dp_grad_accumulation_seqs", "dp_local_token_normalization"] +} + + +class PackedTensorConfig(BaseModel): + """Controls synthetic packed tensor generation used by oracle harness runs.""" + + num_sequences: int = 4 + sequence_length: int = 256 + prefill_tokens: int = 64 + decode_tokens: int = 64 + decode_tokens_jitter: int = Field(default=32, ge=0) + vocab_high: int = 8192 + + +class LoraConfig(BaseModel): + """Configures LoRA adapter dimensions and targeted module families.""" + + rank: int = 1 + alpha: int = 32 + target_modules: list[str] = Field( + default_factory=lambda: [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + ) + + +MetricSummary = dict[str, float] +PhasePassFn = Callable[[MetricSummary], bool] + + +class MetricThresholdRule(BaseModel): + """Callable row pass rule that AND-checks configured metric upper bounds.""" + + limits: dict[str, float] = Field(default_factory=dict) + + def failure_reasons(self, summary: MetricSummary) -> list[str]: + """Builds readable failure reasons for this threshold rule.""" + reasons: list[str] = [] + for key, limit in sorted(self.limits.items()): + value = summary.get(key) + if not isinstance(value, (int, float)): + reasons.append(f"{key}=missing") + continue + if float(value) > float(limit): + reasons.append(f"{key}={float(value):.6g}>{float(limit):.6g}") + return reasons + + def __call__(self, summary: MetricSummary) -> bool: + """Evaluates whether the summary satisfies all configured bounds.""" + return len(self.failure_reasons(summary)) == 0 + + +class OracleCaseConfig(BaseModel): + """Contains all deterministic run parameters for one oracle case.""" + + base_model: str + precision: Literal["bf16", "fp32"] = "fp32" + num_layers: int = 4 + seed: int = 20260304 + num_steps: int = 1 + grad_accumulation_sequences: int = Field(default=4, ge=1) + learning_rate: float = 5e-6 + beta: float = 0.0 + loss_scale: float = 1 + packed_tensors: PackedTensorConfig = Field(default_factory=PackedTensorConfig) + lora: LoraConfig = Field(default_factory=LoraConfig) + + +class DiskPackedTensorsSpec(BaseModel): + """Describes packed tensor artifacts persisted on disk for reuse.""" + + dir: str + num_sequences: int + sequence_length: int + pixel_values: tuple[int, list[int]] | None = None + image_grid_thw: tuple[int, list[int]] | None = None + + +class CaseArtifacts(BaseModel): + """Holds stable case-level artifact paths used by all variants.""" + + case_id: str + case_dir: str + packed_tensors: DiskPackedTensorsSpec + shared_init_adapter_path: str + + +class WorkerRunRequest(BaseModel): + """Defines one distributed worker invocation for generating variant artifacts.""" + + case_id: str + case_config: OracleCaseConfig + topology: Topology + topology_dir: str + packed_tensors: DiskPackedTensorsSpec + shared_init_adapter_path: str + mutation: SensitivityMutation | None = None + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True + capture_moe_routing_bundle_path: str | None = None + + +class StepTrace(BaseModel): + """Tracks per-step trace artifact filenames and loss metadata.""" + + step_index: int + loss: float + probs_corr: float + output_file: str + grads_file: str + deltas_file: str + lora_file: str + + +class RunManifest(BaseModel): + """Records run metadata and per-step trace references for one topology output.""" + + case_id: str + base_model: str + num_layers: int + topology: str + world_size: int + seed: int + num_steps: int + packed_tensors: DiskPackedTensorsSpec + steps: list[StepTrace] + + +class MetricRow(BaseModel): + """Represents one comparable unit (param/module/global) for one phase and step.""" + + case_id: str + variant: str + topology: str + oracle_topology: str + step_index: int + phase: str + param: str + numel: float + mean_abs_diff: float + relative_l2: float + typical_abs_scale: float + mean_abs_pct: float + topk_mismatch_fraction: float | None = None + top1_mismatch_fraction: float | None = None + pass_signal: bool = True + failure_reasons: list[str] = Field(default_factory=list) + + +class VariantSpec(BaseModel): + """Declares how to execute and evaluate one candidate variant against the oracle.""" + + name: str + topology: Topology + pass_fn_by_phase: dict[str, PhasePassFn] = Field( + default_factory=dict, + repr=False, + exclude=True, + ) + output_slug: str | None = None + reference_slug: str | None = None + mutation: SensitivityMutation | None = None + expected_signal: Literal["pass", "fail"] = "pass" + force_regenerate: bool = True + + def resolved_output_slug(self) -> str: + """Resolves the artifact slug for this run, including mutation suffix when present.""" + if self.output_slug is not None: + return self.output_slug + return _topology_output_slug(self.topology, self.mutation) + + def resolved_reference_slug(self) -> str: + """Resolves which topology slug should be treated as the comparison oracle.""" + if self.reference_slug is not None: + return self.reference_slug + return ORACLE_TOPOLOGY.slug() + + +class VariantReport(BaseModel): + """Captures full comparison output for one variant run.""" + + case_id: str + variant: str + topology: str + reference_topology: str + expected_signal: Literal["pass", "fail"] + signal: Literal["pass", "fail"] + pass_count: int + fail_count: int + step_summaries: dict[int, dict[str, Any]] = Field(repr=False) + metrics: list[MetricRow] = Field(repr=False) + + +class DiffAccumulator: + """Accumulates diff statistics across tensors and router-id mismatch counters.""" + + def __init__(self) -> None: + self.numel = 0 + self.abs_sum = 0.0 + self.diff_sq_sum = 0.0 + self.ref_sq_sum = 0.0 + self.ref_abs_sum = 0.0 + self.router_topk_total = 0 + self.router_topk_mismatch = 0 + self.router_top1_total = 0 + self.router_top1_mismatch = 0 + + def update(self, reference, candidate) -> None: # type: ignore[no-untyped-def] + """Adds one tensor pair into the accumulator.""" + ref = reference.detach().float() + cand = candidate.detach().float() + diff = (cand - ref).abs() + if diff.numel() == 0: + return + self.numel += int(diff.numel()) + self.abs_sum += float(diff.sum().item()) + self.diff_sq_sum += float((cand - ref).square().sum().item()) + self.ref_sq_sum += float(ref.square().sum().item()) + self.ref_abs_sum += float(ref.abs().sum().item()) + + @staticmethod + def layer_averaged_summary(reference_stack, candidate_stack) -> dict[str, float]: # type: ignore[no-untyped-def] + """Computes normal per-layer summaries, then averages those summaries.""" + ref = reference_stack.detach().float() + cand = candidate_stack.detach().float() + layer_count = int(ref.shape[0]) + metrics = { + k: 0.0 + for k in [ + "numel", + "mean_abs_diff", + "relative_l2", + "typical_abs_scale", + "mean_abs_pct", + ] + } + for layer_index in range(layer_count): + layer_accumulator = DiffAccumulator() + layer_accumulator.update(ref[layer_index], cand[layer_index]) + layer_summary = layer_accumulator.as_summary() + metrics = {k: metrics[k] + layer_summary[k] for k in metrics.keys()} + return {k: _finite_metric(metrics[k] / layer_count) for k in metrics.keys()} + + def update_router_ids(self, reference_ids, candidate_ids) -> None: # type: ignore[no-untyped-def] + """Adds router top-k id mismatch counts into the accumulator.""" + self.router_topk_total += int(reference_ids.numel()) + self.router_topk_mismatch += int((reference_ids != candidate_ids).sum().item()) + if reference_ids.ndim >= 2 and reference_ids.shape[1] > 0: + self.router_top1_total += int(reference_ids.shape[0]) + self.router_top1_mismatch += int( + (reference_ids[:, 0] != candidate_ids[:, 0]).sum().item() + ) + + def as_summary(self) -> dict[str, float]: + """Returns normalized summary values for one row.""" + if self.numel == 0: + topk_fraction = 0.0 + top1_fraction = 0.0 + else: + topk_fraction = ( + self.router_topk_mismatch / self.router_topk_total + if self.router_topk_total > 0 + else 0.0 + ) + top1_fraction = ( + self.router_top1_mismatch / self.router_top1_total + if self.router_top1_total > 0 + else 0.0 + ) + if self.numel == 0: + return { + "numel": 0.0, + "mean_abs_diff": 0.0, + "relative_l2": 0.0, + "typical_abs_scale": 0.0, + "mean_abs_pct": 0.0, + "topk_mismatch_fraction": topk_fraction, + "top1_mismatch_fraction": top1_fraction, + } + mean_abs = self.abs_sum / self.numel + typical_abs = self.ref_abs_sum / self.numel + mean_abs_pct = (mean_abs / (typical_abs + 1e-12)) * 100.0 + return { + "numel": _finite_metric(float(self.numel), default=0.0), + "mean_abs_diff": _finite_metric(mean_abs), + "relative_l2": _finite_metric( + (self.diff_sq_sum**0.5) / max(self.ref_sq_sum**0.5, 1e-12) + ), + "typical_abs_scale": _finite_metric(typical_abs, default=0.0), + "mean_abs_pct": _finite_metric(mean_abs_pct), + "topk_mismatch_fraction": _finite_metric(topk_fraction, default=1.0), + "top1_mismatch_fraction": _finite_metric(top1_fraction, default=1.0), + } + + +T = TypeVar("T") + + +def _require_not_none(value: T | None, name: str) -> T: + """Asserts non-None values for required artifacts and raises a named runtime error.""" + if value is None: + raise RuntimeError(f"{name} is None") + return value + + +def _truthy(value: str | None) -> bool: + """Parses env-var style booleans using a small accepted truthy set.""" + if value is None: + return False + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def sensitivity_mutations() -> list[SensitivityMutation]: + """Parses sensitivity mutation selectors from env as a CSV list.""" + raw = os.environ.get(SENSITIVITY_MUTATION_ENV) + if raw is None or raw.strip() == "": + return [] + normalized = raw.strip().lower() + if normalized == "all": + return list(SUPPORTED_SENSITIVITY_MUTATIONS) + if normalized in {"1", "true", "yes", "on"}: + return [DEFAULT_SENSITIVITY_MUTATION] + mutations = [item.strip().lower() for item in raw.split(",") if item.strip()] + unsupported = [ + mutation + for mutation in mutations + if mutation not in SUPPORTED_SENSITIVITY_MUTATIONS + ] + if not unsupported: + return mutations + supported = ", ".join(SUPPORTED_SENSITIVITY_MUTATIONS) + raise ValueError( + f"Unsupported {SENSITIVITY_MUTATION_ENV} value '{raw}'. " + f"Supported values: {supported}, CSV of supported values, all, 1/true/yes/on." + ) + + +def sensitivity_enabled() -> bool: + """Returns whether any sensitivity mutation has been requested via environment.""" + return bool(sensitivity_mutations()) + + +def sensitivity_topology_for_mutation(mutation: SensitivityMutation) -> Topology: + """Returns the sensitivity topology required for one mutation.""" + return SENSITIVITY_TOPOLOGY_BY_MUTATION[mutation] + + +def sensitivity_required_world_size(mutations: list[SensitivityMutation]) -> int: + """Returns the max world-size required by a selected mutation set.""" + return max( + sensitivity_topology_for_mutation(mutation).world_size() + for mutation in mutations + ) + + +def extended_topologies_enabled() -> bool: + """Returns whether extended topologies are enabled for the suite.""" + return _truthy(os.environ.get(EXTENDED_TOPOLOGIES_ENV)) + + +def regenerate_requested() -> bool: + """Returns whether regeneration mode is enabled for oracle artifacts.""" + return _truthy(os.environ.get(REGENERATE_ENV)) + + +def case_config( + base_model: str = "Qwen/Qwen3-30B-A3B-Instruct-2507", +) -> OracleCaseConfig: + """Builds the deterministic default oracle case config.""" + return OracleCaseConfig(base_model=base_model) + + +def available_gpu_count() -> int: + """Reports visible CUDA device count for topology scheduling and test skips.""" + import torch + + return int(torch.cuda.device_count()) + + +def stable_case_id(case_config: OracleCaseConfig) -> str: + """Builds a deterministic case id from case config contents.""" + payload = case_config.model_dump(mode="json") + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) + digest = hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:16] + model_tag = ( + case_config.base_model.replace("/", "_") + .replace("-", "_") + .replace(".", "_") + .lower() + ) + return f"{model_tag}_{digest}" + + +def _write_json(path: Path, payload: Any) -> None: + """Writes canonical pretty JSON to disk, creating parent directories as needed.""" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True, allow_nan=False) + + +def _read_json(path: Path) -> dict[str, Any]: + """Loads a JSON object from disk.""" + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def _build_packed_tensors( + config: PackedTensorConfig, + seed: int, +) -> dict[str, Any]: + """Generates deterministic synthetic packed tensors used in integration runs.""" + import torch + + if config.num_sequences <= 1: + raise ValueError("num_sequences must be greater than 1") + shape = (config.num_sequences, config.sequence_length) + generator = torch.Generator().manual_seed(seed) + tokens = torch.randint( + low=10, + high=config.vocab_high, + size=shape, + dtype=torch.long, + generator=generator, + ) + # Ensure paired cross-DP rows are never token-identical. + half = config.num_sequences // 2 + if half > 0 and config.num_sequences % 2 == 0: + for pair_index in range(half): + left_index = pair_index + right_index = pair_index + half + if torch.equal(tokens[left_index], tokens[right_index]): + token_span = max(1, config.vocab_high - 10) + tokens[right_index] = ((tokens[right_index] - 10 + 1) % token_span) + 10 + group_ids = torch.zeros(shape, dtype=torch.long) + parent_ids = torch.full(shape, -1, dtype=torch.long) + input_pos = ( + torch.arange(config.sequence_length, dtype=torch.long) + .unsqueeze(0) + .expand(config.num_sequences, -1) + .clone() + ) + prefix_length = max(1, min(config.sequence_length - 1, config.prefill_tokens)) + assistant_mask = torch.zeros(shape, dtype=torch.bool) + max_decode_tokens = max(1, config.sequence_length - prefix_length) + base_decode_tokens = max(1, min(config.decode_tokens, max_decode_tokens)) + jitter_width = min(config.decode_tokens_jitter, max_decode_tokens - 1) + candidate_decode_lengths: list[int] = [] + for _ in range(config.num_sequences): + if jitter_width > 0: + jitter = int( + torch.randint( + low=-jitter_width, + high=jitter_width + 1, + size=(1,), + generator=generator, + dtype=torch.long, + ).item() + ) + else: + jitter = 0 + decode_length = max( + 1, + min(max_decode_tokens, base_decode_tokens + jitter), + ) + candidate_decode_lengths.append(decode_length) + # Keep jitter local around the configured decode length, but force pairwise + # differences across halves so default DP rank shards see different lengths. + if half > 0 and config.num_sequences % 2 == 0: + for pair_index in range(half): + left_index = pair_index + right_index = pair_index + half + if ( + candidate_decode_lengths[left_index] + != candidate_decode_lengths[right_index] + ): + continue + if candidate_decode_lengths[right_index] < max_decode_tokens: + candidate_decode_lengths[right_index] += 1 + elif candidate_decode_lengths[right_index] > 1: + candidate_decode_lengths[right_index] -= 1 + + for sequence_index, decode_length in enumerate(candidate_decode_lengths): + active_stop = prefix_length + decode_length + assistant_mask[sequence_index, prefix_length:active_stop] = True + decode_span = max(1, min(config.decode_tokens, decode_length)) + cursor = prefix_length + branch = 1 + while cursor < active_stop: + end = min(active_stop, cursor + decode_span) + group_ids[sequence_index, cursor:end] = branch + parent_ids[sequence_index, cursor:end] = 0 + cursor = end + branch += 1 + logprobs = ( + torch.randn( + shape, + generator=generator, + dtype=torch.float32, + ) + * 0.25 + - 1.75 + ) + advantages = ( + torch.randn( + shape, + generator=generator, + dtype=torch.float32, + ) + * 0.1 + + 1.0 + ) + weights = torch.ones(shape, dtype=torch.float32) + return { + "tokens": tokens, + "group_ids": group_ids, + "parent_ids": parent_ids, + "input_pos": input_pos, + "assistant_mask": assistant_mask, + "logprobs": logprobs, + "advantages": advantages, + "weights": weights, + "pixel_values": [None] * config.num_sequences, + "image_grid_thw": [None] * config.num_sequences, + } + + +def _create_packed_tensors( + case_config: OracleCaseConfig, + packed_dir: Path, +) -> DiskPackedTensorsSpec: + """Persists packed tensors to disk and returns their descriptor.""" + from art.preprocessing.pack import PackedTensors, packed_tensors_to_dir + + packed_tensors = cast( + PackedTensors, + _build_packed_tensors(case_config.packed_tensors, case_config.seed), + ) + descriptor = packed_tensors_to_dir(packed_tensors, str(packed_dir)) + return DiskPackedTensorsSpec.model_validate(descriptor) + + +def ensure_case_artifacts(case_config: OracleCaseConfig) -> CaseArtifacts: + """Ensures stable case-level artifacts (input tensors) are present and reusable.""" + case_id = stable_case_id(case_config) + case_dir = ARTIFACT_ROOT / case_id + case_dir.mkdir(parents=True, exist_ok=True) + _write_json(case_dir / "case_config.json", case_config.model_dump(mode="json")) + regenerate = regenerate_requested() + + descriptor_path = case_dir / "packed_tensors.json" + packed_dir = case_dir / "packed_tensors" + if descriptor_path.exists() and not regenerate: + packed_spec = DiskPackedTensorsSpec.model_validate(_read_json(descriptor_path)) + else: + if packed_dir.exists(): + shutil.rmtree(packed_dir) + packed_spec = _create_packed_tensors(case_config, packed_dir) + _write_json(descriptor_path, packed_spec.model_dump(mode="json")) + + shared_init_path = case_dir / "shared_init" / "adapter_model.safetensors" + shared_init_path.parent.mkdir(parents=True, exist_ok=True) + return CaseArtifacts( + case_id=case_id, + case_dir=str(case_dir), + packed_tensors=packed_spec, + shared_init_adapter_path=str(shared_init_path), + ) + + +def _replace_topology_dir(path: Path) -> None: + """Resets one topology output directory before regeneration.""" + if path.exists(): + shutil.rmtree(path) + path.mkdir(parents=True, exist_ok=True) + (path / "traces").mkdir(parents=True, exist_ok=True) + + +def _topology_output_slug( + topology: Topology, + mutation: SensitivityMutation | None = None, +) -> str: + """Builds output slug for a topology and optional mutation variant.""" + return topology.slug() if mutation is None else f"{topology.slug()}__{mutation}" + + +def _load_manifest(topology_dir: Path) -> RunManifest: + """Loads one run manifest for a topology output directory.""" + manifest_path = topology_dir / "manifest.json" + return RunManifest.model_validate(_read_json(manifest_path)) + + +def _load_output_tensor(topology_dir: Path, step: StepTrace): + """Loads one output trace tensor referenced by a step trace entry.""" + import torch + + path = topology_dir / step.output_file + return torch.load(path, map_location="cpu") + + +def _load_safetensor_map(path: Path) -> dict[str, Any]: + """Loads one safetensor map from disk.""" + from safetensors.torch import load_file + + return load_file(str(path)) + + +def _align_sequence_parallel(reference, candidate): # type: ignore[no-untyped-def] + """Aligns sequence-parallel-shaped tensors so diff computation is topology-agnostic.""" + if reference.shape == candidate.shape: + return candidate + if ( + candidate.ndim == reference.ndim + 1 + and candidate.shape[0] * candidate.shape[1] == reference.shape[0] + and tuple(candidate.shape[2:]) == tuple(reference.shape[1:]) + ): + return candidate.reshape(reference.shape) + return None + + +def _load_forward_trace( + topology_dir: Path, step_index: int +) -> dict[str, list[dict[str, Any]]]: + """Loads one merged forward-trace file for a given step.""" + trace_path = topology_dir / "traces" / f"forward_trace_step_{step_index:03d}.pt" + return ForwardTraceCapture.load_trace(trace_path) + + +def _finite_metric(value: float, *, default: float = NON_FINITE_METRIC_VALUE) -> float: + """Maps NaN/Inf metric values to a large finite sentinel for JSON-safe reports.""" + value_f = float(value) + if math.isnan(value_f): + return default + if math.isinf(value_f): + return default if value_f > 0 else -default + return value_f + + +def _triplet_expert_key(param: str) -> tuple[str, int] | None: + """Returns (projection, expert_id) for expert gate/up/down params.""" + match = EXPERT_TRIPLET_PARAM_RE.search(param) + if match is None: + return None + return match.group("proj"), int(match.group("expert")) + + +def _layer_agnostic_param_key(param: str) -> str | None: + """Normalizes one parameter name by stripping the explicit layer index.""" + if LAYER_INDEX_RE.search(param) is None: + return None + return LAYER_INDEX_RE.sub("layers.__layer_avg__.", param, count=1) + + +def _stacked_layers( + pairs: list[tuple[str, Any, Any]], +) -> list[tuple[str, Any, Any]]: + """Builds layer-stacked tensor pairs keyed without explicit layer index.""" + import torch + + grouped: dict[str, list[tuple[Any, Any]]] = {} + original_names_by_group: dict[str, list[str]] = {} + for name, reference, candidate in pairs: + normalized = _layer_agnostic_param_key(name) + if normalized is None: + raise RuntimeError( + f"Expected all compared params to include a layer index, got '{name}'." + ) + grouped.setdefault(normalized, []).append( + (reference.detach().float(), candidate.detach().float()) + ) + original_names_by_group.setdefault(normalized, []).append(name) + + stacked_pairs: list[tuple[str, Any, Any]] = [] + for normalized in sorted(grouped): + group = grouped[normalized] + reference_shapes = {tuple(reference.shape) for reference, _ in group} + candidate_shapes = {tuple(candidate.shape) for _, candidate in group} + if len(reference_shapes) != 1 or len(candidate_shapes) != 1: + original_names = original_names_by_group[normalized] + for original_name, (reference, candidate) in zip(original_names, group): + stacked_pairs.append((original_name, reference, candidate)) + continue + stacked_pairs.append( + ( + normalized, + torch.stack([reference for reference, _ in group], dim=0), + torch.stack([candidate for _, candidate in group], dim=0), + ) + ) + return stacked_pairs + + +class VariantRunner: + """Runs oracle/candidate variants and emits row-level comparison reports.""" + + def __init__( + self, + *, + case_config: OracleCaseConfig, + console: Console | None = None, + ) -> None: + self.case_config = case_config + self.case_artifacts = ensure_case_artifacts(case_config) + self.case_id = self.case_artifacts.case_id + self.case_dir = Path(self.case_artifacts.case_dir) + self.oracle_slug = ORACLE_TOPOLOGY.slug() + self.oracle_dir = self.case_dir / self.oracle_slug + self.oracle_routing_bundle_dir = ( + self.case_dir / ORACLE_MOE_ROUTING_BUNDLE_DIRNAME + ) + self.shared_init_path = Path(self.case_artifacts.shared_init_adapter_path) + self.console = console or Console(width=140) + self._oracle_initialized = False + self._oracle_regenerated = False + + def _run_topology( + self, + *, + topology: Topology, + output_slug: str, + mutation: SensitivityMutation | None, + replay_bundle_dir: Path | None, + capture_bundle_dir: Path | None, + regenerate: bool, + ) -> Path: + """Executes one topology worker run and returns its output directory.""" + topology_dir = self.case_dir / output_slug + manifest_path = topology_dir / "manifest.json" + if manifest_path.exists() and not regenerate: + return topology_dir + _replace_topology_dir(topology_dir) + run_case_config = self.case_config + request = WorkerRunRequest( + case_id=self.case_id, + case_config=run_case_config, + topology=topology, + topology_dir=str(topology_dir), + packed_tensors=self.case_artifacts.packed_tensors, + shared_init_adapter_path=str(self.shared_init_path), + mutation=mutation, + moe_routing_replay_path=( + None if replay_bundle_dir is None else str(replay_bundle_dir) + ), + moe_routing_replay_strict=True, + capture_moe_routing_bundle_path=( + None if capture_bundle_dir is None else str(capture_bundle_dir) + ), + ) + from .megatron_oracle_worker import run_worker_subprocess + + run_worker_subprocess(request, topology_dir, repo_root=REPO_ROOT) + return topology_dir + + def ensure_oracle(self) -> Path: + """Ensures oracle capture and canonical replay artifacts exist exactly once per session.""" + regenerate = regenerate_requested() + if self._oracle_initialized and (not regenerate or self._oracle_regenerated): + return self.oracle_dir + if regenerate and self.shared_init_path.exists(): + self.shared_init_path.unlink() + bundle_manifest = self.oracle_routing_bundle_dir / "manifest.json" + oracle_manifest = self.oracle_dir / "manifest.json" + need_capture = ( + regenerate + or not bundle_manifest.exists() + or not self.shared_init_path.exists() + ) + run_oracle_topology = partial( + self._run_topology, + topology=ORACLE_TOPOLOGY, + mutation=None, + regenerate=True, + ) + if need_capture: + run_oracle_topology( + output_slug=f"{self.oracle_slug}__oracle_capture", + replay_bundle_dir=None, + capture_bundle_dir=self.oracle_routing_bundle_dir, + ) + if regenerate or not oracle_manifest.exists(): + run_oracle_topology( + output_slug=self.oracle_slug, + replay_bundle_dir=self.oracle_routing_bundle_dir, + capture_bundle_dir=None, + ) + self._oracle_initialized = True + self._oracle_regenerated = self._oracle_regenerated or regenerate + return self.oracle_dir + + def ensure_variant_artifacts( + self, + variant: VariantSpec, + ) -> Path: + """Ensures oracle prerequisites and candidate artifacts for one variant.""" + self.ensure_oracle() + output_slug = variant.resolved_output_slug() + if output_slug == self.oracle_slug and variant.mutation is None: + return self.oracle_dir + return self._run_topology( + topology=variant.topology, + output_slug=output_slug, + mutation=variant.mutation, + replay_bundle_dir=self.oracle_routing_bundle_dir, + capture_bundle_dir=None, + regenerate=variant.force_regenerate, + ) + + @staticmethod + def _apply_phase_pass( + *, + row: MetricRow, + phase: str, + summary: MetricSummary, + pass_fn_by_phase: dict[str, PhasePassFn], + ) -> None: + """Evaluates a per-phase pass function against one summary payload.""" + pass_fn = pass_fn_by_phase.get(phase) + if pass_fn is None: + row.pass_signal = True + row.failure_reasons = [] + return + row.pass_signal = bool(pass_fn(summary)) + if row.pass_signal: + row.failure_reasons = [] + return + explain = getattr(pass_fn, "failure_reasons", None) + if callable(explain): + reasons = explain(summary) + row.failure_reasons = ( + reasons if reasons else ["phase pass function returned false"] + ) + return + row.failure_reasons = ["phase pass function returned false"] + + @staticmethod + def _inf_summary() -> dict[str, float]: + """Builds a large-error finite summary for structural mismatches.""" + return { + "numel": 0.0, + "mean_abs_diff": NON_FINITE_METRIC_VALUE, + "relative_l2": NON_FINITE_METRIC_VALUE, + "typical_abs_scale": 0.0, + "mean_abs_pct": NON_FINITE_METRIC_VALUE, + "topk_mismatch_fraction": 1.0, + "top1_mismatch_fraction": 1.0, + } + + def _build_metric_row( + self, + *, + variant: VariantSpec, + step_index: int, + phase: str, + param: str, + summary: dict[str, float], + structural_failure: str | None = None, + ) -> MetricRow: + """Builds one metric row and applies per-phase pass evaluation.""" + row = MetricRow( + case_id=self.case_id, + variant=variant.name, + topology=variant.resolved_output_slug(), + oracle_topology=variant.resolved_reference_slug(), + step_index=step_index, + phase=phase, + param=param, + numel=summary["numel"], + mean_abs_diff=summary["mean_abs_diff"], + relative_l2=summary["relative_l2"], + typical_abs_scale=summary["typical_abs_scale"], + mean_abs_pct=summary["mean_abs_pct"], + topk_mismatch_fraction=summary.get("topk_mismatch_fraction"), + top1_mismatch_fraction=summary.get("top1_mismatch_fraction"), + ) + self._apply_phase_pass( + row=row, + phase=phase, + summary=summary, + pass_fn_by_phase=variant.pass_fn_by_phase, + ) + if structural_failure is not None: + row.pass_signal = False + row.failure_reasons = [structural_failure, *row.failure_reasons] + return row + + def _build_metric_rows_from_tensor_pairs( + self, + *, + variant: VariantSpec, + step_index: int, + phase: str, + pairs: list[tuple[str, Any, Any]], + router_ids: bool = False, + layer_averaged: bool = False, + ) -> list[MetricRow]: + """Builds rows from named tensor pairs with one shared diff path.""" + rows: list[MetricRow] = [] + for name, reference, candidate in pairs: + reference_aligned = reference + candidate_aligned = candidate + aligned_candidate = _align_sequence_parallel( + reference_aligned, candidate_aligned + ) + if aligned_candidate is None: + rows.append( + self._build_metric_row( + variant=variant, + step_index=step_index, + phase=phase, + param=name, + summary=self._inf_summary(), + structural_failure="shape mismatch", + ) + ) + continue + summary: dict[str, float] + if router_ids: + accumulator = DiffAccumulator() + accumulator.update_router_ids(reference_aligned, aligned_candidate) + summary = accumulator.as_summary() + elif layer_averaged: + summary = DiffAccumulator.layer_averaged_summary( + reference_aligned, aligned_candidate + ) + else: + accumulator = DiffAccumulator() + accumulator.update(reference_aligned, aligned_candidate) + summary = accumulator.as_summary() + rows.append( + self._build_metric_row( + variant=variant, + step_index=step_index, + phase=phase, + param=name, + summary=summary, + ) + ) + return rows + + def _check_matching_keys( + self, + reference: dict[str, Any], + candidate: dict[str, Any], + variant: VariantSpec, + step_index: int, + phase: str, + ) -> tuple[bool, list[MetricRow] | None]: + """Checks if the keys of two tensor maps match and builds a metric row if they don't.""" + reference_keys = set(reference.keys()) + candidate_keys = set(candidate.keys()) + if reference_keys != candidate_keys: + missing = sorted(reference_keys - candidate_keys) + extra = sorted(candidate_keys - reference_keys) + return False, [ + self._build_metric_row( + variant=variant, + step_index=step_index, + phase=phase, + param="__keys__", + summary=self._inf_summary(), + structural_failure=f"missing={missing[:5]} extra={extra[:5]}", + ) + ] + return True, None + + def _build_metric_rows_from_tensor_maps( + self, + *, + variant: VariantSpec, + step_index: int, + phase: str, + reference: dict[str, Any], + candidate: dict[str, Any], + router_ids: bool = False, + ) -> list[MetricRow]: + """Builds rows from two keyed tensor maps through a unified compare path.""" + matching, rows = self._check_matching_keys( + reference, candidate, variant, step_index, phase + ) + if not matching: + return rows if rows is not None else [] + pairs = [ + (key, reference[key], candidate[key]) + for key in sorted(set(reference.keys())) + ] + if phase in {"forward", "grads", "deltas"}: + pairs = _stacked_layers(pairs) + return self._build_metric_rows_from_tensor_pairs( + variant=variant, + step_index=step_index, + phase=phase, + pairs=pairs, + router_ids=router_ids, + layer_averaged=phase in {"forward", "grads", "deltas"}, + ) + + @staticmethod + def _build_step_summaries(rows: list[MetricRow]) -> dict[int, dict[str, Any]]: + """Builds step-indexed payloads directly from row model dumps.""" + step_summaries: dict[int, dict[str, Any]] = {} + for row in rows: + step_entry = step_summaries.setdefault(row.step_index, {}) + phase_entry = cast(dict[str, Any], step_entry.setdefault(row.phase, {})) + phase_entry[row.param] = row.model_dump(mode="json") + return step_summaries + + def compare_variant(self, variant: VariantSpec) -> VariantReport: + """Compares one candidate variant against its reference topology.""" + reference_slug = variant.resolved_reference_slug() + topology_slug = variant.resolved_output_slug() + reference_dir = self.case_dir / reference_slug + topology_dir = self.case_dir / topology_slug + reference_manifest = _load_manifest(reference_dir) + topology_manifest = _load_manifest(topology_dir) + rows: list[MetricRow] = [] + if len(reference_manifest.steps) != len(topology_manifest.steps): + rows.append( + self._build_metric_row( + variant=variant, + step_index=0, + phase="step_count", + param="__step_count__", + summary=self._inf_summary(), + structural_failure=( + f"reference={len(reference_manifest.steps)} " + f"candidate={len(topology_manifest.steps)}" + ), + ) + ) + + import torch + + for reference_step, topology_step in zip( + reference_manifest.steps, topology_manifest.steps + ): + step_index = reference_step.step_index + reference_trace = _load_forward_trace(reference_dir, step_index) + topology_trace = _load_forward_trace(topology_dir, step_index) + map_phase_inputs = [ + ( + "outputs", + {"logprobs": _load_output_tensor(reference_dir, reference_step)}, + {"logprobs": _load_output_tensor(topology_dir, topology_step)}, + False, + ), + ( + "losses", + {"loss": torch.tensor([reference_step.loss], dtype=torch.float32)}, + {"loss": torch.tensor([topology_step.loss], dtype=torch.float32)}, + False, + ), + ( + "grads", + _load_safetensor_map(reference_dir / reference_step.grads_file), + _load_safetensor_map(topology_dir / topology_step.grads_file), + False, + ), + ( + "deltas", + _load_safetensor_map(reference_dir / reference_step.deltas_file), + _load_safetensor_map(topology_dir / topology_step.deltas_file), + False, + ), + *[ + ( + phase, + ForwardTraceCapture.flatten_trace_tensors( + reference_trace, + value_key=value_key, + ), + ForwardTraceCapture.flatten_trace_tensors( + topology_trace, + value_key=value_key, + ), + phase == "router_topk_ids", + ) + for phase, value_key in ( + ("forward", "primary_output"), + ("router_scores", "router_topk_scores"), + ("router_topk_ids", "router_topk_ids"), + ) + ], + ] + for phase, reference_map, candidate_map, router_ids in map_phase_inputs: + rows.extend( + self._build_metric_rows_from_tensor_maps( + variant=variant, + step_index=step_index, + phase=phase, + reference=reference_map, + candidate=candidate_map, + router_ids=router_ids, + ) + ) + pass_count = sum(1 for row in rows if row.pass_signal) + fail_count = len(rows) - pass_count + signal: Literal["pass", "fail"] = "pass" if fail_count == 0 else "fail" + return VariantReport( + case_id=self.case_id, + variant=variant.name, + topology=topology_slug, + reference_topology=reference_slug, + expected_signal=variant.expected_signal, + signal=signal, + pass_count=pass_count, + fail_count=fail_count, + step_summaries=self._build_step_summaries(rows), + metrics=rows, + ) + + @staticmethod + def assert_expected_signal( + report: VariantReport, + context: str, + *, + report_path: Path, + ) -> None: + """Raises when observed run signal diverges from variant expectation.""" + if report.signal == report.expected_signal: + return + if report.signal == "fail": + first_failure = next(row for row in report.metrics if not row.pass_signal) + raise AssertionError( + f"{context}: topology={report.topology} phase={first_failure.phase} " + f"step={first_failure.step_index} param={first_failure.param} " + f"reasons={'; '.join(first_failure.failure_reasons)} " + f"report={report_path}" + ) + raise AssertionError( + f"{context}: expected_signal={report.expected_signal} " + f"observed_signal={report.signal} topology={report.topology} " + f"report={report_path}" + ) + + def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> None: + """Persists full variant report JSON for debugging and regression inspection.""" + _write_json( + topology_dir / "variant_report.json", report.model_dump(mode="json") + ) + + def print_report(self, report: VariantReport) -> None: + """Prints a row-level table with expert rows subsampled by highest mean_abs_pct.""" + non_expert_rows: list[MetricRow] = [] + triplet_rows: list[tuple[tuple[str, int], MetricRow]] = [] + for row in report.metrics: + expert_key = _triplet_expert_key(row.param) + if expert_key is None: + non_expert_rows.append(row) + continue + triplet_rows.append((expert_key, row)) + + scores_by_proj: dict[str, dict[int, float]] = {} + for (projection, expert_id), row in triplet_rows: + projection_scores = scores_by_proj.setdefault(projection, {}) + projection_scores[expert_id] = max( + projection_scores.get(expert_id, float("-inf")), row.mean_abs_pct + ) + + selected_experts: set[tuple[str, int]] = set() + for projection, expert_scores in scores_by_proj.items(): + top_experts = sorted( + expert_scores.items(), + key=lambda item: item[1], + reverse=True, + )[:EXPERT_TABLE_ROW_LIMIT] + for expert_id, _score in top_experts: + selected_experts.add((projection, expert_id)) + + selected_triplet_rows = [ + row for expert_key, row in triplet_rows if expert_key in selected_experts + ] + table_rows = non_expert_rows + selected_triplet_rows + detail_table = Table( + title=( + f"Variant Report | variant={report.variant} " + f"| selected_experts={len(selected_experts)} " + f"(top {EXPERT_TABLE_ROW_LIMIT} per projection by mean_abs_pct)" + ), + box=box.SIMPLE_HEAVY, + show_lines=False, + ) + detail_table.add_column("Step", justify="right") + detail_table.add_column("Phase", style="cyan") + detail_table.add_column("Param", overflow="fold") + detail_table.add_column("Status") + detail_table.add_column("relative_l2", justify="right") + detail_table.add_column("mean_abs_pct", justify="right") + detail_table.add_column("typical_abs", justify="right") + detail_table.add_column("mean_abs_diff", justify="right") + detail_table.add_column("Failure") + sorted_rows = sorted( + table_rows, + key=lambda row: ( + row.step_index, + PHASE_PRINT_ORDER.get(row.phase, 999), + row.param, + row.pass_signal, + ), + ) + for row in sorted_rows: + status_text = ( + "[green]PASS[/green]" if row.pass_signal else "[red]FAIL[/red]" + ) + failure_text = "" if row.pass_signal else "; ".join(row.failure_reasons) + detail_table.add_row( + str(row.step_index), + row.phase, + row.param, + status_text, + f"{row.relative_l2:.6g}", + f"{row.mean_abs_pct:.6g}%", + f"{row.typical_abs_scale:.6g}", + f"{row.mean_abs_diff:.6g}", + failure_text, + ) + self.console.print(detail_table) + + def run_variant( + self, + variant: VariantSpec, + ) -> VariantReport: + """Runs a variant end-to-end, writes JSON report, and prints row table.""" + topology_dir = self.ensure_variant_artifacts(variant) + report = self.compare_variant(variant) + self._write_variant_report(topology_dir, report) + self.print_report(report) + return report + + def run_suite( + self, + variants: list[VariantSpec], + ) -> list[VariantReport]: + """Runs variants in order and stops at the first unexpected signal.""" + reports: list[VariantReport] = [] + for variant in variants: + report = self.run_variant(variant) + reports.append(report) + self.assert_expected_signal( + report, + "Megatron correctness suite mismatch", + report_path=self.case_dir + / variant.resolved_output_slug() + / "variant_report.json", + ) + return reports + + +def _default_phase_pass_fns() -> dict[str, PhasePassFn]: + """Builds default per-phase pass functions over diff summaries.""" + # note the metrics get averaged across layers to reduce noise + # we don't expect particular layers to see errors as opposed to the others so this is helpful + fwd_out_loss = MetricThresholdRule( + limits={"relative_l2": 1e-2, "mean_abs_pct": 1.0} + ) + grads_deltas = MetricThresholdRule(limits={"mean_abs_pct": 10.0}) + router_topk_rule = ( + MetricThresholdRule( # should be no mismatch due to router replay + limits={ + "topk_mismatch_fraction": 0.0, + "top1_mismatch_fraction": 0.0, + } + ) + ) + return {key: fwd_out_loss for key in ["forward", "outputs", "losses"]} | { + "grads": grads_deltas, + "deltas": grads_deltas, + "router_topk_ids": router_topk_rule, + } + + +def _suite_variants() -> list[VariantSpec]: + """Builds the standard oracle suite variant ordering.""" + phase_pass = _default_phase_pass_fns() + variants = [ + VariantSpec( + name="oracle_replay_parity", + topology=ORACLE_TOPOLOGY, + output_slug=_topology_output_slug( + ORACLE_TOPOLOGY, ORACLE_REPLAY_TOPOLOGY_SUFFIX + ), + pass_fn_by_phase=phase_pass, + force_regenerate=regenerate_requested(), + ) + ] + for topology in TOPOLOGIES[1:] + ( + EXTENDED_TOPOLOGIES if extended_topologies_enabled() else [] + ): + variants.append( + VariantSpec( + name=f"topology_{topology.slug()}", + topology=topology, + pass_fn_by_phase=phase_pass, + ) + ) + return variants + + +def run_suite( + *, + case_config: OracleCaseConfig, +) -> list[VariantReport]: + """Runs replay parity and topology variants with fail-fast assertions.""" + runner = VariantRunner(case_config=case_config) + return runner.run_suite(_suite_variants()) + + +def run_sensitivity_suite( + *, + case_config: OracleCaseConfig, + mutations: list[SensitivityMutation], +) -> list[VariantReport]: + """Runs a list of sensitivity mutations and expects each to fail.""" + runner = VariantRunner(case_config=case_config) + phase_pass = _default_phase_pass_fns() + variants = [ + VariantSpec( + name=f"sensitivity_{mutation}", + topology=sensitivity_topology_for_mutation(mutation), + mutation=mutation, + expected_signal="fail", + pass_fn_by_phase=phase_pass, + ) + for mutation in mutations + ] + return runner.run_suite(variants) diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py new file mode 100644 index 00000000..f84179b3 --- /dev/null +++ b/tests/integration/megatron_oracle_worker.py @@ -0,0 +1,952 @@ +from __future__ import annotations + +import argparse +from contextlib import ExitStack, contextmanager +import hashlib +import os +from pathlib import Path +import random +import subprocess +import sys +from types import MethodType +from typing import Any, Callable + +import numpy as np +import torch + +from art.megatron.routing_replay import ( + ParallelTopology as ReplayParallelTopology, +) +from art.megatron.routing_replay import ( + build_bundle_from_forward_trace_dir, +) + +from .megatron_forward_trace import ForwardTraceCapture +from .megatron_oracle_harness import ( + SUPPORTED_SENSITIVITY_MUTATIONS, + OracleCaseConfig, + RunManifest, + SensitivityMutation, + StepTrace, + Topology, + WorkerRunRequest, + _read_json, + _require_not_none, + _write_json, +) + + +def run_worker_subprocess( + request: WorkerRunRequest, + topology_dir: Path, + *, + repo_root: Path, +) -> None: + """Runs one distributed worker subprocess and stores combined logs.""" + request_path = topology_dir / "run_request.json" + _write_json(request_path, request.model_dump(mode="json")) + worker_module = "integration.megatron_oracle_worker" + worker_cwd = repo_root / "tests" + + command = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(request.topology.world_size()), + "-m", + worker_module, + "--worker-run", + "--run-request", + str(request_path), + ] + run = subprocess.run( + command, + cwd=str(worker_cwd), + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + capture_output=True, + text=True, + check=False, + ) + combined_output = f"{run.stdout}\n{run.stderr}".strip() + (topology_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") + if run.returncode != 0: + tail = "\n".join(combined_output.splitlines()[-80:]) + raise RuntimeError( + f"Topology run failed for {request.topology.slug()} with exit code " + f"{run.returncode}.\n{tail}" + ) + + +def _set_deterministic_seed(seed: int) -> None: + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def _merge_sharded_dicts(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any]: + """Merges rank-sharded LoRA tensors into a full state dict on rank 0.""" + import torch + + merged: dict[str, list[Any]] = {} + for rank_shards in shards_by_rank: + for key, tensor in rank_shards.items(): + merged.setdefault(key, []).append(tensor.detach().cpu()) + full_state: dict[str, Any] = {} + for key, shards in merged.items(): + if len(shards) == 1: + full_state[key] = shards[0].contiguous() + continue + concat_dim = 1 if ".lora_A." in key else 0 + full_state[key] = torch.cat(shards, dim=concat_dim).contiguous() + return full_state + + +def _gather_full_state( + local_state: dict[str, Any], +) -> dict[str, Any] | None: + """Gathers local state dicts to rank 0 and merges them.""" + import torch + + rank = torch.distributed.get_rank() # ty: ignore[possibly-missing-attribute] + world_size = torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] + gathered = [None for _ in range(world_size)] if rank == 0 else None + torch.distributed.gather_object( # ty: ignore[possibly-missing-attribute] + local_state, gathered, dst=0 + ) + if rank != 0: + return None + assert gathered is not None + entries = [entry for entry in gathered if entry is not None] + return _merge_sharded_dicts(entries) + + +def _collect_lora_state( + model_chunks: list[Any], +) -> dict[str, Any] | None: + """Collects full LoRA adapter state for validation and delta computation.""" + local_state: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if not hasattr(module, "sharded_lora_state_dict"): + continue + module_state = module.sharded_lora_state_dict() + for key, value in module_state.items(): + if key in local_state: + raise RuntimeError( + f"Duplicate LoRA key while collecting state: {key}" + ) + local_state[key] = value.detach().cpu() + return _gather_full_state(local_state) + + +def _collect_lora_grads( + model_chunks: list[Any], +) -> dict[str, Any] | None: + """Collects full LoRA gradient tensors across all ranks.""" + from art.megatron.lora import LoRA + + local_grads: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + for key, param, expert in module._export_items(): # type: ignore[attr-defined] + if not hasattr(param, "main_grad"): + raise RuntimeError( + f"LoRA param missing main_grad attribute for key '{key}'" + ) + grad = param.main_grad + if grad is None: + raise RuntimeError(f"LoRA param main_grad is None for key '{key}'") + if hasattr(grad, "_local_tensor"): + grad = grad._local_tensor + captured_grad = grad[expert] if expert is not None else grad + local_grads[key] = captured_grad.detach().cpu().T + return _gather_full_state(local_grads) + + +def _apply_save_mutation_to_tensor_map( + tensor_map: dict[str, Any], + *, + mutation: SensitivityMutation | None, +) -> dict[str, Any]: + """Applies save-only mutation transforms to already-collected full tensor maps.""" + if mutation == "save_drop_nonzero_ranked_tp_shards": + mutated: dict[str, Any] = {} + for key, value in tensor_map.items(): + if not isinstance(value, torch.Tensor): + mutated[key] = value + continue + if ".lora_A." in key and value.ndim >= 2 and value.shape[1] > 1: + keep = max(1, value.shape[1] // 2) + mutated[key] = value.narrow(1, 0, keep).contiguous() + continue + if ".lora_B." in key and value.ndim >= 2 and value.shape[0] > 1: + keep = max(1, value.shape[0] // 2) + mutated[key] = value.narrow(0, 0, keep).contiguous() + continue + mutated[key] = value + return mutated + + if mutation == "save_duplicate_replicated_entries": + mutated = dict(tensor_map) + source_by_bucket: dict[tuple[tuple[int, ...], str], torch.Tensor] = {} + for key in sorted(mutated.keys()): + value = mutated[key] + if not isinstance(value, torch.Tensor): + continue + if not key.endswith(".weight"): + continue + bucket = (tuple(value.shape), str(value.dtype)) + source = source_by_bucket.get(bucket) + if source is None: + source_by_bucket[bucket] = value + continue + mutated[key] = source.clone().contiguous() + return mutated + + return tensor_map + + +def _validate_loaded_state_matches_adapter( + loaded_state: dict[str, Any], + adapter_model: dict[str, Any], +) -> None: + """Checks loaded model LoRA state exactly matches adapter tensors and keys.""" + import torch + + for key in sorted(adapter_model.keys()): + assert torch.equal(loaded_state[key].cpu(), adapter_model[key].cpu()), ( + f"Loaded LoRA state mismatch for key '{key}'" + ) + + +def _build_deterministic_shared_init( + initial_state: dict[str, Any], + *, + seed: int, +) -> dict[str, Any]: + """Builds deterministic nonzero LoRA init values for both A and B tensors.""" + initialized: dict[str, Any] = {} + for key in sorted(initial_state.keys()): + value = initial_state[key] + if not isinstance(value, torch.Tensor): + raise TypeError(f"Expected tensor value for key '{key}', got {type(value)}") + digest = hashlib.sha256(f"{seed}:{key}".encode("utf-8")).digest() + key_seed = int.from_bytes(digest[:8], "little") % (2**31) + generator = torch.Generator(device="cpu").manual_seed(key_seed) + random_values = torch.randn( + value.shape, + generator=generator, + dtype=torch.float32, + ) + initialized[key] = (0.01 * random_values).to(dtype=value.dtype).contiguous() + return initialized + + +def _configure_provider( + provider: Any, + topology: Topology, + case_config: OracleCaseConfig, +) -> None: + """Applies deterministic topology/model overrides to provider config.""" + provider.tensor_model_parallel_size = topology.tp + provider.expert_model_parallel_size = topology.ep + provider.expert_tensor_parallel_size = topology.etp + # These are intentionally pinned to 1 for now + provider.pipeline_model_parallel_size = 1 + provider.context_parallel_size = 1 + provider.sequence_parallel = topology.sp + provider.num_layers = case_config.num_layers + if case_config.precision == "fp32": + provider.bf16 = False + provider.fp16 = False + provider.params_dtype = torch.float32 + provider.pipeline_dtype = torch.float32 + provider.enable_autocast = False + provider.autocast_dtype = None + provider.attention_softmax_in_fp32 = True + provider.fp32_residual_connection = True + if hasattr(provider, "attention_dropout"): + provider.attention_dropout = 0.0 + if hasattr(provider, "hidden_dropout"): + provider.hidden_dropout = 0.0 + + +def _build_optimizer_config(case_config: OracleCaseConfig): + """Builds Megatron optimizer settings for deterministic harness runs.""" + from megatron.core.optimizer import OptimizerConfig + + if case_config.precision == "fp32": + return OptimizerConfig( + bf16=False, + fp16=False, + params_dtype=torch.float32, + main_grads_dtype=torch.float32, + main_params_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + lr=case_config.learning_rate, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.0, + adam_eps=1e-13, + ) + + return OptimizerConfig( + bf16=True, + fp16=False, + lr=case_config.learning_rate, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.0, + adam_eps=1e-13, + ) + + +def _configure_cuda_precision(case_config: OracleCaseConfig) -> None: + if case_config.precision != "fp32": + return + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.set_float32_matmul_precision("highest") + + +def _assert_runtime_configuration( + model_chunks: list[Any], + case_config: OracleCaseConfig, +) -> None: + """Validates runtime model depth equals requested oracle case config.""" + observed_num_layers: set[int] = set() + + for chunk in model_chunks: + module: Any = chunk + while hasattr(module, "module"): + module = module.module + config = getattr(module, "config", None) + if config is not None and hasattr(config, "num_layers"): + observed_num_layers.add(int(config.num_layers)) + + if observed_num_layers != {case_config.num_layers}: + raise RuntimeError( + "Runtime num_layers mismatch: " + f"requested={case_config.num_layers}, observed={sorted(observed_num_layers)}" + ) + + +def _delta_state( + initial_state: dict[str, Any], + current_state: dict[str, Any], +) -> dict[str, Any]: + """Computes LoRA parameter deltas while enforcing stable key sets.""" + initial_keys = set(initial_state.keys()) + current_keys = set(current_state.keys()) + if initial_keys != current_keys: + missing = sorted(initial_keys - current_keys) + extra = sorted(current_keys - initial_keys) + raise KeyError( + f"LoRA state keys changed during training: missing={missing[:3]} extra={extra[:3]}" + ) + return { + key: current_state[key].detach().cpu() - initial_state[key].detach().cpu() + for key in sorted(initial_keys) + } + + +def _iter_named_unique_parameters( + model_chunks: list[Any], +) -> list[tuple[str, torch.nn.Parameter]]: + seen: set[int] = set() + params: list[tuple[str, torch.nn.Parameter]] = [] + for chunk_index, chunk in enumerate(model_chunks): + for name, param in chunk.named_parameters(): + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + params.append((f"chunk{chunk_index}.{name}", param)) + return params + + +def _matches_grad_sync_skip_mutation( + param_name: str, mutation: SensitivityMutation +) -> bool: + if mutation == "bwd_skip_sync_qkv_a": + return any( + token in param_name + for token in ( + ".self_attention.linear_qkv.q_proj_lora.A_T", + ".self_attention.linear_qkv.k_proj_lora.A_T", + ".self_attention.linear_qkv.v_proj_lora.A_T", + ) + ) + if mutation == "bwd_skip_sync_o_proj_b": + return ".self_attention.linear_proj.lora.B_T" in param_name + if mutation == "bwd_skip_sync_fc1_a": + return ( + ".mlp.experts.linear_fc1.gate_lora.A_T" in param_name + or ".mlp.experts.linear_fc1.up_lora.A_T" in param_name + ) + return False + + +@contextmanager +def _apply_grad_sync_skip_mutation( + model_chunks: list[Any], + mutation: SensitivityMutation | None, +): + if mutation not in { + "bwd_skip_sync_qkv_a", + "bwd_skip_sync_o_proj_b", + "bwd_skip_sync_fc1_a", + }: + yield + return + + saved_attrs: list[tuple[Any, str, Any]] = [] + for param_name, param in _iter_named_unique_parameters(model_chunks): + # this only passes lora params atm, so we assume lora params below + if not _matches_grad_sync_skip_mutation(param_name, mutation): + continue + if ( + mutation == "bwd_skip_sync_fc1_a" and param.grad_sync_domain != "expert_tp" # ty: ignore[unresolved-attribute] + ): + continue + + # For fc1 A params, extended finalize handles expert-TP sync via grad_sync_op. + saved_attrs.append((param, "grad_sync_op", param.grad_sync_op)) # ty: ignore[unresolved-attribute] + param.grad_sync_op = "none" # ty: ignore[unresolved-attribute] + + # Megatron native TP finalize uses this only for tp_default-domain params. + average_gradients_across_tp_domain = param.average_gradients_across_tp_domain # ty: ignore[unresolved-attribute] + grad_sync_domain = param.grad_sync_domain # ty: ignore[unresolved-attribute] + if average_gradients_across_tp_domain and grad_sync_domain == "tp_default": + saved_attrs.append( + ( + param, + "average_gradients_across_tp_domain", + average_gradients_across_tp_domain, + ) + ) + param.average_gradients_across_tp_domain = False # ty: ignore[unresolved-attribute] + try: + yield + finally: + for param, attr, value in reversed(saved_attrs): + setattr(param, attr, value) + + +@contextmanager +def _apply_o_proj_forward_mutation( + model_chunks: list[Any], + mutation: SensitivityMutation | None, +): + if mutation not in { + "fwd_skip_o_proj_tp_reduce", + "fwd_o_proj_tp_reduce_avg_not_sum", + }: + yield + return + + from megatron.core import parallel_state as ps + from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, + ) + + from art.megatron.lora import SelfAttentionLinearProjLoRA + + original_forwards: list[tuple[Any, Any]] = [] + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, SelfAttentionLinearProjLoRA): + continue + original_forwards.append((module, module.forward)) + + def _mutated_forward(self: Any, x: Any): + base_output, bias_output = self.linear_proj(x) + lora_output = self.lora(x) + tp_size = self.provider.tensor_model_parallel_size + if tp_size > 1: + if mutation == "fwd_o_proj_tp_reduce_avg_not_sum": + if self.provider.sequence_parallel: + lora_output = reduce_scatter_to_sequence_parallel_region( + lora_output + ) + else: + lora_output = reduce_from_tensor_model_parallel_region( + lora_output + ) + lora_output = lora_output / tp_size + elif mutation == "fwd_skip_o_proj_tp_reduce": + if self.provider.sequence_parallel: + seq_per_rank = lora_output.shape[0] // tp_size + tp_rank = ps.get_tensor_model_parallel_rank() + lora_output = lora_output.narrow( + 0, tp_rank * seq_per_rank, seq_per_rank + ) + return base_output + lora_output, bias_output + + module.forward = MethodType(_mutated_forward, module) + + try: + yield + finally: + for module, original_forward in reversed(original_forwards): + module.forward = original_forward + + +@contextmanager +def _patch_lora_for_fp32( + model_chunks: list[Any], + optimizer: Any, +): + """ + torch grouped_gemm is bf16 only, so we have a simple custom fp32 path + to make the numbers match closely + """ + from art.megatron.lora import LoRA + + del model_chunks + del optimizer + original_forward = LoRA.forward + + def _reference_forward( + self: Any, + x: torch.Tensor, + tokens_per_expert: list[int] | torch.Tensor | None = None, + ) -> torch.Tensor: + work_dtype = ( + torch.float32 + if torch.is_floating_point(x) and x.dtype != torch.float32 + else x.dtype + ) + work_x = x.to(dtype=work_dtype) + work_a = self.A_T.to(dtype=work_dtype) + work_b = self.B_T.to(dtype=work_dtype) + + if tokens_per_expert is None or self.num_local_experts == 1: + return (((work_x @ work_a) @ work_b) * self.scale).to(dtype=x.dtype) + + counts = ( + tokens_per_expert.tolist() + if isinstance(tokens_per_expert, torch.Tensor) + else list(tokens_per_expert) + ) + out = work_x.new_zeros((work_x.shape[0], work_b.shape[-1])) + + cursor = 0 + for expert_index, count in enumerate(counts): + count_int = int(count) + if count_int <= 0: + continue + next_cursor = cursor + count_int + x_chunk = work_x[cursor:next_cursor] + out[cursor:next_cursor] = (x_chunk @ work_a[expert_index]) @ work_b[ + expert_index + ] + cursor = next_cursor + + if cursor != int(work_x.shape[0]): + raise RuntimeError( + "Expert LoRA reference path did not consume all grouped rows: " + f"consumed={cursor}, rows={int(work_x.shape[0])}" + ) + + return (out * self.scale).to(dtype=x.dtype) + + LoRA.forward = _reference_forward # ty: ignore[invalid-assignment] + try: + yield + finally: + LoRA.forward = original_forward + + +@contextmanager +def _mutation_hook( + megatron_train_module: Any, + model_chunks: list[Any], + mutation: SensitivityMutation | None, + topology: Topology, + pre_optimizer_step_hook: Callable[[], None] | None = None, + loss_scale: float = 1.0, +): + """Applies optional sensitivity mutation hooks around training steps.""" + original_finalize = megatron_train_module.finalize_model_grads_extended + original_optimizer_step = megatron_train_module._optimizer_step + original_loss_fn = megatron_train_module.loss_fn + original_local_token_count_tensor = ( + megatron_train_module._local_trainable_token_count_tensor + ) + original_build_micro_sample_indices = ( + megatron_train_module.build_micro_sample_indices + ) + + known_mutations = {None, *SUPPORTED_SENSITIVITY_MUTATIONS} + if mutation not in known_mutations: + raise ValueError(f"Unsupported mutation: {mutation}") + + if mutation == "skip_finalize": + megatron_train_module.finalize_model_grads_extended = ( + lambda _model, **_kwargs: (None) + ) + + if mutation == "dp_local_token_normalization": + + def _wrong_local_trainable_token_count_tensor( + micro_inputs: list[Any], + device: torch.device, + ) -> torch.Tensor: + local_token_total = sum( + megatron_train_module._count_trainable_tokens(micro) + for micro in micro_inputs + ) + dp_world_size = int( + megatron_train_module.ps.get_data_parallel_world_size( + with_context_parallel=True + ) + ) + wrong_local_token_total = local_token_total / max(dp_world_size, 1) + return torch.tensor( + [wrong_local_token_total], + device=device, + dtype=torch.float32, + ) + + megatron_train_module._local_trainable_token_count_tensor = ( + _wrong_local_trainable_token_count_tensor + ) + + if mutation == "dp_grad_accumulation_seqs": + + def _wrong_build_micro_sample_indices( + *, + step_index: int, + num_sequences: int, + global_grad_accumulation_sequences: int, + ) -> list[int | None]: + base_global_sample_index = step_index * global_grad_accumulation_sequences + return [ + (global_sample_index if global_sample_index < num_sequences else None) + for global_sample_index in range( + base_global_sample_index, + base_global_sample_index + global_grad_accumulation_sequences, + ) + ] + + megatron_train_module.build_micro_sample_indices = ( + _wrong_build_micro_sample_indices + ) + + if pre_optimizer_step_hook is not None: + + def _patched_optimizer_step(optimizer: Any, learning_rate: float): + if pre_optimizer_step_hook is not None: + pre_optimizer_step_hook() + return original_optimizer_step(optimizer, learning_rate) + + megatron_train_module._optimizer_step = _patched_optimizer_step + + effective_loss_scale = loss_scale + if effective_loss_scale <= 0: + raise ValueError( + f"effective_loss_scale must be > 0, got {effective_loss_scale}" + ) + if effective_loss_scale != 1.0: + + def _scaled_loss_fn(*args: Any, **kwargs: Any): + loss = original_loss_fn(*args, **kwargs) + return loss.model_copy( + update={ + "policy_loss": loss.policy_loss * effective_loss_scale, + "kl": loss.kl * effective_loss_scale, + "policy_loss_sum": loss.policy_loss_sum * effective_loss_scale, + } + ) + + megatron_train_module.loss_fn = _scaled_loss_fn + + if mutation is None: + if pre_optimizer_step_hook is None and effective_loss_scale == 1.0: + yield + return + with ExitStack() as stack: + stack.enter_context(_apply_o_proj_forward_mutation(model_chunks, mutation)) + stack.enter_context(_apply_grad_sync_skip_mutation(model_chunks, mutation)) + try: + yield + finally: + megatron_train_module.finalize_model_grads_extended = original_finalize + megatron_train_module._optimizer_step = original_optimizer_step + megatron_train_module.loss_fn = original_loss_fn + megatron_train_module._local_trainable_token_count_tensor = ( + original_local_token_count_tensor + ) + megatron_train_module.build_micro_sample_indices = ( + original_build_micro_sample_indices + ) + + +def _worker_run(request: WorkerRunRequest) -> None: + """Executes one full distributed training trace generation worker run.""" + from safetensors.torch import load_file, save_file + import torch + + from art import dev, types + from art.megatron import train as megatron_train + from art.preprocessing.pack import packed_tensors_from_dir + + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(backend="nccl") # ty: ignore[possibly-missing-attribute] + _set_deterministic_seed(request.case_config.seed) + _configure_cuda_precision(request.case_config) + + runtime = megatron_train.build_training_runtime( + model_identifier=request.case_config.base_model, + provider_torch_dtype=( + torch.float32 if request.case_config.precision == "fp32" else torch.bfloat16 + ), + provider_configure=lambda provider: _configure_provider( + provider, request.topology, request.case_config + ), + optimizer_config=_build_optimizer_config(request.case_config), + print_env=False, + print_optimizer_stats=False, + ) + model_chunks = runtime.model + optimizer = runtime.optimizer + megatron_train.configure_moe_routing_replay( + runtime, + replay_bundle_path=request.moe_routing_replay_path, + strict=request.moe_routing_replay_strict, + ) + _assert_runtime_configuration(model_chunks, request.case_config) + + topology_dir = Path(request.topology_dir) + traces_dir = topology_dir / "traces" + traces_dir.mkdir(parents=True, exist_ok=True) + + # setup the shared initial lora + shared_init_path = Path(request.shared_init_adapter_path) + if not shared_init_path.exists(): + initial_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] + shared_init_path.parent.mkdir(parents=True, exist_ok=True) + deterministic_init = _build_deterministic_shared_init( + _require_not_none(initial_state, "initial_state"), + seed=request.case_config.seed, + ) + save_file( + deterministic_init, + str(shared_init_path), + ) + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + + # load the shared initial lora into the model and validate we can collect it from the model + adapter_model = load_file(str(shared_init_path)) + megatron_train.load_adapter_into_model(model_chunks, adapter_model, optimizer) + loaded_state = _collect_lora_state(model_chunks) + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] + _validate_loaded_state_matches_adapter( + _require_not_none(loaded_state, "loaded_state"), adapter_model + ) + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + + # load the inputs + packed_tensors = packed_tensors_from_dir( + **request.packed_tensors.model_dump(exclude_none=True) + ) + template = megatron_train.select_indexed_inputs(packed_tensors, 0) + zero_template = megatron_train._zero_contribution_inputs(template) + initial_lora_state = loaded_state + global_grad_accumulation_sequences = request.case_config.grad_accumulation_sequences + + train_config = types.TrainConfig( + learning_rate=request.case_config.learning_rate, + kl_penalty_coef=0.0, + grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + experimental_config: dev.TrainConfig = {} + step_traces: list[StepTrace] = [] + captured_grads: dict[str, Any] | None = None + routing_replay_controller = runtime.moe_routing_replay_controller + micro_start_callback = ( + routing_replay_controller.begin_micro + if routing_replay_controller is not None + else None + ) + forward_trace_capture = ForwardTraceCapture( + model_chunks, + enabled=True, + micro_start_callback=micro_start_callback, + ) + + def _capture_lora_grads() -> None: + nonlocal captured_grads + captured_grads = _collect_lora_grads(model_chunks) + + with ( + _mutation_hook( + megatron_train, + model_chunks, + request.mutation, + request.topology, + pre_optimizer_step_hook=_capture_lora_grads, + loss_scale=request.case_config.loss_scale, + ), + _patch_lora_for_fp32(model_chunks, optimizer), + ): + for step_index in range(request.case_config.num_steps): + micro_sample_indices = megatron_train.build_micro_sample_indices( + step_index=step_index, + num_sequences=request.packed_tensors.num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + forward_trace_capture.set_step(step_index, micro_sample_indices) + micro_inputs = megatron_train.select_micro_inputs( + packed_tensors, micro_sample_indices, zero_template + ) + captured_grads = None + + step_result = megatron_train.run_training_step( + model_chunks=model_chunks, + optimizer=optimizer, + learning_rate=train_config.learning_rate, + inputs=micro_inputs, + config=train_config, + experimental_config=experimental_config, + ref_logprobs=None, + step_index=step_index, + sample_index=micro_sample_indices, + moe_routing_replay_controller=runtime.moe_routing_replay_controller, + ) + ordered_micro_outputs = forward_trace_capture.ordered_step_outputs() + forward_trace_capture.save_current_step(traces_dir) + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + current_lora_state = _collect_lora_state(model_chunks) + + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] + grads = _require_not_none(captured_grads, "captured_grads") + initial_state = _require_not_none( + initial_lora_state, "initial_lora_state" + ) + current_state = _require_not_none( + current_lora_state, "current_lora_state" + ) + deltas = _delta_state(initial_state, current_state) + saved_deltas = _apply_save_mutation_to_tensor_map( + deltas, + mutation=request.mutation, + ) + saved_current_state = _apply_save_mutation_to_tensor_map( + current_state, + mutation=request.mutation, + ) + + output_rel = Path("traces") / f"output_step_{step_index:03d}.pt" + grads_rel = Path("traces") / f"grads_step_{step_index:03d}.safetensors" + deltas_rel = ( + Path("traces") / f"deltas_step_{step_index:03d}.safetensors" + ) + lora_rel = Path(f"lora_step_{step_index:03d}.safetensors") + ordered_outputs = _require_not_none( + ordered_micro_outputs, "ordered_micro_outputs" + ) + if not ordered_outputs: + raise RuntimeError("Expected at least one captured micro output") + + torch.save( + torch.stack(ordered_outputs, dim=0), + topology_dir / output_rel, + ) + save_file(grads, str(topology_dir / grads_rel)) + save_file(saved_deltas, str(topology_dir / deltas_rel)) + save_file(saved_current_state, str(topology_dir / lora_rel)) + + step_traces.append( + StepTrace( + step_index=step_index, + loss=float( + step_result.reduced_loss.item() + / request.case_config.loss_scale + ), + probs_corr=step_result.probs_corr, + output_file=str(output_rel), + grads_file=str(grads_rel), + deltas_file=str(deltas_rel), + lora_file=str(lora_rel), + ) + ) + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + + forward_trace_capture.close() + + if torch.distributed.get_rank() == 0: # ty: ignore[possibly-missing-attribute] + # build and save the moe routing replay bundle + if request.capture_moe_routing_bundle_path is not None: + replay_bundle = build_bundle_from_forward_trace_dir( + traces_dir=traces_dir, + num_steps=request.case_config.num_steps, + topology=ReplayParallelTopology.model_validate( + request.topology.model_dump( + include={"tp", "ep", "etp", "dp", "sp", "cp", "pp", "vpp"}, + mode="python", + ) + ), + ) + replay_bundle.to_dir(request.capture_moe_routing_bundle_path) + + # build and save the run manifest + manifest = RunManifest( + case_id=request.case_id, + base_model=request.case_config.base_model, + num_layers=request.case_config.num_layers, + topology=request.topology.slug(), + world_size=request.topology.world_size(), + seed=request.case_config.seed, + num_steps=request.case_config.num_steps, + packed_tensors=request.packed_tensors, + steps=step_traces, + ) + _write_json(topology_dir / "manifest.json", manifest.model_dump(mode="json")) + torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] + torch.distributed.destroy_process_group() # ty: ignore[possibly-missing-attribute] + + +def run_worker_cli(run_request_path: Path) -> None: + """Loads a worker request and dispatches worker execution.""" + request = WorkerRunRequest.model_validate(_read_json(run_request_path)) + _worker_run(request) + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + """Parses worker CLI arguments.""" + parser = argparse.ArgumentParser(description="Megatron oracle harness worker") + parser.add_argument("--worker-run", action="store_true") + parser.add_argument("--run-request", type=Path) + return parser.parse_args(argv) + + +def _main(argv: list[str]) -> int: + """CLI entry for worker-only execution mode.""" + args = _parse_args(argv) + if not args.worker_run: + raise SystemExit("This module is intended for test imports or --worker-run") + if args.run_request is None: + raise SystemExit("--run-request is required with --worker-run") + run_worker_cli(args.run_request) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main(sys.argv[1:])) diff --git a/tests/integration/test_megatron_lora_oracle_correctness.py b/tests/integration/test_megatron_lora_oracle_correctness.py new file mode 100644 index 00000000..67c35adb --- /dev/null +++ b/tests/integration/test_megatron_lora_oracle_correctness.py @@ -0,0 +1,126 @@ +from contextlib import redirect_stderr, redirect_stdout +from pathlib import Path +from typing import Callable + +import pytest + +from .megatron_oracle_harness import ( + EXTENDED_TOPOLOGIES, + SENSITIVITY_MUTATION_ENV, + TOPOLOGIES, + available_gpu_count, + case_config, + extended_topologies_enabled, + run_sensitivity_suite, + run_suite, + sensitivity_enabled, + sensitivity_mutations, + sensitivity_required_world_size, +) + +REPO_ROOT = Path(__file__).resolve().parents[2] +CORRECTNESS_LOG_PATH = REPO_ROOT / ".local" / "correctness.log" +SENSITIVITY_LOG_PATH = REPO_ROOT / ".local" / "sensitivity.log" + + +def _run_suite_with_log( + *, + log_path: Path, + run: Callable[[], object], +) -> None: + log_path.parent.mkdir(parents=True, exist_ok=True) + with log_path.open("w", encoding="utf-8") as log_file: + with redirect_stdout(log_file), redirect_stderr(log_file): + run() + + +def _announce_report_log( + *, + log_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + with capsys.disabled(): + print(f"\nMegatron LoRA oracle report log: {log_path}", flush=True) + + +def _require_gpus_for(topology_world_size: int) -> None: + gpu_count = available_gpu_count() + if gpu_count < topology_world_size: + pytest.skip( + f"Need {topology_world_size} GPUs for topology run, only found {gpu_count}" + ) + + +def _suite_world_size() -> int: + suite_topologies = list(TOPOLOGIES) + if extended_topologies_enabled(): + suite_topologies.extend(EXTENDED_TOPOLOGIES) + return max(topology.world_size() for topology in suite_topologies) + + +def test_megatron_lora_diff_sensitivity(capsys: pytest.CaptureFixture[str]) -> None: + """ + Runs a each of the sensitivity mutations (e.g. drop megatron finalize grads) + and expects each to fail (numerical differences larger than our thresholds) + + This test ensures we can catch errors we know of (implying we will be able to catch unknown errors as well) + """ + _announce_report_log(log_path=SENSITIVITY_LOG_PATH, capsys=capsys) + if not sensitivity_enabled(): + SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + SENSITIVITY_LOG_PATH.write_text( + ( + "Sensitivity suite skipped. " + f"Set {SENSITIVITY_MUTATION_ENV}=all (or one mutation / CSV).\n" + ), + encoding="utf-8", + ) + pytest.skip( + f"Set {SENSITIVITY_MUTATION_ENV}=all (or one mutation / CSV) to enable sensitivity check." + ) + mutations = sensitivity_mutations() + assert mutations + sensitivity_world_size = sensitivity_required_world_size(mutations) + gpu_count = available_gpu_count() + if gpu_count < sensitivity_world_size: + SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + SENSITIVITY_LOG_PATH.write_text( + ( + "Sensitivity suite skipped. " + f"Need {sensitivity_world_size} GPUs, found {gpu_count}.\n" + ), + encoding="utf-8", + ) + _require_gpus_for(sensitivity_world_size) + _run_suite_with_log( + log_path=SENSITIVITY_LOG_PATH, + run=lambda: run_sensitivity_suite( + case_config=case_config(), + mutations=mutations, + ), + ) + + +def test_megatron_lora_topology_suite(capsys: pytest.CaptureFixture[str]) -> None: + """ + Runs the suite of topologies and expects each to pass (numerical differences within our thresholds) + """ + _announce_report_log(log_path=CORRECTNESS_LOG_PATH, capsys=capsys) + suite_world_size = _suite_world_size() + gpu_count = available_gpu_count() + if gpu_count < suite_world_size: + CORRECTNESS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + CORRECTNESS_LOG_PATH.write_text( + ( + "Topology suite skipped. " + f"Need {suite_world_size} GPUs, found {gpu_count}.\n" + ), + encoding="utf-8", + ) + _require_gpus_for(suite_world_size) + _run_suite_with_log( + log_path=CORRECTNESS_LOG_PATH, + run=lambda: run_suite( + case_config=case_config(), + ), + ) diff --git a/tests/unit/test_moe_routing_replay.py b/tests/unit/test_moe_routing_replay.py new file mode 100644 index 00000000..15d1ebc6 --- /dev/null +++ b/tests/unit/test_moe_routing_replay.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from pathlib import Path +import tempfile +from typing import cast + +import pytest +import torch +from torch import nn + +from art.megatron.routing_replay import ( + MoeRoutingReplayBundle, + MoeRoutingReplayController, + ParallelTopology, + RouterCallRoute, + StepRouterRoutes, + StepRoutes, +) + + +def _dense_from_compact( + route: RouterCallRoute, + *, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = route.expert_indices.shape[0] + num_experts = route.num_experts + probs = torch.zeros((num_tokens, num_experts), dtype=dtype) + routing_map = torch.zeros((num_tokens, num_experts), dtype=torch.bool) + for token_idx in range(num_tokens): + for slot in range(route.expert_indices.shape[1]): + if not bool(route.expert_mask[token_idx, slot]): + continue + expert_idx = int(route.expert_indices[token_idx, slot].item()) + probs[token_idx, expert_idx] = route.expert_probs[token_idx, slot].to(dtype) + routing_map[token_idx, expert_idx] = True + return probs, routing_map + + +def _make_bundle() -> tuple[MoeRoutingReplayBundle, RouterCallRoute]: + router_key = "chunk_00.layer_0000.mlp.router" + route = RouterCallRoute( + expert_indices=torch.tensor( + [ + [0, 2], + [1, 0], + [2, 1], + [1, 0], + ], + dtype=torch.int32, + ), + expert_probs=torch.tensor( + [ + [0.70, 0.30], + [1.00, 0.00], + [0.65, 0.35], + [1.00, 0.00], + ], + dtype=torch.float32, + ), + expert_mask=torch.tensor( + [ + [True, True], + [True, False], + [True, True], + [True, False], + ], + dtype=torch.bool, + ), + num_experts=3, + ) + bundle = MoeRoutingReplayBundle( + topology=ParallelTopology(tp=1, ep=1, etp=1, dp=1, sp=False, cp=1, pp=1, vpp=1), + num_steps=1, + max_topk=2, + router_keys=[router_key], + steps={ + 0: StepRoutes( + routers={router_key: StepRouterRoutes(calls={0: route})}, + global_token_uids=torch.arange(4, dtype=torch.int64), + ) + }, + ) + return bundle, route + + +class _IdentityIndexer: + def build_local_token_uids( + self, + *, + global_token_uids: torch.Tensor, + num_local_tokens: int, + sequence_parallel: bool, + context_parallel_size: int, + ) -> torch.Tensor: + del sequence_parallel, context_parallel_size + if int(global_token_uids.numel()) < num_local_tokens: + raise RuntimeError("num_local_tokens exceeds global token count") + return global_token_uids[:num_local_tokens].clone() + + +class _FakeRouter(nn.Module): + def __init__(self) -> None: + super().__init__() + self.config = type( + "Config", + (), + {"sequence_parallel": False, "context_parallel_size": 1}, + )() + + def routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + probs = torch.softmax(logits, dim=-1) + routing_map = torch.zeros_like(logits, dtype=torch.bool) + return probs, routing_map + + +class _FakeMlp(nn.Module): + def __init__(self) -> None: + super().__init__() + self.router = _FakeRouter() + + +class _FakeLayer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.mlp = _FakeMlp() + + +class _FakeDecoder(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = nn.ModuleList([_FakeLayer()]) + + +class _FakeChunk(nn.Module): + def __init__(self) -> None: + super().__init__() + self.decoder = _FakeDecoder() + + +def test_bundle_roundtrip_disk() -> None: + bundle, route = _make_bundle() + with tempfile.TemporaryDirectory() as tmp_dir: + bundle_path = Path(tmp_dir) + bundle.to_dir(bundle_path) + loaded = MoeRoutingReplayBundle.from_dir(bundle_path) + + assert loaded.num_steps == 1 + assert loaded.max_topk == 2 + assert loaded.router_keys == bundle.router_keys + loaded_route = loaded.steps[0].routers[bundle.router_keys[0]].calls[0] + assert torch.equal(loaded_route.expert_indices, route.expert_indices) + assert torch.equal(loaded_route.expert_probs, route.expert_probs) + assert torch.equal(loaded_route.expert_mask, route.expert_mask) + + +def test_controller_patches_router_and_replays() -> None: + bundle, route = _make_bundle() + controller = MoeRoutingReplayController( + bundle=bundle, + strict=True, + local_token_indexer=_IdentityIndexer(), + ) + chunk = _FakeChunk() + controller.install_router_patches([chunk]) + controller.set_step(step_index=0, sample_index=0) + + logits = torch.randn((4, 3), dtype=torch.float32) + router = cast(_FakeRouter, chunk.decoder.layers[0].mlp.router) + replay_probs, replay_map = router.routing(logits) + expected_probs, expected_map = _dense_from_compact(route, dtype=logits.dtype) + + assert torch.equal(replay_map.cpu(), expected_map) + assert torch.allclose(replay_probs.cpu(), expected_probs, atol=0.0, rtol=0.0) + + controller.finalize_step() + controller.remove_router_patches() + + +def test_controller_finalize_fails_when_unconsumed_calls_remain() -> None: + bundle, _route = _make_bundle() + controller = MoeRoutingReplayController( + bundle=bundle, + strict=True, + local_token_indexer=_IdentityIndexer(), + ) + chunk = _FakeChunk() + controller.install_router_patches([chunk]) + controller.set_step(step_index=0, sample_index=0) + with pytest.raises(RuntimeError, match="consumption mismatch"): + controller.finalize_step()