From 680caf5d65c495ed620f87d10b1f98f8942fd14b Mon Sep 17 00:00:00 2001 From: tomiock Date: Thu, 23 Apr 2026 21:26:31 +0200 Subject: [PATCH 01/23] pipeline parallel init --- configs/cvc/qwen3_5_2b.toml | 3 +- configs/cvc/qwen3_5_9b.toml | 12 +- models/qwen3_5/model.py | 6 +- train/config.py | 1 + train/infra.py | 410 ++++++++++++++++++++++++++++++++++-- train/train_qwen.py | 143 ++++++++++--- 6 files changed, 521 insertions(+), 54 deletions(-) diff --git a/configs/cvc/qwen3_5_2b.toml b/configs/cvc/qwen3_5_2b.toml index 7875207..4a256ac 100644 --- a/configs/cvc/qwen3_5_2b.toml +++ b/configs/cvc/qwen3_5_2b.toml @@ -23,11 +23,12 @@ wsd_decay_ratio: float = 0.1 min_lr_ratio: float = 0.1 tp_size = 1 +pp_size = 2 data_parallel = 'ddp' compile = false -tpi_multiplier = 12 +tpi_multiplier = 1 [data] data_path = "/data/151-1/datasets/llava_recap" diff --git a/configs/cvc/qwen3_5_9b.toml b/configs/cvc/qwen3_5_9b.toml index 221eb5e..65622ba 100644 --- a/configs/cvc/qwen3_5_9b.toml +++ b/configs/cvc/qwen3_5_9b.toml @@ -16,17 +16,19 @@ output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" save_steps = 10000 total_steps = 10000 -random_init_mlp = false +random_init = false -tp_size = 4 +tp_size = 1 +pp_size = 2 data_parallel = 'fsdp' ac_mode = "full" - compile = false [data] data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" -seq_len = 4096 +seq_len = 512 + +packing_buffer_size = 100 -batch_size = 32 +batch_size = 0 \ No newline at end of file diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 5db7f2b..fd361c5 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -548,10 +548,14 @@ def forward( merged = self.merger(hidden_states) return merged, deepstack +class Qwen3_5InnerLanguage(nn.Module): + def __init__(self, cfg: Qwen3VLConfig): + super().__init__() + self.language_model = LanguageModel(cfg.text) + class Qwen3_5Inner(nn.Module): """HF name: `model`. Groups `language_model` and `visual`. This is only used to match the state keys. """ - def __init__(self, cfg: Qwen3VLConfig): super().__init__() self.language_model = LanguageModel(cfg.text) diff --git a/train/config.py b/train/config.py index 6dcde76..b309346 100644 --- a/train/config.py +++ b/train/config.py @@ -91,6 +91,7 @@ class Training: data_parallel: str = "ddp" # fsdp, ddp tp_size: int = 1 # 1 means disabled + pp_size: int = 1 # 1 means disabled; supported values: 1, 2, 4 """ Use `fsdp` when you want to decrease usage to increase seq_len/batch_size. """ diff --git a/train/infra.py b/train/infra.py index 55aa03f..dd72091 100644 --- a/train/infra.py +++ b/train/infra.py @@ -1,10 +1,13 @@ from dataclasses import dataclass from functools import partial +from typing import Optional from train.config import ModelType import torch import torch._inductor.config +import torch.distributed as dist +import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import Replicate, Shard @@ -113,27 +116,35 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) def get_mesh(training_args, world_size): - """ - Creates a 2D DeviceMesh based on tp_size and world_size. - Always returns ('dp', 'tp'). - """ tp_size = training_args.tp_size - - if world_size % tp_size != 0: - raise ValueError(f"World size {world_size} is not divisible by TP size {tp_size}") + pp_size = getattr(training_args, "pp_size", 1) - dp_size = world_size // tp_size + if world_size % (tp_size * pp_size) != 0: + raise ValueError( + f"world_size {world_size} not divisible by tp_size*pp_size={tp_size * pp_size}" + ) + dp_size = world_size // (tp_size * pp_size) + + if pp_size > 1: + return init_device_mesh( + "cuda", (dp_size, pp_size, tp_size), mesh_dim_names=("dp", "pp", "tp") + ) return init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) def get_tp_group(mesh): if "tp" in mesh.mesh_dim_names: - return mesh['tp'] + return mesh["tp"] return None def get_dp_group(mesh): if "dp" in mesh.mesh_dim_names: - return mesh['dp'] + return mesh["dp"] + return None + +def get_pp_group(mesh): + if "pp" in mesh.mesh_dim_names: + return mesh["pp"] return None def module_filter_float8_fn(mod: torch.nn.Module, fqn: str): @@ -238,7 +249,7 @@ def compile_model(model: torch.nn.Module): model.language_model = torch.compile(model.language_model, fullgraph=False, mode='default') model.visual = torch.compile(model.visual, fullgraph=False, mode='default') model.visual.merger = torch.compile(model.visual.merger, fullgraph=False, mode='default',) - #model = torch.compile(model, mode='default') + #model = torch.compile(model, mode='default') def apply_fsdp(model_type, model, **kwargs): if model_type == ModelType.Qwen3_text: @@ -417,7 +428,7 @@ def _apply_tp_to_decoder_qwen3_vl( desired_input_kwarg_layouts={ "hidden_states": Replicate(), }, - ), + ), "self_attn.q_proj": colwise_parallel(use_local_output=False), "self_attn.k_proj": colwise_parallel(use_local_output=False), "self_attn.v_proj": colwise_parallel(use_local_output=False), @@ -440,9 +451,6 @@ def _apply_tp_to_decoder_qwen3_vl( parallelize_plan=layer_plan, ) - if enable_async_tp: - torch._inductor.config._micro_pipeline_tp = True - def _register_tp_sum_hook(param, tp_mesh): """All-reduce SUM a parameter's grad on the TP process group. @@ -650,3 +658,375 @@ def _apply_tp_to_decoder_qwen3_5( if enable_async_tp: torch._inductor.config._micro_pipeline_tp = True + + +# --------------------------------------------------------------------------- +# Pipeline Parallel for Qwen3.5 (native impl, dense model) +# --------------------------------------------------------------------------- + +# cu_seqlens is padded to this fixed size so PipelineStage sees constant shapes. +_PP_MAX_SEQS: int = 256 + + +def _pp_layer_ranges(n_layers: int, pp_size: int) -> list[tuple[int, int]]: + """Equal-split layer index ranges across pp_size stages.""" + assert pp_size in (2, 4), f"pp_size must be 2 or 4, got {pp_size}" + base, rem = divmod(n_layers, pp_size) + ranges, start = [], 0 + for i in range(pp_size): + end = start + base + (1 if i < rem else 0) + ranges.append((start, end)) + start = end + return ranges + + +class PPStageModule(nn.Module): + """ + A single PP stage for Qwen3.5ForCausalLM (native impl). + + Owns decoder layers[layer_start:layer_end]. + Stage 0 (is_first=True) also owns the visual encoder and embed_tokens. + Last stage (is_last=True) also owns norm and lm_head. + + Inter-stage tensor protocol (all fixed shapes): + hidden_states : (1, seq_len, hidden_size) dtype + cos, sin : (1, seq_len, rope_dim) dtype + cu_seqlens_pad : (_PP_MAX_SEQS+1,) int32 + n_seqs : () int64 + """ + + def __init__( + self, + full_model: nn.Module, + layer_start: int, + layer_end: int, + is_first: bool, + is_last: bool, + ): + super().__init__() + lm = full_model.model.language_model + + self.is_first = is_first + self.is_last = is_last + self.layers = nn.ModuleList(list(lm.layers[layer_start:layer_end])) + + if is_first: + n_ds = len(full_model.model.visual.deepstack_visual_indexes) + assert n_ds <= (layer_end - layer_start), ( + f"Stage 0 has {layer_end - layer_start} layers but {n_ds} deepstack " + "injections — all must fit in stage 0. Reduce pp_size or use a " + "larger first-stage split." + ) + self.visual = full_model.model.visual + self.embed_tokens = lm.embed_tokens + self.register_buffer("text_inv_freq", full_model.text_inv_freq.clone()) + self.mrope_section = list(full_model.mrope_section) + self.image_token_id = full_model.cfg.image_token_id + self.video_token_id = full_model.cfg.video_token_id + self.spatial_merge_size = full_model.cfg.vision.spatial_merge_size + # Populated by preprocess() before each forward + self._vis_masks: Optional[torch.Tensor] = None + self._ds_embeds: Optional[list] = None + + if is_last: + self.norm = lm.norm + self.lm_head = full_model.lm_head + + # ------------------------------------------------------------------ + def preprocess(self, batch: dict) -> tuple: + """ + Stage-0 only: visual encoding + token embedding + RoPE. + + Call this on pp_rank=0 before schedule.step() to produce the + fixed-shape tensors that enter the pipeline. + """ + from models.qwen3_5.utils import mrope_cos_sin + + input_ids = batch["input_ids"] # (1, seq_len) + pixel_values = batch.get("pixel_values") + image_grid_thw = batch.get("image_grid_thw") + video_grid_thw = batch.get("video_grid_thw") + cu_seqlens = batch["attention_mask"] # (n_seqs+1,) int32 + device = input_ids.device + total = input_ids.shape[1] + + x = self.embed_tokens(input_ids) # (1, total, H) + + self._vis_masks = None + self._ds_embeds = None + has_img = pixel_values is not None and pixel_values.numel() > 0 + has_vid = video_grid_thw is not None and video_grid_thw.numel() > 0 + + if has_img: + merged, ds = self.visual(pixel_values, image_grid_thw) + merged = merged.to(x.dtype) + mask = input_ids == self.image_token_id + x = x.masked_scatter(mask.unsqueeze(-1).expand_as(x), merged) + self._vis_masks, self._ds_embeds = mask, ds + + if has_vid: + merged_v, ds_v = self.visual(batch["pixel_values_videos"], video_grid_thw) + merged_v = merged_v.to(x.dtype) + vmask = input_ids == self.video_token_id + x = x.masked_scatter(vmask.unsqueeze(-1).expand_as(x), merged_v) + if self._vis_masks is None: + self._vis_masks, self._ds_embeds = vmask, ds_v + else: + combined = self._vis_masks | vmask + merged_ds = [] + for a, b in zip(self._ds_embeds, ds_v): + e = a.new_zeros(combined.sum().item(), a.shape[-1]) + e[self._vis_masks[combined]] = a + e[vmask[combined]] = b + merged_ds.append(e) + self._vis_masks, self._ds_embeds = combined, merged_ds + + # position ids → cos/sin + if has_img or has_vid: + pos = _mrope_position_ids( + input_ids, cu_seqlens, + image_grid_thw if has_img else None, + video_grid_thw if has_vid else None, + self.image_token_id, self.video_token_id, self.spatial_merge_size, + ) + else: + pos = torch.zeros(total, dtype=torch.int64, device=device) + for s, e in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist()): + pos[s:e] = torch.arange(e - s, device=device) + pos = pos.view(1, 1, -1).expand(3, 1, -1) + + cos, sin = mrope_cos_sin(self.text_inv_freq, pos, self.mrope_section) + cos, sin = cos.to(x.dtype), sin.to(x.dtype) + + # pad cu_seqlens to fixed size + n = cu_seqlens.shape[0] + assert n <= _PP_MAX_SEQS + 1, f"cu_seqlens len {n} > _PP_MAX_SEQS={_PP_MAX_SEQS}" + cu_pad = F.pad(cu_seqlens, (0, _PP_MAX_SEQS + 1 - n)) + n_t = torch.tensor(n, dtype=torch.int64, device=device) + + return x, cos, sin, cu_pad, n_t + + # ------------------------------------------------------------------ + def forward( + self, + hidden: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_pad: torch.Tensor, + n_t: torch.Tensor, + ) -> tuple | torch.Tensor: + n = int(n_t.item()) + cu = cu_pad[:n].to(torch.int32) + max_s = int((cu[1:] - cu[:-1]).max().item()) + + x = hidden + for i, layer in enumerate(self.layers): + x = layer(x, cos, sin, cu, max_s) + if self.is_first and self._ds_embeds is not None and i < len(self._ds_embeds): + x = x.clone() + x[self._vis_masks] = x[self._vis_masks] + self._ds_embeds[i].to(x.dtype) + + if self.is_last: + return self.lm_head(self.norm(x)) # (1, seq_len, vocab_size) + + return x, cos, sin, cu_pad, n_t + + +class _ScaledLoss: + """Stateful loss callable — update .accum_target each step.""" + + def __init__(self) -> None: + self.accum_target: int = 1 + + def __call__(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + from models.qwen3_5.utils import causal_lm_loss + return causal_lm_loss(logits, labels) / self.accum_target + + +def _mrope_position_ids( + input_ids, cu_seqlens, image_grid_thw, video_grid_thw, + image_token_id, video_token_id, spatial_merge_size, +) -> torch.Tensor: + """3-D MRoPE positions — mirrors Qwen3_5ForCausalLM.get_rope_index.""" + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0).clone() + video_grid_thw[:, 0] = 1 + + _, S = input_ids.shape + device = input_ids.device + mm = torch.zeros(S, dtype=torch.int64, device=device) + mm[input_ids[0] == image_token_id] = 1 + mm[input_ids[0] == video_token_id] = 2 + types = mm.tolist() + img_it = iter(image_grid_thw) if image_grid_thw is not None else None + vid_it = iter(video_grid_thw) if video_grid_thw is not None else None + out = torch.zeros(3, 1, S, dtype=torch.int64, device=device) + for start, end in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist()): + if start == end: + continue + seg = types[start:end] + parts, cur, j = [], 0, 0 + while j < len(seg): + k = j + while k < len(seg) and seg[k] == seg[j]: + k += 1 + key, length = seg[j], k - j + if key == 0: + parts.append(torch.arange(length, device=device).view(1, -1).expand(3, -1) + cur) + cur += length + else: + g = next(img_it if key == 1 else vid_it) + t, h, w = int(g[0]), int(g[1]), int(g[2]) + lh, lw = h // spatial_merge_size, w // spatial_merge_size + n = t * lh * lw + pw = torch.arange(cur, cur + lw, device=device).repeat(lh * t) + ph = torch.arange(cur, cur + lh, device=device).repeat_interleave(lw * t) + pt = torch.full((n,), cur, device=device, dtype=torch.int64) + parts.append(torch.stack([pt, ph, pw])) + cur += max(lh, lw) + j = k + out[:, 0, start:end] = torch.cat(parts, dim=1) + return out + + +class _PPSchedule: + """ + Single-microbatch pipeline schedule using blocking P2P comms. + + Implements GPipe-style all-forward then all-backward for a linear + pp_size-stage pipeline with packed (flash-attn varlen) sequences. + + Inter-stage tensor layout: + metadata : (1,) int64 — n_tokens (variable, sent before hidden/cos/sin) + hidden : (1, n_tokens, H) bfloat16 + cos : (1, n_tokens, R) bfloat16 + sin : (1, n_tokens, R) bfloat16 + cu_pad : (_PP_MAX_SEQS+1,) int32 (fixed size) + n_t : () int64 (scalar) + Backward sends only grad_hidden : (1, n_tokens, H) bfloat16. + """ + + def __init__( + self, + stage_module: nn.Module, + pp_rank: int, + pp_size: int, + pp_group, # DeviceMesh sub-mesh (pp dimension) + hidden_size: int, + rope_dim: int, + dtype: torch.dtype, + loss_fn: Optional["_ScaledLoss"] = None, + ): + self.stage = stage_module + self.pp_rank = pp_rank + self.pp_size = pp_size + self.is_first = pp_rank == 0 + self.is_last = pp_rank == pp_size - 1 + self._group = pp_group.get_group() + self._H = hidden_size + self._R = rope_dim + self._dt = dtype + self._device = torch.device(f"cuda:{torch.cuda.current_device()}") + self.loss_fn = loss_fn + + # Fixed-size recv buffers (cu_pad, n_t, metadata always fixed) + self._cu_buf = torch.empty(_PP_MAX_SEQS + 1, device=self._device, dtype=torch.int32) + self._nt_buf = torch.empty((), device=self._device, dtype=torch.int64) + self._meta_buf = torch.empty(1, device=self._device, dtype=torch.int64) + + # ------------------------------------------------------------------ + def _send_fwd(self, x, cos, sin, cu_pad, n_t, dst: int): + n_tok = torch.tensor([x.shape[1]], device=x.device, dtype=torch.int64) + dist.send(n_tok, dst=dst, group=self._group) + dist.send(x.contiguous(), dst=dst, group=self._group) + dist.send(cos.contiguous(), dst=dst, group=self._group) + dist.send(sin.contiguous(), dst=dst, group=self._group) + dist.send(cu_pad.contiguous(), dst=dst, group=self._group) + dist.send(n_t.contiguous(), dst=dst, group=self._group) + + def _recv_fwd(self, src: int): + dist.recv(self._meta_buf, src=src, group=self._group) + n = int(self._meta_buf.item()) + x = torch.empty(1, n, self._H, device=self._device, dtype=self._dt) + cos = torch.empty(1, n, self._R, device=self._device, dtype=self._dt) + sin = torch.empty(1, n, self._R, device=self._device, dtype=self._dt) + dist.recv(x, src=src, group=self._group) + dist.recv(cos, src=src, group=self._group) + dist.recv(sin, src=src, group=self._group) + dist.recv(self._cu_buf, src=src, group=self._group) + dist.recv(self._nt_buf, src=src, group=self._group) + return x, cos, sin, self._cu_buf, self._nt_buf + + # ------------------------------------------------------------------ + def step(self, *args, target=None, losses=None): + if self.is_first and not self.is_last: + x, cos, sin, cu_pad, n_t = self.stage(*args) + self._send_fwd(x, cos, sin, cu_pad, n_t, dst=self.pp_rank + 1) + grad_x = torch.empty_like(x) + dist.recv(grad_x, src=self.pp_rank + 1, group=self._group) + x.backward(grad_x) + + elif not self.is_first and self.is_last: + x_in, cos, sin, cu_pad, n_t = self._recv_fwd(src=self.pp_rank - 1) + x_leaf = x_in.detach().requires_grad_(True) + logits = self.stage(x_leaf, cos, sin, cu_pad, n_t) + loss = self.loss_fn(logits, target) + if losses is not None: + losses.append(loss.detach()) + loss.backward() + dist.send(x_leaf.grad.contiguous(), dst=self.pp_rank - 1, group=self._group) + + elif not self.is_first and not self.is_last: + # Middle stage (pp_size=4) + x_in, cos, sin, cu_pad, n_t = self._recv_fwd(src=self.pp_rank - 1) + x_leaf = x_in.detach().requires_grad_(True) + x_out, cos_out, sin_out, cu_out, nt_out = self.stage(x_leaf, cos, sin, cu_pad, n_t) + self._send_fwd(x_out, cos_out, sin_out, cu_out, nt_out, dst=self.pp_rank + 1) + grad_x_out = torch.empty_like(x_out) + dist.recv(grad_x_out, src=self.pp_rank + 1, group=self._group) + x_out.backward(grad_x_out) + dist.send(x_leaf.grad.contiguous(), dst=self.pp_rank - 1, group=self._group) + + else: + raise RuntimeError("_PPSchedule used with pp_size=1; use regular training path instead") + + +def apply_pp_qwen35( + model: nn.Module, + pp_group, # DeviceMesh sub-mesh for the PP dimension + seq_len: int, +) -> tuple: + """ + Split Qwen3_5ForCausalLM across PP ranks (pp_size must be 2 or 4). + + Returns (stage_module, None, schedule, loss_fn, pp_rank, pp_size, is_last). + The caller should: + - wrap stage_module with DDP/FSDP if needed + - call schedule.step(...) in the training loop (see train_qwen.py) + """ + pp_rank: int = pp_group.get_local_rank() + pp_size: int = pp_group.size() + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + assert pp_size in (2, 4), f"apply_pp_qwen35: pp_size must be 2 or 4, got {pp_size}" + + n_layers = len(model.model.language_model.layers) + ranges = _pp_layer_ranges(n_layers, pp_size) + ls, le = ranges[pp_rank] + is_first = pp_rank == 0 + is_last = pp_rank == pp_size - 1 + + stage_module = PPStageModule(model, ls, le, is_first, is_last).to(device) + + H = model.model.language_model.cfg.hidden_size + R = model.text_inv_freq.shape[0] * 2 # rope_dim = 2 * len(inv_freq) + dt = next(model.parameters()).dtype + + loss_fn = _ScaledLoss() + schedule = _PPSchedule( + stage_module, pp_rank, pp_size, pp_group, + hidden_size=H, rope_dim=R, dtype=dt, + loss_fn=loss_fn if is_last else None, + ) + + return stage_module, None, schedule, loss_fn, pp_rank, pp_size, is_last diff --git a/train/train_qwen.py b/train/train_qwen.py index 1003439..dfa896c 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -27,11 +27,13 @@ get_mesh, get_tp_group, get_dp_group, + get_pp_group, apply_fsdp, apply_tp, apply_ac, ACConfig, compile_model, + apply_pp_qwen35, ) from train.utils import ( set_determinism, @@ -77,6 +79,8 @@ def __init__(self, cfg: Config): self.mesh = get_mesh(self.training_args, self.world_size) self.tp_group = get_tp_group(self.mesh) self.dp_group = get_dp_group(self.mesh) + self.pp_group = get_pp_group(self.mesh) + self.pp_size = getattr(self.training_args, "pp_size", 1) self.device = torch.device(f"cuda:{self.local_rank}") if self.if_log_rank(): @@ -159,25 +163,55 @@ def __init__(self, cfg: Config): logger.info("model loaded") - if self.training_args.tp_size > 1: - apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) - - ac_mode = getattr(self.training_args, "ac_mode", "off") - if ac_mode != "off": - ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) - apply_ac( - self.model.model.language_model, - ac_cfg, - model_compile_enabled=self.training_args.compile, + if self.pp_size > 1: + assert self.training_args.tp_size == 1, "TP + PP is not yet supported" + assert self.model_type == ModelType.Qwen3_5, \ + "Pipeline Parallel only implemented for Qwen3.5 native model" + + ( + self.model, + self._pp_pipeline_stage, + self._pp_schedule, + self._pp_loss_fn, + self.pp_rank, + _, + self.pp_is_last, + ) = apply_pp_qwen35(self.model, self.pp_group, int(self.data_args.seq_len)) + + logger.info( + f"PP applied: rank {self.pp_rank}/{self.pp_size}, " + f"layers {[len(list(self.model.layers))]}, " + f"is_last={self.pp_is_last}" ) - logger.info(f"activation checkpointing applied ({ac_mode})") - - if self.training_args.data_parallel == 'fsdp': - apply_fsdp(self.model_type, self.model, mesh=self.dp_group) - elif self.training_args.data_parallel == 'ddp': - self.model = replicate(self.model, device_mesh=self.dp_group) else: - raise Exception('invalid sharding strategy for Data Parallel') + self.pp_rank = 0 + self.pp_is_last = True + + if self.training_args.tp_size > 1: + apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) + + ac_mode = getattr(self.training_args, "ac_mode", "off") + if ac_mode != "off": + ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) + apply_ac( + self.model.model.language_model, + ac_cfg, + model_compile_enabled=self.training_args.compile, + ) + logger.info(f"activation checkpointing applied ({ac_mode})") + + if self.training_args.data_parallel == 'fsdp': + apply_fsdp(self.model_type, self.model, mesh=self.dp_group) + elif self.training_args.data_parallel == 'ddp': + self.model = replicate(self.model, device_mesh=self.dp_group) + else: + raise Exception('invalid sharding strategy for Data Parallel') + + if self.pp_size > 1: + # With PP, apply DDP to the stage module within the DP group + if self.training_args.data_parallel == 'ddp': + self.model = replicate(self.model, device_mesh=self.dp_group) + # (FSDP support for PP stages can be added later) # get rank of local GPU that belongs to the DP group data_rank = self.dp_group.get_local_rank() @@ -185,7 +219,7 @@ def __init__(self, cfg: Config): logger.info('sharding/parallelism applied') - if self.training_args.compile: + if self.training_args.compile and self.pp_size == 1: compile_model(self.model) logger.info("model (will be) compiled") @@ -202,7 +236,10 @@ def __init__(self, cfg: Config): max_pixels=1048576, ) - self.model = set_model(self.model_type, self.model_args, self.model) + # set_model freezes/unfreezes param groups; skip for PP (stage module + # doesn't have the full VLM wrapper structure) + if self.pp_size == 1: + self.model = set_model(self.model_type, self.model_args, self.model) worker_config = WorkerConfig( rank=data_rank, @@ -247,6 +284,7 @@ def rank(self): return torch.distributed.get_rank() def if_log_rank(self): + # Log only from global rank 0 (always pp_rank=0 and dp_rank=0) return self.rank() == 0 def create_optimizer(self): @@ -479,14 +517,17 @@ def setup_accumulation(self, tpi_multiplier=1.5): self.current_accum_count = 0 def train_step(self, data_iterator, optimizer): + if self.pp_size > 1: + return self._train_step_pp(data_iterator, optimizer) + return self._train_step_regular(data_iterator, optimizer) + + def _train_step_regular(self, data_iterator, optimizer): batch = next(data_iterator) s_model = time.perf_counter() with record_function("forward_pass"): with torch.autocast('cuda', torch.bfloat16): - outputs = self.model( - **batch - ) + outputs = self.model(**batch) loss = outputs.loss with record_function("backward_pass"): @@ -495,7 +536,49 @@ def train_step(self, data_iterator, optimizer): scaled_loss.backward() self.fwd_bwd_time = time.perf_counter() - s_model + return self._maybe_optimizer_step(loss, optimizer) + + def _train_step_pp(self, data_iterator, optimizer): + """Training step for Pipeline Parallel (Qwen3.5 native, pp_size 2 or 4).""" + batch = next(data_iterator) + + # Update the loss scaling factor used inside the last stage + if self.pp_is_last: + self._pp_loss_fn.accum_target = self.current_accum_target + s_model = time.perf_counter() + losses: list = [] + + with record_function("pp_forward_backward"): + with torch.autocast('cuda', torch.bfloat16): + if self.pp_rank == 0: + # Preprocess: visual encode + embed + RoPE → fixed-shape tensors + stage_inputs = self.model.preprocess(batch) + self._pp_schedule.step(*stage_inputs) + elif self.pp_is_last: + # Last stage: receive activations, compute loss, run backward + self._pp_schedule.step(target=batch['labels'], losses=losses) + else: + # Middle stages: receive, forward, send + self._pp_schedule.step() + + self.fwd_bwd_time = time.perf_counter() - s_model + + # Collect the loss scalar for logging (only last stage has it) + if self.pp_is_last and losses: + loss = losses[0].detach() + else: + loss = torch.zeros(1, device=self.device) + + # Reduce loss across the PP group so all ranks have the same value for logging + torch.distributed.all_reduce( + loss, op=torch.distributed.ReduceOp.SUM, group=self.pp_group.get_group() + ) + + return self._maybe_optimizer_step(loss, optimizer) + + def _maybe_optimizer_step(self, loss, optimizer): + """Shared optimizer-step logic after fwd+bwd (regular and PP paths).""" self.current_accum_count += 1 if self.current_accum_count >= self.current_accum_target: @@ -504,31 +587,28 @@ def train_step(self, data_iterator, optimizer): optimizer.zero_grad() lr = optimizer.param_groups[0]['lr'] - self.global_step += 1 avg_loss, max_loss, global_tokens, global_assistant, global_samples = ( dist_mean(loss, self.dp_group), dist_max(loss, self.dp_group), dist_sum( - torch.tensor( - self.tokens_seen, dtype=torch.int64, device=self.device - ), + torch.tensor(self.tokens_seen, dtype=torch.int64, device=self.device), self.dp_group, ), dist_sum( - torch.tensor( - self.tokens_seen_assistant, dtype=torch.int64, device=self.device - ), + torch.tensor(self.tokens_seen_assistant, dtype=torch.int64, device=self.device), self.dp_group, ), dist_sum( torch.tensor(self.samples_since_last_log, dtype=torch.int32, device=self.device), self.dp_group, - ) + ), ) - self.train_step_delta = (time.perf_counter() - self.time_last_log) / self.current_accum_target + self.train_step_delta = ( + (time.perf_counter() - self.time_last_log) / self.current_accum_target + ) if self.if_log_rank(): self.log(avg_loss, max_loss, global_tokens, global_assistant, global_samples, lr) @@ -537,7 +617,6 @@ def train_step(self, data_iterator, optimizer): self.ntokens_since_last_log = 0 self.samples_since_last_log = 0 self.time_last_log = time.perf_counter() - self.current_accum_count = 0 self.current_accum_target = next(self.accum_schedule) From 56c048da030d7ceb4173ac6eda1c524617000a9f Mon Sep 17 00:00:00 2001 From: tomiock <38719343+tomiock@users.noreply.github.com> Date: Thu, 23 Apr 2026 17:23:15 +0200 Subject: [PATCH 02/23] Enhance README with online datapacking and loading info Updated README.md to include details about online datapacking and model loading mechanisms. --- README.md | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index eabf020..b32295c 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ Massive-scale VLM pre-training and finetuning on HPC environments. It is specifically designed and tested for **Marenostrum 5** and **JUPITER**. Works similary to torchtitan, only relying on native torch code for the distributed implementation. Compatibilty with HF state-dict, loads weights from HF snapshot directory. +See SCALABILITY.md and USAGE.md for more details. + ## Key Features * **Supported Architectures:** **Qwen3.5**, Qwen3-VL and Qwen3 (text). * **2D Parallelism:** FSDP/DDP (Single & Multi-node) and Tensor Parallelism (TP) support. Tested scaling up to 256 GPUs. @@ -32,19 +34,12 @@ Support for ROCm systems (LUMI) is work in progress. - `transformers=5.6.0` ## Datasets and Dataloading -Datasets are expected to be as a CrudeWebdataset. With https://github.com/NVIDIA/Megatron-Energon we handle the raw data and tokenize it on the fly. It is an asynchrnos process that does not have an impact on model performance. - -**Online datapacking is not yet supported** (no particular issue related to the HPC system, its skill issue on my part). We believe that data packing is a must-have for a visual-language training codebase with native resolution, as the varying image sizes on the datasets can be handled easily. +Datasets are expected to be as a CrudeWebdataset. With https://github.com/NVIDIA/Megatron-Energon we handle the raw data and tokenize it on the fly. It is an asynchrnos process that does not have an impact on model performance. **Online datapacking is used by default.** Support for Metadatasets (multiple sources). ## Model Weights & Offline Loading -Use `utils/down.py` on a login node to pre-download model weights and tokenizers to a shared filesystem: - -```bash -python utils/down.py -``` -Go into the file and change the arguments, it does not have CLI support. +Use `utils/down.py` on a login node to pre-download model weights and tokenizers to a shared filesystem. The models' archicture configuration relies on what is downloaded. -**Loading Mechanism:** During training, models are instantiated directly from these local paths. For Native Torch models, the architecture is initialized purely in PyTorch, and the offline weights are mapped and loaded directly into the native state dictionary. +**Loading Mechanism:** During training, models are instantiated directly from these local paths. The architecture is initialized purely in PyTorch, and the offline weights are mapped and loaded directly into the native state dictionary. ## Usage 1. Ensure your datasets are formatted as Nvidia Energon webdatasets. @@ -67,5 +62,6 @@ The codebase demonstrates linear scaling up to 256 GPUs using FSDP and Tensor Pa For a detailed breakdown of throughput, GPU efficiency, and scaling characteristics, please refer to [SCALABILITY.md](SCALABILITY.md). ## Known Issues & TODOs -* Online data packing for Energon dataloading is not yet supported. +* The entire workflow `training -> checkpoints -> eval/usage` needs a lot of work. * Static shape compilation (`torch.compile` with `fullgraph=True`) is pending. +* A better data packing implemented is needed. From ed1d14753127021d89894410f7eac4de1a2d9693 Mon Sep 17 00:00:00 2001 From: tomiock <38719343+tomiock@users.noreply.github.com> Date: Thu, 23 Apr 2026 17:26:01 +0200 Subject: [PATCH 03/23] Revise scalability metrics in SCALABILITY.md Updated scalability metrics for Qwen3.5-2B and Qwen3.5-9B models, including new throughput and scaling test results. --- SCALABILITY.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/SCALABILITY.md b/SCALABILITY.md index 7672d9a..f8a65e5 100644 --- a/SCALABILITY.md +++ b/SCALABILITY.md @@ -1,6 +1,12 @@ ## Qwen3.5-2B @ JUPITER - 16 node (64 H200 96GB) tested -- 10,000 tks/sec/device +- +15,000 tks/sec/gpu + +## Qwen3.5-9B @ JUPITER +- scaling test from 16 to 256 nodes +- +500 TFLOPS/s/gpu +image + ## Qwen3-VL-8B @ JUPITER - ~380 TFLOPS with 4 nodes (16 GH200 96GB) @@ -16,4 +22,4 @@ ### Results Scalability throughput with 8B model on Marenostrum 5: -image +image From ba3ed0e20f2710c9b8e1f253ea31b278b64edbad Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 22 Apr 2026 18:27:45 +0200 Subject: [PATCH 04/23] jup config --- configs/jupiter/qwen3_5_2b.toml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/configs/jupiter/qwen3_5_2b.toml b/configs/jupiter/qwen3_5_2b.toml index 2f4b2e4..26cbfe5 100644 --- a/configs/jupiter/qwen3_5_2b.toml +++ b/configs/jupiter/qwen3_5_2b.toml @@ -34,10 +34,16 @@ tp_size = 1 resume_checkpoint = false random_init_mlp = false -compile = true +compile = false [data] data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" seq_len = 8192 -batch_size = 66 \ No newline at end of file + +shuffle_buffer_size = 1000 +packing_buffer_size = 1000 +max_samples_per_sequence = 100 + +batch_size = 0 + From 60f2680cde10f5a1419f0bbf5804a19156aef564 Mon Sep 17 00:00:00 2001 From: tomiock Date: Thu, 23 Apr 2026 17:37:14 +0200 Subject: [PATCH 05/23] jupiter data packing config --- configs/jupiter/qwen3_5_2b.toml | 12 ++++++------ data/energon_dataloader.py | 25 ++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/configs/jupiter/qwen3_5_2b.toml b/configs/jupiter/qwen3_5_2b.toml index 26cbfe5..cfa8502 100644 --- a/configs/jupiter/qwen3_5_2b.toml +++ b/configs/jupiter/qwen3_5_2b.toml @@ -17,27 +17,27 @@ model_dir = "/e/project1/reformo/ockier1/qwen_models/qwen3_5_2b" output_dir = "/e/scratch/reformo/ockier1/checkpoints/test_35_2b" tpi_multiplier = 1.0 -save_steps = 1000 +save_steps = 200 scheduler_type = "cosine" -total_steps = 18000 -warmup_steps = 100 +total_steps = 500 +warmup_steps = 10 #wsd_decay_ratio = 0.1 min_lr_ratio = 0.1 lr_llm = 0.00002 lr_mlp = 0.0001 -data_parallel = 'fsdp' +data_parallel = 'ddp' tp_size = 1 resume_checkpoint = false -random_init_mlp = false +random_init = true compile = false [data] -data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" +data_path = "/e/data1/datasets/products/llava_onevision_mid_training_85m/imagenet/EN" seq_len = 8192 diff --git a/data/energon_dataloader.py b/data/energon_dataloader.py index a60422e..4a43268 100644 --- a/data/energon_dataloader.py +++ b/data/energon_dataloader.py @@ -93,6 +93,28 @@ class EnergonSample(Sample): image: torch.Tensor messages: list +@stateless +def cooker_llava_imagenet(sample: dict, add_system_prompt: bool = True) -> EnergonSample: + messages = [ + {'role': 'user', 'content': [ + {"type": "image"} + ]}, + {'role': 'assistant', 'content': [ + {"type": "text", "text": sample['txt']} + ]}, + ] + + if not add_system_prompt: + messages.append({"role": "system", "content": [{"type": "text", "text": ""}]}) + + image = sample['jpg'] + + return EnergonSample( + **basic_sample_keys(sample), + image=image, + messages=messages, + ) + @stateless def cooker_captioning(sample: dict, add_system_prompt: bool = True) -> EnergonSample: role_map = {'human': 'user', 'gpt': 'assistant', 'user': 'user', 'assistant': 'assistant'} @@ -256,7 +278,8 @@ def __init__(self, processor, max_seq_len): cookers = [ # subflavors can be used to distinguish datasets when using a Metadataset - Cooker(cooker_captioning), + Cooker(cooker_captioning, has_subflavors={"type_dataset": "synth"}), + Cooker(cooker_llava_imagenet, has_subflavors={"type_dataset": "llava_onevision_midtraining"}), ] # transform the RAW data, tokenize a single sample From fe08445c6ca37651dc3583343808c028f1088c9e Mon Sep 17 00:00:00 2001 From: tomiock Date: Thu, 23 Apr 2026 22:25:57 +0200 Subject: [PATCH 06/23] 9B PP=4 but error on 27B --- configs/jupiter/qwen3_5_27b.toml | 49 ++++++++++++++++++++++++++++++++ configs/jupiter/qwen3_5_9b.toml | 14 +++++---- models/qwen3_5/model.py | 7 ++--- scripts/jup_finetune.sh | 3 +- train/train_qwen.py | 14 ++++----- train/utils.py | 2 ++ utils/down.py | 4 +-- 7 files changed, 72 insertions(+), 21 deletions(-) create mode 100644 configs/jupiter/qwen3_5_27b.toml diff --git a/configs/jupiter/qwen3_5_27b.toml b/configs/jupiter/qwen3_5_27b.toml new file mode 100644 index 0000000..6b9753d --- /dev/null +++ b/configs/jupiter/qwen3_5_27b.toml @@ -0,0 +1,49 @@ +[model] +model_name = "Qwen/Qwen3.5-27B" +model_impl = "native" + +train_llm = true +train_mlp = true +train_vit = false + +[wandb] +run_name = "test 27b" +project_name = "scaling_27b" +entity_name = "bsc_runs" + +[training] +model_dir = "/e/project1/reformo/ockier1/qwen_models/qwen3_5_27b" +output_dir = "/e/scratch/reformo/ockier1/checkpoints/test_35_27b" + +tpi_multiplier = 1.0 +save_steps = 1000 + +scheduler_type = "cosine" +total_steps = 18000 +warmup_steps = 100 +#wsd_decay_ratio = 0.1 +min_lr_ratio = 0.1 + +lr_llm = 0.00002 +lr_mlp = 0.0001 + +data_parallel = 'ddp' + +pp_size = 1 +tp_size = 1 + +resume_checkpoint = false +random_init = false + +ac_mode = 'off' + +compile = false +async_tp = false + +[data] +data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" + +packing_buffer_size = 100 + +seq_len = 512 +batch_size = 0 \ No newline at end of file diff --git a/configs/jupiter/qwen3_5_9b.toml b/configs/jupiter/qwen3_5_9b.toml index 8b145b2..dbe49d4 100644 --- a/configs/jupiter/qwen3_5_9b.toml +++ b/configs/jupiter/qwen3_5_9b.toml @@ -28,18 +28,22 @@ lr_llm = 0.00002 lr_mlp = 0.0001 data_parallel = 'fsdp' -tp_size = 4 + +pp_size = 4 +tp_size = 1 resume_checkpoint = false -random_init_mlp = false +random_init = false ac_mode = 'off' -compile = true +compile = false async_tp = false [data] data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" -seq_len = 10240 -batch_size = 64 \ No newline at end of file +packing_buffer_size = 100 + +seq_len = 8192 +batch_size = 0 \ No newline at end of file diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index fd361c5..fe20d94 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from pathlib import Path import torch @@ -808,6 +809,8 @@ def from_pretrained( with torch.device("meta"): model = cls(cfg) + + # error here model = model.to_empty(device=device).to(dtype=dtype) load_safetensors_into( @@ -818,9 +821,6 @@ def from_pretrained( load_vision=load_vision, ) - # `to_empty` above re-materializes every parameter and breaks the - # tie established in `__init__`. Re-tie here so `lm_head` (absent - # from checkpoints when tied) shares storage with the embedding. if cfg.tie_word_embeddings: model.lm_head.weight = model.model.language_model.embed_tokens.weight @@ -834,7 +834,6 @@ def from_pretrained( ) model.model.visual.rotary_pos_emb.inv_freq = inv_freq_v - # Recompute text inv_freq (non-persistent buffer wiped by `to_empty`). head_dim = cfg.text.head_dim partial = cfg.text.rope_parameters.get('partial_rotary_factor', 1.0) rope_dim = int(head_dim * partial) diff --git a/scripts/jup_finetune.sh b/scripts/jup_finetune.sh index 35bcb04..7d50387 100755 --- a/scripts/jup_finetune.sh +++ b/scripts/jup_finetune.sh @@ -2,7 +2,6 @@ MASTER_ADDR="127.0.0.1" MASTER_PORT=$(shuf -i 20000-29999 -n 1) -NGPUS=$(nvidia-smi --list-gpus | wc -l) export WANDB_MODE=offline export HF_HUB_OFFLINE=1 @@ -23,7 +22,7 @@ ulimit -s unlimited torchrun \ --nnodes=1 \ - --nproc_per_node=$NGPUS \ + --nproc_per_node=4 \ --rdzv_id 101 \ --rdzv_backend c10d \ --rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ diff --git a/train/train_qwen.py b/train/train_qwen.py index dfa896c..2a9e1a8 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -152,22 +152,17 @@ def __init__(self, cfg: Config): else: logger.info('model not initlized, incompatible') - # replace flash_attn self.model.train() if self.model_args.model_impl == "hf": self.model.enable_input_require_grads() self.optimizer = None # its defined later on - if self.training_args.bfloat16: - self.model = self.model.to(torch.bfloat16) - logger.info("model loaded") if self.pp_size > 1: assert self.training_args.tp_size == 1, "TP + PP is not yet supported" assert self.model_type == ModelType.Qwen3_5, \ "Pipeline Parallel only implemented for Qwen3.5 native model" - ( self.model, self._pp_pipeline_stage, @@ -184,7 +179,7 @@ def __init__(self, cfg: Config): f"is_last={self.pp_is_last}" ) else: - self.pp_rank = 0 + self.pp_rank = 0 self.pp_is_last = True if self.training_args.tp_size > 1: @@ -208,10 +203,13 @@ def __init__(self, cfg: Config): raise Exception('invalid sharding strategy for Data Parallel') if self.pp_size > 1: - # With PP, apply DDP to the stage module within the DP group if self.training_args.data_parallel == 'ddp': self.model = replicate(self.model, device_mesh=self.dp_group) - # (FSDP support for PP stages can be added later) + + self.model = self.model.to(self.device) + + if self.training_args.bfloat16: + self.model = self.model.to(torch.bfloat16) # get rank of local GPU that belongs to the DP group data_rank = self.dp_group.get_local_rank() diff --git a/train/utils.py b/train/utils.py index c920568..efb117f 100644 --- a/train/utils.py +++ b/train/utils.py @@ -512,6 +512,8 @@ def get_dense_model_nparams_and_flops( tied = False elif "9B" in model_name: tied = False + elif "27B" in model_name: + tied = False elif "2B" in model_name: tied = True elif "4B" in model_name: diff --git a/utils/down.py b/utils/down.py index 09b895e..1428d84 100644 --- a/utils/down.py +++ b/utils/down.py @@ -1,6 +1,6 @@ from huggingface_hub import snapshot_download snapshot_download( - repo_id="Qwen/Qwen3.5-9B", - local_dir="/data/151-1/users/tockier/qwen_finetune/cache/qwen35_9b", + repo_id="Qwen/Qwen3.5-27B", + local_dir="/e/project1/reformo/ockier1/qwen_models/qwen3_5_27b", ) From 92b5ccc68ab679c8c4604bbc61b22237b06fdb88 Mon Sep 17 00:00:00 2001 From: tomiock Date: Fri, 24 Apr 2026 13:28:01 +0200 Subject: [PATCH 07/23] [feat] 27B model running w/ PP=6 - model meta init before PP/TP applied - dynamic layer split according to PP dim --- configs/cvc/qwen3_5_27b.toml | 34 ++++++++++ models/qwen3_5/model.py | 5 ++ models/qwen3_5/utils.py | 88 +++++++++++++++++++++++++ train/config.py | 1 + train/infra.py | 122 +++++++++++++++++++++++++++-------- train/train_qwen.py | 86 +++++++++++++++--------- train/utils.py | 19 ++++-- 7 files changed, 294 insertions(+), 61 deletions(-) create mode 100644 configs/cvc/qwen3_5_27b.toml diff --git a/configs/cvc/qwen3_5_27b.toml b/configs/cvc/qwen3_5_27b.toml new file mode 100644 index 0000000..edcdf75 --- /dev/null +++ b/configs/cvc/qwen3_5_27b.toml @@ -0,0 +1,34 @@ +[model] +model_name = "Qwen/Qwen3.5-27B" +model_impl = "native" + +train_llm = true +train_mlp = true +train_vit = false + +[wandb] +run_name = "test" +project_name = "qwen35_27b" + +[training] +model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_27b" +output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" + +save_steps = 10000 +total_steps = 10000 +random_init = false + +tp_size = 1 +pp_size = 4 +data_parallel = 'fsdp' + +ac_mode = "full" +compile = false + +[data] +data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" +seq_len = 400 + +packing_buffer_size = 100 + +batch_size = 0 \ No newline at end of file diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index fd361c5..ba144e6 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -796,6 +796,7 @@ def from_pretrained( device: str | torch.device = "cpu", *, load_vision: bool = True, + weights: bool = True, ) -> "Qwen3_5ForCausalLM": snapshot_dir = Path(snapshot_dir) cfg = Qwen3VLConfig.from_json(snapshot_dir / "config.json") @@ -808,6 +809,10 @@ def from_pretrained( with torch.device("meta"): model = cls(cfg) + + if not weights: + return model + model = model.to_empty(device=device).to(dtype=dtype) load_safetensors_into( diff --git a/models/qwen3_5/utils.py b/models/qwen3_5/utils.py index 14bf20f..ea59ed2 100644 --- a/models/qwen3_5/utils.py +++ b/models/qwen3_5/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -195,3 +196,90 @@ def load_safetensors_into( missing = {m for m in missing if not m.startswith("model.visual.")} if missing: raise RuntimeError(f"Missing weights after load: {sorted(missing)[:8]} ... ({len(missing)} total)") + + +def load_stage_weights( + stage: nn.Module, + snapshot_dir: Path, + layer_start: int, + layer_end: int, + is_first: bool, + is_last: bool, + device: torch.device | str, + dtype: torch.dtype, +) -> None: + """Load checkpoint weights for a single PP stage directly onto ``device``. + + The stage must have been created from a meta model and materialized with + ``to_empty(device)`` before this call. Uses ``load_state_dict(assign=True)`` + so tensors are assigned directly — no extra CPU copy. + """ + from models.qwen3_5.config import Qwen3VLConfig + + snapshot_dir = Path(snapshot_dir) + cfg = Qwen3VLConfig.from_json(snapshot_dir / "config.json") + + # Map each stage state-dict key to the corresponding checkpoint key. + # Stage key namespace → HF checkpoint namespace + # layers.{i}.* → model.language_model.layers.{layer_start+i}.* + # embed_tokens.* → model.language_model.embed_tokens.* + # visual.* → model.visual.* + # norm.* → model.language_model.norm.* + # lm_head.* → lm_head.* + # text_inv_freq → (computed; skip) + stage_to_ckpt: dict[str, str] = {} + for key in stage.state_dict(): + if key == "text_inv_freq": + continue # recomputed from config after loading + elif key.startswith("layers."): + rest = key[len("layers."):] + i_str, suffix = rest.split(".", 1) + ckpt_key = f"model.language_model.layers.{layer_start + int(i_str)}.{suffix}" + stage_to_ckpt[key] = ckpt_key + elif key.startswith("embed_tokens."): + stage_to_ckpt[key] = f"model.language_model.{key}" + elif key.startswith("visual."): + stage_to_ckpt[key] = f"model.{key}" + elif key.startswith("norm."): + stage_to_ckpt[key] = f"model.language_model.{key}" + elif key.startswith("lm_head."): + stage_to_ckpt[key] = key + + # Determine which shard files contain the needed keys. + index_path = snapshot_dir / "model.safetensors.index.json" + if index_path.exists(): + with open(index_path) as f: + weight_map: dict[str, str] = json.load(f)["weight_map"] + shard_to_keys: dict[str, list[str]] = defaultdict(list) + for ckpt_key in stage_to_ckpt.values(): + if ckpt_key in weight_map: + shard_to_keys[weight_map[ckpt_key]].append(ckpt_key) + else: + single = snapshot_dir / "model.safetensors" + assert single.exists(), f"No safetensors found in {snapshot_dir}" + shard_to_keys = {single.name: list(stage_to_ckpt.values())} + + # Load only the needed tensors from disk, directly onto target device. + ckpt_tensors: dict[str, torch.Tensor] = {} + for shard_name, keys in shard_to_keys.items(): + with safe_open(str(snapshot_dir / shard_name), framework="pt", device=str(device)) as f: + for k in keys: + ckpt_tensors[k] = f.get_tensor(k).to(dtype=dtype) + + stage_sd = {sk: ckpt_tensors[ck] for sk, ck in stage_to_ckpt.items() if ck in ckpt_tensors} + stage.load_state_dict(stage_sd, assign=True, strict=False) + + # Recompute non-checkpoint buffers that to_empty() wiped. + if is_first: + head_dim = cfg.text.head_dim + partial = cfg.text.rope_parameters.get("partial_rotary_factor", 1.0) + rope_dim = int(head_dim * partial) + stage.text_inv_freq = 1.0 / ( + cfg.text.rope_parameters["rope_theta"] + ** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=device) / rope_dim) + ) + head_dim_v = cfg.vision.hidden_size // cfg.vision.num_heads + rdim = head_dim_v // 2 + stage.visual.rotary_pos_emb.inv_freq = 1.0 / ( + 10000.0 ** (torch.arange(0, rdim, 2, dtype=torch.float32, device=device) / rdim) + ) diff --git a/train/config.py b/train/config.py index b309346..2e21abf 100644 --- a/train/config.py +++ b/train/config.py @@ -96,6 +96,7 @@ class Training: Use `fsdp` when you want to decrease usage to increase seq_len/batch_size. """ + # compiler flag for TP (goes faster) async_tp: bool = True diff --git a/train/infra.py b/train/infra.py index dd72091..73e7191 100644 --- a/train/infra.py +++ b/train/infra.py @@ -668,15 +668,46 @@ def _apply_tp_to_decoder_qwen3_5( _PP_MAX_SEQS: int = 256 -def _pp_layer_ranges(n_layers: int, pp_size: int) -> list[tuple[int, int]]: - """Equal-split layer index ranges across pp_size stages.""" - assert pp_size in (2, 4), f"pp_size must be 2 or 4, got {pp_size}" - base, rem = divmod(n_layers, pp_size) - ranges, start = [], 0 - for i in range(pp_size): - end = start + base + (1 if i < rem else 0) - ranges.append((start, end)) - start = end +def _pp_layer_ranges( + n_layers: int, + pp_size: int, + first_virtual: float = 1.0, + last_virtual: float = 0.0, +) -> list[tuple[int, int]]: + """ + Distribute n_layers across pp_size stages for balanced memory. + + first_virtual / last_virtual: overhead of the non-layer modules on the + first/last stage expressed in units of a single transformer layer. + Computed from actual parameter counts so the split automatically adapts + to any model size. + + Optimal layers per stage: + target = (n_layers + first_virtual + last_virtual) / pp_size + first_n = round(target - first_virtual) + last_n = round(target - last_virtual) [pp_size == 4 only] + middle = evenly distributed remainder + """ + target = (n_layers + first_virtual + last_virtual) / pp_size + first_n = max(1, round(target - first_virtual)) + + if pp_size == 2: + return [(0, first_n), (first_n, n_layers)] + + last_n = max(1, round(target - last_virtual)) + remaining = n_layers - first_n - last_n + assert remaining >= 2, ( + f"Not enough layers for middle stages with " + f"first_n={first_n}, last_n={last_n}, n_layers={n_layers}" + ) + # all minus 2 (last and first) + mid, extra = divmod(remaining, pp_size - 2) + ranges, pos = [(0, first_n)], first_n + for i in range(pp_size - 2): + n = mid + (1 if i < extra else 0) + ranges.append((pos, pos + n)) + pos += n + ranges.append((pos, n_layers)) return ranges @@ -993,39 +1024,76 @@ def step(self, *args, target=None, losses=None): def apply_pp_qwen35( model: nn.Module, - pp_group, # DeviceMesh sub-mesh for the PP dimension + pp_group, seq_len: int, + *, + snapshot_dir=None, # Path | None — when set, model is on meta; weights loaded per-rank + device=None, # torch.device | None — required when snapshot_dir is set + dtype=None, # torch.dtype | None — required when snapshot_dir is set ) -> tuple: """ Split Qwen3_5ForCausalLM across PP ranks (pp_size must be 2 or 4). - Returns (stage_module, None, schedule, loss_fn, pp_rank, pp_size, is_last). - The caller should: - - wrap stage_module with DDP/FSDP if needed - - call schedule.step(...) in the training loop (see train_qwen.py) + + Meta-loading path (large models): + Pass ``snapshot_dir``, ``device``, ``dtype``. The model must be on + ``torch.device("meta")`` with no weights. Each rank materialises only + its stage slice and loads the corresponding weights directly onto + ``device`` via ``load_stage_weights`` — no full-model CPU or GPU copy. + + Legacy path: + Omit those kwargs. The full model must already reside on the target + device with weights loaded (original behaviour, kept for compatibility). """ pp_rank: int = pp_group.get_local_rank() pp_size: int = pp_group.size() - device = torch.device(f"cuda:{torch.cuda.current_device()}") - - assert pp_size in (2, 4), f"apply_pp_qwen35: pp_size must be 2 or 4, got {pp_size}" + meta_load = snapshot_dir is not None + + if not meta_load: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + lm = model.model.language_model + n_layers = len(lm.layers) + layer_params = sum(p.numel() for p in lm.layers[0].parameters()) + embed_params = lm.embed_tokens.weight.numel() + visual = getattr(model.model, "visual", None) + visual_params = sum(p.numel() for p in visual.parameters()) if visual is not None else 0 + lm_head_params = model.lm_head.weight.numel() + + first_virtual = (embed_params + visual_params) / layer_params + last_virtual = lm_head_params / layer_params + + ranges = _pp_layer_ranges(n_layers, pp_size, first_virtual, last_virtual) + if pp_rank == 0: + counts = [e - s for s, e in ranges] + print( + f"[PP] layer split {counts} " + f"(first_virtual={first_virtual:.2f} last_virtual={last_virtual:.2f})", + flush=True, + ) + ls, le = ranges[pp_rank] + is_first = pp_rank == 0 + is_last = pp_rank == pp_size - 1 - n_layers = len(model.model.language_model.layers) - ranges = _pp_layer_ranges(n_layers, pp_size) - ls, le = ranges[pp_rank] - is_first = pp_rank == 0 - is_last = pp_rank == pp_size - 1 + # Read config values before to_empty() modifies parameters. + H = model.model.language_model.cfg.hidden_size + R = model.text_inv_freq.shape[0] * 2 # rope_dim = 2 * len(inv_freq) - stage_module = PPStageModule(model, ls, le, is_first, is_last).to(device) + stage_module = PPStageModule(model, ls, le, is_first, is_last) - H = model.model.language_model.cfg.hidden_size - R = model.text_inv_freq.shape[0] * 2 # rope_dim = 2 * len(inv_freq) - dt = next(model.parameters()).dtype + if meta_load: + from models.qwen3_5.utils import load_stage_weights + stage_module.to_empty(device=device) + stage_module.to(dtype) + load_stage_weights(stage_module, snapshot_dir, ls, le, is_first, is_last, device, dtype) + else: + dtype = next(model.parameters()).dtype + stage_module = stage_module.to(device) loss_fn = _ScaledLoss() schedule = _PPSchedule( stage_module, pp_rank, pp_size, pp_group, - hidden_size=H, rope_dim=R, dtype=dt, + hidden_size=H, rope_dim=R, dtype=dtype, loss_fn=loss_fn if is_last else None, ) diff --git a/train/train_qwen.py b/train/train_qwen.py index dfa896c..945014c 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -5,6 +5,7 @@ import wandb import transformers from itertools import cycle +from pathlib import Path import time @@ -124,9 +125,20 @@ def __init__(self, cfg: Config): else: raise NotImplementedError(f"model not supported: {self.model_args.model_name}") - self.model = select_model_class(self.model_type, self.model_args, self.training_args) + # For PP + native: load on meta, weights loaded per-rank inside apply_pp_qwen35, + # avoiding a full-model CPU→GPU copy on every rank before the split. + pp_meta_load = ( + self.pp_size > 1 + and self.model_args.model_impl == "native" + and not self.training_args.random_init + ) + + self.model = select_model_class( + self.model_type, self.model_args, self.training_args, meta_only=pp_meta_load + ) # we calculate the flops per token used to get the MFU number + # (works on meta tensors: shapes are valid even without data) num_params, self.flops_per_token = get_dense_model_nparams_and_flops( self.model_args.model_name, self.model, @@ -135,39 +147,40 @@ def __init__(self, cfg: Config): logger.info(f"Number params: {num_params}") - if self.training_args.load_text_model: - self.text_model = select_text_model(self.training_args) - self.model = load_text_model(self.model, self.text_model) + self.optimizer = None # defined later on - # MOVE TO cuda:{self.local_rank} - self.model.to(self.device) - - if self.training_args.random_init: - if self.model_type == ModelType.Qwen3_5: - logger.info('initilizing decoder and projecter of Qwen3.5') - init_qwen35(self.model) - elif self.model_type == ModelType.Qwen3_vl: - logger.info('initilizing projector of Qwen3-VL') - init_qwen3vl(self.model) - else: - logger.info('model not initlized, incompatible') + if not pp_meta_load: + if self.training_args.load_text_model: + self.text_model = select_text_model(self.training_args) + self.model = load_text_model(self.model, self.text_model) - # replace flash_attn - self.model.train() - if self.model_args.model_impl == "hf": - self.model.enable_input_require_grads() - self.optimizer = None # its defined later on + self.model.to(self.device) - if self.training_args.bfloat16: - self.model = self.model.to(torch.bfloat16) + if self.training_args.random_init: + if self.model_type == ModelType.Qwen3_5: + logger.info('initilizing decoder and projecter of Qwen3.5') + init_qwen35(self.model) + elif self.model_type == ModelType.Qwen3_vl: + logger.info('initilizing projector of Qwen3-VL') + init_qwen3vl(self.model) + else: + logger.info('model not initlized, incompatible') + + self.model.train() + if self.model_args.model_impl == "hf": + self.model.enable_input_require_grads() + + if self.training_args.bfloat16: + self.model = self.model.to(torch.bfloat16) - logger.info("model loaded") + logger.info("model loaded") if self.pp_size > 1: assert self.training_args.tp_size == 1, "TP + PP is not yet supported" assert self.model_type == ModelType.Qwen3_5, \ "Pipeline Parallel only implemented for Qwen3.5 native model" + dtype = torch.bfloat16 if self.training_args.bfloat16 else torch.float32 ( self.model, self._pp_pipeline_stage, @@ -176,12 +189,27 @@ def __init__(self, cfg: Config): self.pp_rank, _, self.pp_is_last, - ) = apply_pp_qwen35(self.model, self.pp_group, int(self.data_args.seq_len)) + ) = apply_pp_qwen35( + self.model, + self.pp_group, + int(self.data_args.seq_len), + snapshot_dir=Path(self.training_args.model_dir) if pp_meta_load else None, + device=self.device if pp_meta_load else None, + dtype=dtype if pp_meta_load else None, + ) - logger.info( - f"PP applied: rank {self.pp_rank}/{self.pp_size}, " - f"layers {[len(list(self.model.layers))]}, " - f"is_last={self.pp_is_last}" + if pp_meta_load: + self.model.train() + logger.info("model loaded") + + n_layers_this_rank = len(list(self.model.layers)) + mem_alloc_mb = torch.cuda.memory_allocated() / 1024**2 + mem_reserv_mb = torch.cuda.memory_reserved() / 1024**2 + print( + f"[PP rank {self.pp_rank}/{self.pp_size}] " + f"layers={n_layers_this_rank} is_last={self.pp_is_last} " + f"mem_alloc={mem_alloc_mb:.0f} MiB mem_reserved={mem_reserv_mb:.0f} MiB", + flush=True, ) else: self.pp_rank = 0 diff --git a/train/utils.py b/train/utils.py index c920568..3b13838 100644 --- a/train/utils.py +++ b/train/utils.py @@ -233,7 +233,7 @@ def collect(reason: str, generation: int = 1): logger.info("[GC] %s took %.2f seconds", reason, time.monotonic() - begin) -def select_model_class(model_type: ModelType, model_args: ModelArgs, training_args: TrainArgs): +def select_model_class(model_type: ModelType, model_args: ModelArgs, training_args: TrainArgs, meta_only: bool = False): """ TODO: use ModelType instead of model name """ @@ -245,7 +245,7 @@ def select_model_class(model_type: ModelType, model_args: ModelArgs, training_ar model_name = model_args.model_name.lower() if model_args.model_impl == "native": - return _select_native_model_class(training_args, model_name) + return _select_native_model_class(training_args, model_name, meta_only=meta_only) elif model_args.model_impl != "hf": raise ValueError( f"Unknown model_impl '{model_args.model_impl}'. Expected 'hf' or 'native'." @@ -307,8 +307,13 @@ def select_model_class(model_type: ModelType, model_args: ModelArgs, training_ar return model -def _select_native_model_class(training_args: TrainArgs, model_name: str): - """Dispatch to our torch-native model implementations under `models/`.""" +def _select_native_model_class(training_args: TrainArgs, model_name: str, meta_only: bool = False): + """Dispatch to our torch-native model implementations under `models/`. + + When ``meta_only=True`` the model is returned on ``torch.device("meta")`` + with no weights loaded. The caller is responsible for materialising + parameters and loading weights (e.g. via ``load_stage_weights`` for PP). + """ dtype = torch.bfloat16 if training_args.bfloat16 else torch.float32 if "qwen3-vl" in model_name: @@ -326,8 +331,10 @@ def _select_native_model_class(training_args: TrainArgs, model_name: str): training_args.model_dir, dtype=dtype, device="cpu", + weights=not meta_only, ) - logger.info(f"Loaded native {model_name} from {training_args.model_dir}") + if not meta_only: + logger.info(f"Loaded native {model_name} from {training_args.model_dir}") return model def select_text_model(training_args): @@ -512,6 +519,8 @@ def get_dense_model_nparams_and_flops( tied = False elif "9B" in model_name: tied = False + elif "27B" in model_name: + tied = False elif "2B" in model_name: tied = True elif "4B" in model_name: From d48c74f4c69ec6b3aba6924e5897fc084bcc7b34 Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 15:51:39 +0200 Subject: [PATCH 08/23] broken loss, models gives `nan` on logits --- configs/cvc/qwen3_5_2b.toml | 11 +- models/qwen3_5/model.py | 2 - models/qwen3_5/utils.py | 1 - train/config.py | 8 +- train/train_qwen.py | 306 ++++++++++++++++++++---------------- 5 files changed, 182 insertions(+), 146 deletions(-) diff --git a/configs/cvc/qwen3_5_2b.toml b/configs/cvc/qwen3_5_2b.toml index 4a256ac..8f85558 100644 --- a/configs/cvc/qwen3_5_2b.toml +++ b/configs/cvc/qwen3_5_2b.toml @@ -17,11 +17,6 @@ random_init = false scheduler_type = "cosine" -total_steps: int = 1_000 -warmup_steps: int = 100 -wsd_decay_ratio: float = 0.1 -min_lr_ratio: float = 0.1 - tp_size = 1 pp_size = 2 data_parallel = 'ddp' @@ -31,11 +26,11 @@ compile = false tpi_multiplier = 1 [data] -data_path = "/data/151-1/datasets/llava_recap" -seq_len = 8192 +data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" +seq_len = 256 shuffle_buffer_size = 1000 -packing_buffer_size = 1000 +packing_buffer_size = 100 max_samples_per_sequence = 100 batch_size = 0 diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 1bd51eb..301cfee 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -811,8 +811,6 @@ def from_pretrained( with torch.device("meta"): model = cls(cfg) - if not weights: - return model model = model.to_empty(device=device).to(dtype=dtype) load_safetensors_into( diff --git a/models/qwen3_5/utils.py b/models/qwen3_5/utils.py index ea59ed2..cb9306a 100644 --- a/models/qwen3_5/utils.py +++ b/models/qwen3_5/utils.py @@ -53,7 +53,6 @@ def causal_lm_loss( labels: torch.Tensor, ignore_index: int = -100, ) -> torch.Tensor: - # Match HF ForCausalLMLoss: upcast to fp32 before CE to avoid bf16 precision issues. shift_logits = logits[..., :-1, :].contiguous().float() shift_labels = labels[..., 1:].contiguous() return F.cross_entropy( diff --git a/train/config.py b/train/config.py index 2e21abf..3b6dc21 100644 --- a/train/config.py +++ b/train/config.py @@ -90,12 +90,16 @@ class Training: # --------------- data_parallel: str = "ddp" # fsdp, ddp - tp_size: int = 1 # 1 means disabled - pp_size: int = 1 # 1 means disabled; supported values: 1, 2, 4 """ Use `fsdp` when you want to decrease usage to increase seq_len/batch_size. """ + tp_size: int = 1 # 1 means disabled + pp_size: int = 1 # 1 means disabled; supported values: 2, 2, 4 + + pp_num_layers_first: int = 1 + pp_num_layers_last: int = 1 + # compiler flag for TP (goes faster) async_tp: bool = True diff --git a/train/train_qwen.py b/train/train_qwen.py index 2fba0f9..a218b20 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -13,6 +13,8 @@ from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed._composable.replicate import replicate +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ScheduleGPipe from torch.profiler import profile, record_function, ProfilerActivity, schedule @@ -34,8 +36,9 @@ apply_ac, ACConfig, compile_model, - apply_pp_qwen35, ) +from models.qwen3_5.utils import causal_lm_loss + from train.utils import ( set_determinism, generate_accumulation_pattern, @@ -61,6 +64,58 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True +def get_local_fqns( + num_layers: int, + pp_size: int, + pp_rank: int, + num_first: int, + num_last: int +) -> list[str]: + if pp_size == 1: + return [ + "model.visual", + "model.language_model.embed_tokens", + ] + [f"model.language_model.layers.{i}" for i in range(num_layers)] + [ + "model.language_model.norm", + "lm_head" + ] + + fqns = [] + + if pp_rank == 0: + fqns.extend([ + "model.visual", + "model.language_model.embed_tokens" + ]) + start_idx = 0 + end_idx = num_first + + elif pp_rank == pp_size - 1: + start_idx = num_layers - num_last + end_idx = num_layers + + else: + middle_layers = num_layers - num_first - num_last + middle_ranks = pp_size - 2 + + layers_per_mid = middle_layers // middle_ranks + remainder = middle_layers % middle_ranks + + mid_idx = pp_rank - 1 + start_idx = num_first + (mid_idx * layers_per_mid) + min(mid_idx, remainder) + num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) + end_idx = start_idx + num_layers_this_rank + + for i in range(start_idx, end_idx): + fqns.append(f"model.language_model.layers.{i}") + + if pp_rank == pp_size - 1: + fqns.extend([ + "model.language_model.norm", + "lm_head" + ]) + + return fqns class Trainer(torch.distributed.checkpoint.stateful.Stateful): @@ -125,16 +180,9 @@ def __init__(self, cfg: Config): else: raise NotImplementedError(f"model not supported: {self.model_args.model_name}") - # For PP + native: load on meta, weights loaded per-rank inside apply_pp_qwen35, - # avoiding a full-model CPU→GPU copy on every rank before the split. - pp_meta_load = ( - self.pp_size > 1 - and self.model_args.model_impl == "native" - and not self.training_args.random_init - ) - + # Load the model on CPU with weights; for PP the split stage is moved to GPU below. self.model = select_model_class( - self.model_type, self.model_args, self.training_args, meta_only=pp_meta_load + self.model_type, self.model_args, self.training_args, ) # we calculate the flops per token used to get the MFU number @@ -149,100 +197,105 @@ def __init__(self, cfg: Config): self.optimizer = None # defined later on - if not pp_meta_load: - if self.training_args.load_text_model: - self.text_model = select_text_model(self.training_args) - self.model = load_text_model(self.model, self.text_model) + if self.training_args.load_text_model: + self.text_model = select_text_model(self.training_args) + self.model = load_text_model(self.model, self.text_model) - self.model.to(self.device) + if self.pp_size > 1: + logger.info("Applying Pipeline Parallelism module split...") + pp_rank = self.mesh.get_local_rank(mesh_dim="pp") + total_layers = self.model.cfg.text.num_hidden_layers + + local_fqns = get_local_fqns( + num_layers=total_layers, + pp_size=self.pp_size, + pp_rank=pp_rank, + num_first=self.training_args.pp_num_layers_first, + num_last=self.training_args.pp_num_layers_last + ) - if self.training_args.random_init: - if self.model_type == ModelType.Qwen3_5: - logger.info('initilizing decoder and projecter of Qwen3.5') - init_qwen35(self.model) - elif self.model_type == ModelType.Qwen3_vl: - logger.info('initilizing projector of Qwen3-VL') - init_qwen3vl(self.model) - else: - logger.info('model not initlized, incompatible') + if "model.visual" not in local_fqns: + self.model.model.visual = None + if "model.language_model.embed_tokens" not in local_fqns: + self.model.model.language_model.embed_tokens = None + + layers = self.model.model.language_model.layers + kept_indices = {int(f.split('.')[-1]) for f in local_fqns if "layers." in f} + self.model.model.language_model.layers = torch.nn.ModuleList( + [m for i, m in enumerate(layers) if i in kept_indices] + ) + + if "model.language_model.norm" not in local_fqns: + self.model.model.language_model.norm = None + if "lm_head" not in local_fqns: + self.model.lm_head = None - self.model.train() - if self.model_args.model_impl == "hf": - self.model.enable_input_require_grads() + self.model.to_empty(device=self.device) - if self.training_args.bfloat16: - self.model = self.model.to(torch.bfloat16) + self.pp_has_first_stage = (pp_rank == 0) + self.pp_has_last_stage = (pp_rank == self.pp_size - 1) - logger.info("model loaded") + self.pp_stage = PipelineStage( + self.model, + stage_index=pp_rank, + num_stages=self.pp_size, + device=self.device, + group=self.mesh.get_group(mesh_dim="pp"), + ) - if self.pp_size > 1: - assert self.training_args.tp_size == 1, "TP + PP is not yet supported" - assert self.model_type == ModelType.Qwen3_5, \ - "Pipeline Parallel only implemented for Qwen3.5 native model" + def pp_loss_fn(logits, labels): + return causal_lm_loss(logits, labels) / self.current_accum_target - dtype = torch.bfloat16 if self.training_args.bfloat16 else torch.float32 - ( - self.model, - self._pp_pipeline_stage, - self._pp_schedule, - self._pp_loss_fn, - self.pp_rank, - _, - self.pp_is_last, - ) = apply_pp_qwen35( - self.model, - self.pp_group, - int(self.data_args.seq_len), - snapshot_dir=Path(self.training_args.model_dir) if pp_meta_load else None, - device=self.device if pp_meta_load else None, - dtype=dtype if pp_meta_load else None, + self.pp_schedule = ScheduleGPipe( + self.pp_stage, + n_microbatches=1, + loss_fn=pp_loss_fn, ) - if pp_meta_load: - self.model.train() - logger.info("model loaded") - - n_layers_this_rank = len(list(self.model.layers)) - mem_alloc_mb = torch.cuda.memory_allocated() / 1024**2 - mem_reserv_mb = torch.cuda.memory_reserved() / 1024**2 - print( - f"[PP rank {self.pp_rank}/{self.pp_size}] " - f"layers={n_layers_this_rank} is_last={self.pp_is_last} " - f"mem_alloc={mem_alloc_mb:.0f} MiB mem_reserved={mem_reserv_mb:.0f} MiB", - flush=True, + if self.training_args.random_init: + if self.model_type == ModelType.Qwen3_5: + logger.info('initilizing decoder and projecter of Qwen3.5') + init_qwen35(self.model) + elif self.model_type == ModelType.Qwen3_vl: + logger.info('initilizing projector of Qwen3-VL') + init_qwen3vl(self.model) + else: + logger.info('model not initlized, incompatible') + + self.model.train() + if self.training_args.bfloat16: + self.model = self.model.to(torch.bfloat16) + + logger.info("model loaded") + + if self.training_args.tp_size > 1: + apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) + + ac_mode = getattr(self.training_args, "ac_mode", "off") + if ac_mode != "off": + ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) + apply_ac( + self.model.model.language_model, + ac_cfg, + model_compile_enabled=self.training_args.compile, ) + logger.info(f"activation checkpointing applied ({ac_mode})") + + if self.training_args.data_parallel == 'fsdp': + apply_fsdp(self.model_type, self.model, mesh=self.dp_group) + elif self.training_args.data_parallel == 'ddp': + self.model = replicate(self.model, device_mesh=self.dp_group) else: - self.pp_rank = 0 - self.pp_is_last = True - - if self.training_args.tp_size > 1: - apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) - - ac_mode = getattr(self.training_args, "ac_mode", "off") - if ac_mode != "off": - ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) - apply_ac( - self.model.model.language_model, - ac_cfg, - model_compile_enabled=self.training_args.compile, - ) - logger.info(f"activation checkpointing applied ({ac_mode})") - - if self.training_args.data_parallel == 'fsdp': - apply_fsdp(self.model_type, self.model, mesh=self.dp_group) - elif self.training_args.data_parallel == 'ddp': - self.model = replicate(self.model, device_mesh=self.dp_group) - else: - raise Exception('invalid sharding strategy for Data Parallel') + raise Exception('invalid sharding strategy for Data Parallel') if self.pp_size > 1: if self.training_args.data_parallel == 'ddp': self.model = replicate(self.model, device_mesh=self.dp_group) - self.model = self.model.to(self.device) + self.model = self.model.to_empty(device=self.device) - if self.training_args.bfloat16: - self.model = self.model.to(torch.bfloat16) + # if self.training_args.bfloat16: + # self.model = self.model.to(torch.bfloat16) # get rank of local GPU that belongs to the DP group data_rank = self.dp_group.get_local_rank() @@ -264,7 +317,7 @@ def __init__(self, cfg: Config): self.processor = AutoProcessor.from_pretrained( self.training_args.model_dir, - max_pixels=1048576, + ) # set_model freezes/unfreezes param groups; skip for PP (stage module @@ -547,66 +600,53 @@ def setup_accumulation(self, tpi_multiplier=1.5): self.current_accum_target = next(self.accum_schedule) self.current_accum_count = 0 - def train_step(self, data_iterator, optimizer): - if self.pp_size > 1: - return self._train_step_pp(data_iterator, optimizer) - return self._train_step_regular(data_iterator, optimizer) - - def _train_step_regular(self, data_iterator, optimizer): + def _train_step_pp(self, data_iterator, optimizer): batch = next(data_iterator) + input_ids = batch.pop('input_ids') + labels = batch.pop('labels', None) - s_model = time.perf_counter() - with record_function("forward_pass"): - with torch.autocast('cuda', torch.bfloat16): - outputs = self.model(**batch) - loss = outputs.loss + losses = [] if self.pp_has_last_stage else None + target = labels if self.pp_has_last_stage else None - with record_function("backward_pass"): - scaled_loss = loss / self.current_accum_target + s_model = time.perf_counter() + with record_function("pp_forward_backward"): with torch.autocast('cuda', torch.bfloat16): - scaled_loss.backward() + if self.pp_has_first_stage: + self.pp_schedule.step(input_ids, **batch, target=target, losses=losses) + else: + self.pp_schedule.step(**batch, target=target, losses=losses) self.fwd_bwd_time = time.perf_counter() - s_model + + loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) + # loss needs to be propagated accross PP group + torch.distributed.all_reduce(loss, group=self.pp_group.get_group()) return self._maybe_optimizer_step(loss, optimizer) - def _train_step_pp(self, data_iterator, optimizer): - """Training step for Pipeline Parallel (Qwen3.5 native, pp_size 2 or 4).""" + def _train_step(self, data_iterator, optimizer): batch = next(data_iterator) - - # Update the loss scaling factor used inside the last stage - if self.pp_is_last: - self._pp_loss_fn.accum_target = self.current_accum_target + input_ids = batch.pop('input_ids') + labels = batch.pop('labels', None) s_model = time.perf_counter() - losses: list = [] - - with record_function("pp_forward_backward"): + with record_function("forward_pass"): with torch.autocast('cuda', torch.bfloat16): - if self.pp_rank == 0: - # Preprocess: visual encode + embed + RoPE → fixed-shape tensors - stage_inputs = self.model.preprocess(batch) - self._pp_schedule.step(*stage_inputs) - elif self.pp_is_last: - # Last stage: receive activations, compute loss, run backward - self._pp_schedule.step(target=batch['labels'], losses=losses) - else: - # Middle stages: receive, forward, send - self._pp_schedule.step() + logits = self.model(input_ids, **batch) + loss = causal_lm_loss(logits, labels) + breakpoint() + + with record_function("backward_pass"): + scaled_loss = loss / self.current_accum_target + scaled_loss.backward() self.fwd_bwd_time = time.perf_counter() - s_model + return self._maybe_optimizer_step(loss, optimizer) - # Collect the loss scalar for logging (only last stage has it) - if self.pp_is_last and losses: - loss = losses[0].detach() + def train_step(self, data_iterator, optimizer): + if self.pp_size == 1: + return self._train_step(data_iterator, optimizer) else: - loss = torch.zeros(1, device=self.device) - - # Reduce loss across the PP group so all ranks have the same value for logging - torch.distributed.all_reduce( - loss, op=torch.distributed.ReduceOp.SUM, group=self.pp_group.get_group() - ) - - return self._maybe_optimizer_step(loss, optimizer) + return self._train_step_pp(data_iterator, optimizer) def _maybe_optimizer_step(self, loss, optimizer): """Shared optimizer-step logic after fwd+bwd (regular and PP paths).""" @@ -727,4 +767,4 @@ def trace_handler(prof): torch.manual_seed(42) trainer = Trainer(config) - trainer.train() \ No newline at end of file + trainer.train() From 2afd6c7fe1e2a093d1c50406cffb6823dd4b0f0a Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 16:13:18 +0200 Subject: [PATCH 09/23] [fix] corrent loss --- train/train_qwen.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train/train_qwen.py b/train/train_qwen.py index a218b20..0a75db8 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -292,10 +292,10 @@ def pp_loss_fn(logits, labels): if self.training_args.data_parallel == 'ddp': self.model = replicate(self.model, device_mesh=self.dp_group) - self.model = self.model.to_empty(device=self.device) - - # if self.training_args.bfloat16: - # self.model = self.model.to(torch.bfloat16) + # loading into GPU + self.model = self.model.to(device=self.device) + if self.training_args.bfloat16: + self.model = self.model.to(torch.bfloat16) # get rank of local GPU that belongs to the DP group data_rank = self.dp_group.get_local_rank() @@ -633,7 +633,6 @@ def _train_step(self, data_iterator, optimizer): with torch.autocast('cuda', torch.bfloat16): logits = self.model(input_ids, **batch) loss = causal_lm_loss(logits, labels) - breakpoint() with record_function("backward_pass"): scaled_loss = loss / self.current_accum_target From a4b5083bd7a354db26cd1980b600a5ab61f7170b Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 17:11:59 +0200 Subject: [PATCH 10/23] loss issues w/ pipeline --- models/qwen3_5/model.py | 118 ++++++++++++++++++++-------------------- train/train_qwen.py | 7 ++- 2 files changed, 65 insertions(+), 60 deletions(-) diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 301cfee..d1e1a27 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -294,7 +294,8 @@ def forward( x[visual_pos_masks] = ( x[visual_pos_masks] + deepstack_visual_embeds[i].to(x.dtype) ) - return self.norm(x) + + return self.norm(x) if self.norm is not None else x class VisionPatchEmbed(nn.Module): def __init__(self, cfg: Qwen3_5VisionConfig): @@ -664,9 +665,9 @@ def _compute_cos_sin(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, to def forward( self, - input_ids: torch.Tensor | None = None, + hidden_states: torch.Tensor | None = None, *, - inputs_embeds: torch.Tensor | None = None, + input_ids: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, pixel_values_videos: torch.Tensor | None = None, @@ -684,17 +685,18 @@ def forward( (same tensor consumed by `torch.nn.attention.varlen.varlen_attn`). If `attention_mask` is None, the whole row is treated as one sample. """ - assert (input_ids is None) ^ (inputs_embeds is None) - if input_ids is not None and input_ids.dim() == 1: - input_ids = input_ids.unsqueeze(0) - if input_ids is not None: - assert input_ids.dim() == 2 and input_ids.shape[0] == 1, ( - f"varlen expects packed (1, total), got {tuple(input_ids.shape)}" - ) - - if inputs_embeds is None: + if getattr(self.model.language_model, "embed_tokens", None) is not None: + input_ids = hidden_states inputs_embeds = self.model.language_model.embed_tokens(input_ids) - assert inputs_embeds.dim() == 3 and inputs_embeds.shape[0] == 1 + else: + inputs_embeds = hidden_states + if input_ids is None: + raise ValueError("input_ids must be passed to intermediate stages for MRoPE calculation.") + + assert inputs_embeds.dim() == 3 and inputs_embeds.shape[0] == 1, ( + f"inputs_embeds should be (1, total, hidden_dim), got {tuple(inputs_embeds.shape)}" + ) + total = inputs_embeds.shape[1] device = inputs_embeds.device @@ -713,43 +715,42 @@ def forward( visual_pos_masks: torch.Tensor | None = None deepstack_visual_embeds: list[torch.Tensor] | None = None - if pixel_values is not None: - assert image_grid_thw is not None - merged, deepstack = self.model.visual(pixel_values, image_grid_thw) - merged = merged.to(inputs_embeds.dtype) - image_mask = input_ids == self.cfg.image_token_id - assert image_mask.sum().item() == merged.shape[0], ( - f"image tokens={image_mask.sum().item()} vs features={merged.shape[0]}" - ) - inputs_embeds = inputs_embeds.masked_scatter( - image_mask.unsqueeze(-1).expand_as(inputs_embeds), merged - ) - visual_pos_masks = image_mask - deepstack_visual_embeds = deepstack - - if pixel_values_videos is not None: - assert video_grid_thw is not None - merged_v, deepstack_v = self.model.visual(pixel_values_videos, video_grid_thw) - merged_v = merged_v.to(inputs_embeds.dtype) - video_mask = input_ids == self.cfg.video_token_id - inputs_embeds = inputs_embeds.masked_scatter( - video_mask.unsqueeze(-1).expand_as(inputs_embeds), merged_v - ) - if visual_pos_masks is None: - visual_pos_masks = video_mask - deepstack_visual_embeds = deepstack_v - else: - combined = visual_pos_masks | video_mask - image_only = visual_pos_masks[combined] - video_only = video_mask[combined] - merged_ds = [] - for img_ds, vid_ds in zip(deepstack_visual_embeds, deepstack_v): - e = img_ds.new_zeros(combined.sum().item(), img_ds.shape[-1]) - e[image_only] = img_ds - e[video_only] = vid_ds - merged_ds.append(e) - visual_pos_masks = combined - deepstack_visual_embeds = merged_ds + if getattr(self.model, "visual", None) is not None: + if pixel_values is not None: + assert image_grid_thw is not None + merged, deepstack = self.model.visual(pixel_values, image_grid_thw) + merged = merged.to(inputs_embeds.dtype) + image_mask = input_ids == self.cfg.image_token_id + #assert image_mask.sum().item() == merged.shape[0], (f"image tokens={image_mask.sum().item()} vs features={merged.shape[0]}") + inputs_embeds = inputs_embeds.masked_scatter( + image_mask.unsqueeze(-1).expand_as(inputs_embeds), merged + ) + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack + + if pixel_values_videos is not None: + assert video_grid_thw is not None + merged_v, deepstack_v = self.model.visual(pixel_values_videos, video_grid_thw) + merged_v = merged_v.to(inputs_embeds.dtype) + video_mask = input_ids == self.cfg.video_token_id + inputs_embeds = inputs_embeds.masked_scatter( + video_mask.unsqueeze(-1).expand_as(inputs_embeds), merged_v + ) + if visual_pos_masks is None: + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_v + else: + combined = visual_pos_masks | video_mask + image_only = visual_pos_masks[combined] + video_only = video_mask[combined] + merged_ds = [] + for img_ds, vid_ds in zip(deepstack_visual_embeds, deepstack_v): + e = img_ds.new_zeros(combined.sum().item(), img_ds.shape[-1]) + e[image_only] = img_ds + e[video_only] = vid_ds + merged_ds.append(e) + visual_pos_masks = combined + deepstack_visual_embeds = merged_ds if position_ids is None: if image_grid_thw is not None or video_grid_thw is not None: @@ -780,14 +781,15 @@ def forward( visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, ) - logits = self.lm_head(h) - - if labels is None: - return logits - if labels.dim() == 1: - labels = labels.unsqueeze(0) - loss = causal_lm_loss(logits, labels) - return CausalLMOutput(loss=loss, logits=logits) + + if self.lm_head is not None: + logits = self.lm_head(h) + if labels is None: + return logits + loss = causal_lm_loss(logits, labels) + return CausalLMOutput(loss=loss, logits=logits) + else: + return h @classmethod def from_pretrained( diff --git a/train/train_qwen.py b/train/train_qwen.py index 0a75db8..ad65217 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -602,9 +602,12 @@ def setup_accumulation(self, tpi_multiplier=1.5): def _train_step_pp(self, data_iterator, optimizer): batch = next(data_iterator) - input_ids = batch.pop('input_ids') labels = batch.pop('labels', None) + # this is weird but "trust me, it works" + input_ids = batch.pop('input_ids') + batch['input_ids'] = input_ids + losses = [] if self.pp_has_last_stage else None target = labels if self.pp_has_last_stage else None @@ -619,7 +622,7 @@ def _train_step_pp(self, data_iterator, optimizer): self.fwd_bwd_time = time.perf_counter() - s_model loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) - # loss needs to be propagated accross PP group + # loss needs to be propagated across PP group torch.distributed.all_reduce(loss, group=self.pp_group.get_group()) return self._maybe_optimizer_step(loss, optimizer) From 1aff281a6b5c778c89fa68cb505a4edee15fff8f Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 17:42:17 +0200 Subject: [PATCH 11/23] we need to load the weights in PP --- models/qwen3_5/model.py | 8 ++------ train/train_qwen.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index d1e1a27..4f76bc7 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -782,12 +782,8 @@ def forward( deepstack_visual_embeds=deepstack_visual_embeds, ) - if self.lm_head is not None: - logits = self.lm_head(h) - if labels is None: - return logits - loss = causal_lm_loss(logits, labels) - return CausalLMOutput(loss=loss, logits=logits) + if getattr(self, "lm_head", None) is not None: + return self.lm_head(h) else: return h diff --git a/train/train_qwen.py b/train/train_qwen.py index ad65217..92fe6e6 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -244,6 +244,7 @@ def __init__(self, cfg: Config): ) def pp_loss_fn(logits, labels): + print(logits) return causal_lm_loss(logits, labels) / self.current_accum_target self.pp_schedule = ScheduleGPipe( @@ -602,9 +603,8 @@ def setup_accumulation(self, tpi_multiplier=1.5): def _train_step_pp(self, data_iterator, optimizer): batch = next(data_iterator) + labels = batch.pop('labels', None) - - # this is weird but "trust me, it works" input_ids = batch.pop('input_ids') batch['input_ids'] = input_ids @@ -621,10 +621,12 @@ def _train_step_pp(self, data_iterator, optimizer): self.fwd_bwd_time = time.perf_counter() - s_model - loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) - # loss needs to be propagated across PP group - torch.distributed.all_reduce(loss, group=self.pp_group.get_group()) - return self._maybe_optimizer_step(loss, optimizer) + scaled_loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) + + loss_for_logging = scaled_loss * self.current_accum_target + torch.distributed.all_reduce(loss_for_logging, group=self.pp_group.get_group()) + + return self._maybe_optimizer_step(loss_for_logging, optimizer) def _train_step(self, data_iterator, optimizer): batch = next(data_iterator) From 6a7de7c09b46393fd60799a26f027c4a613eb083 Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 18:26:21 +0200 Subject: [PATCH 12/23] dev checkpoint --- models/qwen3_5/utils.py | 49 ++++++++++++++--------------------------- train/train_qwen.py | 17 +++++++++----- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/models/qwen3_5/utils.py b/models/qwen3_5/utils.py index cb9306a..05bb74f 100644 --- a/models/qwen3_5/utils.py +++ b/models/qwen3_5/utils.py @@ -196,7 +196,6 @@ def load_safetensors_into( if missing: raise RuntimeError(f"Missing weights after load: {sorted(missing)[:8]} ... ({len(missing)} total)") - def load_stage_weights( stage: nn.Module, snapshot_dir: Path, @@ -207,48 +206,29 @@ def load_stage_weights( device: torch.device | str, dtype: torch.dtype, ) -> None: - """Load checkpoint weights for a single PP stage directly onto ``device``. - - The stage must have been created from a meta model and materialized with - ``to_empty(device)`` before this call. Uses ``load_state_dict(assign=True)`` - so tensors are assigned directly — no extra CPU copy. - """ from models.qwen3_5.config import Qwen3VLConfig snapshot_dir = Path(snapshot_dir) cfg = Qwen3VLConfig.from_json(snapshot_dir / "config.json") - # Map each stage state-dict key to the corresponding checkpoint key. - # Stage key namespace → HF checkpoint namespace - # layers.{i}.* → model.language_model.layers.{layer_start+i}.* - # embed_tokens.* → model.language_model.embed_tokens.* - # visual.* → model.visual.* - # norm.* → model.language_model.norm.* - # lm_head.* → lm_head.* - # text_inv_freq → (computed; skip) stage_to_ckpt: dict[str, str] = {} for key in stage.state_dict(): - if key == "text_inv_freq": - continue # recomputed from config after loading - elif key.startswith("layers."): - rest = key[len("layers."):] + if "inv_freq" in key: + continue + + if key.startswith("model.language_model.layers."): + rest = key[len("model.language_model.layers."):] i_str, suffix = rest.split(".", 1) - ckpt_key = f"model.language_model.layers.{layer_start + int(i_str)}.{suffix}" - stage_to_ckpt[key] = ckpt_key - elif key.startswith("embed_tokens."): - stage_to_ckpt[key] = f"model.language_model.{key}" - elif key.startswith("visual."): - stage_to_ckpt[key] = f"model.{key}" - elif key.startswith("norm."): - stage_to_ckpt[key] = f"model.language_model.{key}" - elif key.startswith("lm_head."): + global_layer_idx = layer_start + int(i_str) + stage_to_ckpt[key] = f"model.language_model.layers.{global_layer_idx}.{suffix}" + + else: stage_to_ckpt[key] = key - # Determine which shard files contain the needed keys. index_path = snapshot_dir / "model.safetensors.index.json" if index_path.exists(): with open(index_path) as f: - weight_map: dict[str, str] = json.load(f)["weight_map"] + weight_map = json.load(f)["weight_map"] shard_to_keys: dict[str, list[str]] = defaultdict(list) for ckpt_key in stage_to_ckpt.values(): if ckpt_key in weight_map: @@ -258,7 +238,6 @@ def load_stage_weights( assert single.exists(), f"No safetensors found in {snapshot_dir}" shard_to_keys = {single.name: list(stage_to_ckpt.values())} - # Load only the needed tensors from disk, directly onto target device. ckpt_tensors: dict[str, torch.Tensor] = {} for shard_name, keys in shard_to_keys.items(): with safe_open(str(snapshot_dir / shard_name), framework="pt", device=str(device)) as f: @@ -266,9 +245,13 @@ def load_stage_weights( ckpt_tensors[k] = f.get_tensor(k).to(dtype=dtype) stage_sd = {sk: ckpt_tensors[ck] for sk, ck in stage_to_ckpt.items() if ck in ckpt_tensors} + + missing = set(stage_to_ckpt.keys()) - set(stage_sd.keys()) + if missing: + raise RuntimeError(f"Missing weights for stage: {list(missing)[:5]}... ({len(missing)} total)") + stage.load_state_dict(stage_sd, assign=True, strict=False) - # Recompute non-checkpoint buffers that to_empty() wiped. if is_first: head_dim = cfg.text.head_dim partial = cfg.text.rope_parameters.get("partial_rotary_factor", 1.0) @@ -279,6 +262,6 @@ def load_stage_weights( ) head_dim_v = cfg.vision.hidden_size // cfg.vision.num_heads rdim = head_dim_v // 2 - stage.visual.rotary_pos_emb.inv_freq = 1.0 / ( + stage.model.visual.rotary_pos_emb.inv_freq = 1.0 / ( 10000.0 ** (torch.arange(0, rdim, 2, dtype=torch.float32, device=device) / rdim) ) diff --git a/train/train_qwen.py b/train/train_qwen.py index 92fe6e6..86340e6 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -37,7 +37,7 @@ ACConfig, compile_model, ) -from models.qwen3_5.utils import causal_lm_loss +from models.qwen3_5.utils import causal_lm_loss, load_stage_weights from train.utils import ( set_determinism, @@ -230,10 +230,18 @@ def __init__(self, cfg: Config): if "lm_head" not in local_fqns: self.model.lm_head = None - self.model.to_empty(device=self.device) + self.model.to(device=self.device) - self.pp_has_first_stage = (pp_rank == 0) - self.pp_has_last_stage = (pp_rank == self.pp_size - 1) + self.pp_has_first_stage = self.model.model.visual is not None + self.pp_has_last_stage = self.model.lm_head is not None + + layer_indices = [int(f.split('.')[-1]) for f in local_fqns if "layers." in f] + layer_start = min(layer_indices) if layer_indices else 0 + layer_end = max(layer_indices) + 1 if layer_indices else 0 + + target_dtype = torch.bfloat16 if self.training_args.bfloat16 else torch.float32 + + self.model = self.model.to(self.device) self.pp_stage = PipelineStage( self.model, @@ -244,7 +252,6 @@ def __init__(self, cfg: Config): ) def pp_loss_fn(logits, labels): - print(logits) return causal_lm_loss(logits, labels) / self.current_accum_target self.pp_schedule = ScheduleGPipe( From c89d450674d47f24de2d007b6847ce8740ace5d2 Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 22:17:10 +0200 Subject: [PATCH 13/23] [fix] PP loading weights --- models/qwen3_5/model.py | 10 + train/infra.py | 440 ---------------------------------------- train/train_qwen.py | 27 ++- 3 files changed, 35 insertions(+), 442 deletions(-) diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 4f76bc7..888e0f4 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -231,6 +231,16 @@ def __init__(self, cfg: Qwen3_5TextConfig): def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) +class MLP_MoE(nn.Module): + def __init__(self, cfg: Qwen3_5TextConfig): + super().__init__() + self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) + self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) + self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + class DecoderLayer(nn.Module): def __init__(self, cfg: Qwen3_5TextConfig, layer_type: str): super().__init__() diff --git a/train/infra.py b/train/infra.py index 73e7191..490c62a 100644 --- a/train/infra.py +++ b/train/infra.py @@ -658,443 +658,3 @@ def _apply_tp_to_decoder_qwen3_5( if enable_async_tp: torch._inductor.config._micro_pipeline_tp = True - - -# --------------------------------------------------------------------------- -# Pipeline Parallel for Qwen3.5 (native impl, dense model) -# --------------------------------------------------------------------------- - -# cu_seqlens is padded to this fixed size so PipelineStage sees constant shapes. -_PP_MAX_SEQS: int = 256 - - -def _pp_layer_ranges( - n_layers: int, - pp_size: int, - first_virtual: float = 1.0, - last_virtual: float = 0.0, -) -> list[tuple[int, int]]: - """ - Distribute n_layers across pp_size stages for balanced memory. - - first_virtual / last_virtual: overhead of the non-layer modules on the - first/last stage expressed in units of a single transformer layer. - Computed from actual parameter counts so the split automatically adapts - to any model size. - - Optimal layers per stage: - target = (n_layers + first_virtual + last_virtual) / pp_size - first_n = round(target - first_virtual) - last_n = round(target - last_virtual) [pp_size == 4 only] - middle = evenly distributed remainder - """ - target = (n_layers + first_virtual + last_virtual) / pp_size - first_n = max(1, round(target - first_virtual)) - - if pp_size == 2: - return [(0, first_n), (first_n, n_layers)] - - last_n = max(1, round(target - last_virtual)) - remaining = n_layers - first_n - last_n - assert remaining >= 2, ( - f"Not enough layers for middle stages with " - f"first_n={first_n}, last_n={last_n}, n_layers={n_layers}" - ) - # all minus 2 (last and first) - mid, extra = divmod(remaining, pp_size - 2) - ranges, pos = [(0, first_n)], first_n - for i in range(pp_size - 2): - n = mid + (1 if i < extra else 0) - ranges.append((pos, pos + n)) - pos += n - ranges.append((pos, n_layers)) - return ranges - - -class PPStageModule(nn.Module): - """ - A single PP stage for Qwen3.5ForCausalLM (native impl). - - Owns decoder layers[layer_start:layer_end]. - Stage 0 (is_first=True) also owns the visual encoder and embed_tokens. - Last stage (is_last=True) also owns norm and lm_head. - - Inter-stage tensor protocol (all fixed shapes): - hidden_states : (1, seq_len, hidden_size) dtype - cos, sin : (1, seq_len, rope_dim) dtype - cu_seqlens_pad : (_PP_MAX_SEQS+1,) int32 - n_seqs : () int64 - """ - - def __init__( - self, - full_model: nn.Module, - layer_start: int, - layer_end: int, - is_first: bool, - is_last: bool, - ): - super().__init__() - lm = full_model.model.language_model - - self.is_first = is_first - self.is_last = is_last - self.layers = nn.ModuleList(list(lm.layers[layer_start:layer_end])) - - if is_first: - n_ds = len(full_model.model.visual.deepstack_visual_indexes) - assert n_ds <= (layer_end - layer_start), ( - f"Stage 0 has {layer_end - layer_start} layers but {n_ds} deepstack " - "injections — all must fit in stage 0. Reduce pp_size or use a " - "larger first-stage split." - ) - self.visual = full_model.model.visual - self.embed_tokens = lm.embed_tokens - self.register_buffer("text_inv_freq", full_model.text_inv_freq.clone()) - self.mrope_section = list(full_model.mrope_section) - self.image_token_id = full_model.cfg.image_token_id - self.video_token_id = full_model.cfg.video_token_id - self.spatial_merge_size = full_model.cfg.vision.spatial_merge_size - # Populated by preprocess() before each forward - self._vis_masks: Optional[torch.Tensor] = None - self._ds_embeds: Optional[list] = None - - if is_last: - self.norm = lm.norm - self.lm_head = full_model.lm_head - - # ------------------------------------------------------------------ - def preprocess(self, batch: dict) -> tuple: - """ - Stage-0 only: visual encoding + token embedding + RoPE. - - Call this on pp_rank=0 before schedule.step() to produce the - fixed-shape tensors that enter the pipeline. - """ - from models.qwen3_5.utils import mrope_cos_sin - - input_ids = batch["input_ids"] # (1, seq_len) - pixel_values = batch.get("pixel_values") - image_grid_thw = batch.get("image_grid_thw") - video_grid_thw = batch.get("video_grid_thw") - cu_seqlens = batch["attention_mask"] # (n_seqs+1,) int32 - device = input_ids.device - total = input_ids.shape[1] - - x = self.embed_tokens(input_ids) # (1, total, H) - - self._vis_masks = None - self._ds_embeds = None - has_img = pixel_values is not None and pixel_values.numel() > 0 - has_vid = video_grid_thw is not None and video_grid_thw.numel() > 0 - - if has_img: - merged, ds = self.visual(pixel_values, image_grid_thw) - merged = merged.to(x.dtype) - mask = input_ids == self.image_token_id - x = x.masked_scatter(mask.unsqueeze(-1).expand_as(x), merged) - self._vis_masks, self._ds_embeds = mask, ds - - if has_vid: - merged_v, ds_v = self.visual(batch["pixel_values_videos"], video_grid_thw) - merged_v = merged_v.to(x.dtype) - vmask = input_ids == self.video_token_id - x = x.masked_scatter(vmask.unsqueeze(-1).expand_as(x), merged_v) - if self._vis_masks is None: - self._vis_masks, self._ds_embeds = vmask, ds_v - else: - combined = self._vis_masks | vmask - merged_ds = [] - for a, b in zip(self._ds_embeds, ds_v): - e = a.new_zeros(combined.sum().item(), a.shape[-1]) - e[self._vis_masks[combined]] = a - e[vmask[combined]] = b - merged_ds.append(e) - self._vis_masks, self._ds_embeds = combined, merged_ds - - # position ids → cos/sin - if has_img or has_vid: - pos = _mrope_position_ids( - input_ids, cu_seqlens, - image_grid_thw if has_img else None, - video_grid_thw if has_vid else None, - self.image_token_id, self.video_token_id, self.spatial_merge_size, - ) - else: - pos = torch.zeros(total, dtype=torch.int64, device=device) - for s, e in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist()): - pos[s:e] = torch.arange(e - s, device=device) - pos = pos.view(1, 1, -1).expand(3, 1, -1) - - cos, sin = mrope_cos_sin(self.text_inv_freq, pos, self.mrope_section) - cos, sin = cos.to(x.dtype), sin.to(x.dtype) - - # pad cu_seqlens to fixed size - n = cu_seqlens.shape[0] - assert n <= _PP_MAX_SEQS + 1, f"cu_seqlens len {n} > _PP_MAX_SEQS={_PP_MAX_SEQS}" - cu_pad = F.pad(cu_seqlens, (0, _PP_MAX_SEQS + 1 - n)) - n_t = torch.tensor(n, dtype=torch.int64, device=device) - - return x, cos, sin, cu_pad, n_t - - # ------------------------------------------------------------------ - def forward( - self, - hidden: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - cu_pad: torch.Tensor, - n_t: torch.Tensor, - ) -> tuple | torch.Tensor: - n = int(n_t.item()) - cu = cu_pad[:n].to(torch.int32) - max_s = int((cu[1:] - cu[:-1]).max().item()) - - x = hidden - for i, layer in enumerate(self.layers): - x = layer(x, cos, sin, cu, max_s) - if self.is_first and self._ds_embeds is not None and i < len(self._ds_embeds): - x = x.clone() - x[self._vis_masks] = x[self._vis_masks] + self._ds_embeds[i].to(x.dtype) - - if self.is_last: - return self.lm_head(self.norm(x)) # (1, seq_len, vocab_size) - - return x, cos, sin, cu_pad, n_t - - -class _ScaledLoss: - """Stateful loss callable — update .accum_target each step.""" - - def __init__(self) -> None: - self.accum_target: int = 1 - - def __call__(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - from models.qwen3_5.utils import causal_lm_loss - return causal_lm_loss(logits, labels) / self.accum_target - - -def _mrope_position_ids( - input_ids, cu_seqlens, image_grid_thw, video_grid_thw, - image_token_id, video_token_id, spatial_merge_size, -) -> torch.Tensor: - """3-D MRoPE positions — mirrors Qwen3_5ForCausalLM.get_rope_index.""" - if video_grid_thw is not None: - video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0).clone() - video_grid_thw[:, 0] = 1 - - _, S = input_ids.shape - device = input_ids.device - mm = torch.zeros(S, dtype=torch.int64, device=device) - mm[input_ids[0] == image_token_id] = 1 - mm[input_ids[0] == video_token_id] = 2 - types = mm.tolist() - img_it = iter(image_grid_thw) if image_grid_thw is not None else None - vid_it = iter(video_grid_thw) if video_grid_thw is not None else None - out = torch.zeros(3, 1, S, dtype=torch.int64, device=device) - for start, end in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist()): - if start == end: - continue - seg = types[start:end] - parts, cur, j = [], 0, 0 - while j < len(seg): - k = j - while k < len(seg) and seg[k] == seg[j]: - k += 1 - key, length = seg[j], k - j - if key == 0: - parts.append(torch.arange(length, device=device).view(1, -1).expand(3, -1) + cur) - cur += length - else: - g = next(img_it if key == 1 else vid_it) - t, h, w = int(g[0]), int(g[1]), int(g[2]) - lh, lw = h // spatial_merge_size, w // spatial_merge_size - n = t * lh * lw - pw = torch.arange(cur, cur + lw, device=device).repeat(lh * t) - ph = torch.arange(cur, cur + lh, device=device).repeat_interleave(lw * t) - pt = torch.full((n,), cur, device=device, dtype=torch.int64) - parts.append(torch.stack([pt, ph, pw])) - cur += max(lh, lw) - j = k - out[:, 0, start:end] = torch.cat(parts, dim=1) - return out - - -class _PPSchedule: - """ - Single-microbatch pipeline schedule using blocking P2P comms. - - Implements GPipe-style all-forward then all-backward for a linear - pp_size-stage pipeline with packed (flash-attn varlen) sequences. - - Inter-stage tensor layout: - metadata : (1,) int64 — n_tokens (variable, sent before hidden/cos/sin) - hidden : (1, n_tokens, H) bfloat16 - cos : (1, n_tokens, R) bfloat16 - sin : (1, n_tokens, R) bfloat16 - cu_pad : (_PP_MAX_SEQS+1,) int32 (fixed size) - n_t : () int64 (scalar) - Backward sends only grad_hidden : (1, n_tokens, H) bfloat16. - """ - - def __init__( - self, - stage_module: nn.Module, - pp_rank: int, - pp_size: int, - pp_group, # DeviceMesh sub-mesh (pp dimension) - hidden_size: int, - rope_dim: int, - dtype: torch.dtype, - loss_fn: Optional["_ScaledLoss"] = None, - ): - self.stage = stage_module - self.pp_rank = pp_rank - self.pp_size = pp_size - self.is_first = pp_rank == 0 - self.is_last = pp_rank == pp_size - 1 - self._group = pp_group.get_group() - self._H = hidden_size - self._R = rope_dim - self._dt = dtype - self._device = torch.device(f"cuda:{torch.cuda.current_device()}") - self.loss_fn = loss_fn - - # Fixed-size recv buffers (cu_pad, n_t, metadata always fixed) - self._cu_buf = torch.empty(_PP_MAX_SEQS + 1, device=self._device, dtype=torch.int32) - self._nt_buf = torch.empty((), device=self._device, dtype=torch.int64) - self._meta_buf = torch.empty(1, device=self._device, dtype=torch.int64) - - # ------------------------------------------------------------------ - def _send_fwd(self, x, cos, sin, cu_pad, n_t, dst: int): - n_tok = torch.tensor([x.shape[1]], device=x.device, dtype=torch.int64) - dist.send(n_tok, dst=dst, group=self._group) - dist.send(x.contiguous(), dst=dst, group=self._group) - dist.send(cos.contiguous(), dst=dst, group=self._group) - dist.send(sin.contiguous(), dst=dst, group=self._group) - dist.send(cu_pad.contiguous(), dst=dst, group=self._group) - dist.send(n_t.contiguous(), dst=dst, group=self._group) - - def _recv_fwd(self, src: int): - dist.recv(self._meta_buf, src=src, group=self._group) - n = int(self._meta_buf.item()) - x = torch.empty(1, n, self._H, device=self._device, dtype=self._dt) - cos = torch.empty(1, n, self._R, device=self._device, dtype=self._dt) - sin = torch.empty(1, n, self._R, device=self._device, dtype=self._dt) - dist.recv(x, src=src, group=self._group) - dist.recv(cos, src=src, group=self._group) - dist.recv(sin, src=src, group=self._group) - dist.recv(self._cu_buf, src=src, group=self._group) - dist.recv(self._nt_buf, src=src, group=self._group) - return x, cos, sin, self._cu_buf, self._nt_buf - - # ------------------------------------------------------------------ - def step(self, *args, target=None, losses=None): - if self.is_first and not self.is_last: - x, cos, sin, cu_pad, n_t = self.stage(*args) - self._send_fwd(x, cos, sin, cu_pad, n_t, dst=self.pp_rank + 1) - grad_x = torch.empty_like(x) - dist.recv(grad_x, src=self.pp_rank + 1, group=self._group) - x.backward(grad_x) - - elif not self.is_first and self.is_last: - x_in, cos, sin, cu_pad, n_t = self._recv_fwd(src=self.pp_rank - 1) - x_leaf = x_in.detach().requires_grad_(True) - logits = self.stage(x_leaf, cos, sin, cu_pad, n_t) - loss = self.loss_fn(logits, target) - if losses is not None: - losses.append(loss.detach()) - loss.backward() - dist.send(x_leaf.grad.contiguous(), dst=self.pp_rank - 1, group=self._group) - - elif not self.is_first and not self.is_last: - # Middle stage (pp_size=4) - x_in, cos, sin, cu_pad, n_t = self._recv_fwd(src=self.pp_rank - 1) - x_leaf = x_in.detach().requires_grad_(True) - x_out, cos_out, sin_out, cu_out, nt_out = self.stage(x_leaf, cos, sin, cu_pad, n_t) - self._send_fwd(x_out, cos_out, sin_out, cu_out, nt_out, dst=self.pp_rank + 1) - grad_x_out = torch.empty_like(x_out) - dist.recv(grad_x_out, src=self.pp_rank + 1, group=self._group) - x_out.backward(grad_x_out) - dist.send(x_leaf.grad.contiguous(), dst=self.pp_rank - 1, group=self._group) - - else: - raise RuntimeError("_PPSchedule used with pp_size=1; use regular training path instead") - - -def apply_pp_qwen35( - model: nn.Module, - pp_group, - seq_len: int, - *, - snapshot_dir=None, # Path | None — when set, model is on meta; weights loaded per-rank - device=None, # torch.device | None — required when snapshot_dir is set - dtype=None, # torch.dtype | None — required when snapshot_dir is set -) -> tuple: - """ - Split Qwen3_5ForCausalLM across PP ranks (pp_size must be 2 or 4). - Returns (stage_module, None, schedule, loss_fn, pp_rank, pp_size, is_last). - - Meta-loading path (large models): - Pass ``snapshot_dir``, ``device``, ``dtype``. The model must be on - ``torch.device("meta")`` with no weights. Each rank materialises only - its stage slice and loads the corresponding weights directly onto - ``device`` via ``load_stage_weights`` — no full-model CPU or GPU copy. - - Legacy path: - Omit those kwargs. The full model must already reside on the target - device with weights loaded (original behaviour, kept for compatibility). - """ - pp_rank: int = pp_group.get_local_rank() - pp_size: int = pp_group.size() - meta_load = snapshot_dir is not None - - if not meta_load: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - - lm = model.model.language_model - n_layers = len(lm.layers) - layer_params = sum(p.numel() for p in lm.layers[0].parameters()) - embed_params = lm.embed_tokens.weight.numel() - visual = getattr(model.model, "visual", None) - visual_params = sum(p.numel() for p in visual.parameters()) if visual is not None else 0 - lm_head_params = model.lm_head.weight.numel() - - first_virtual = (embed_params + visual_params) / layer_params - last_virtual = lm_head_params / layer_params - - ranges = _pp_layer_ranges(n_layers, pp_size, first_virtual, last_virtual) - if pp_rank == 0: - counts = [e - s for s, e in ranges] - print( - f"[PP] layer split {counts} " - f"(first_virtual={first_virtual:.2f} last_virtual={last_virtual:.2f})", - flush=True, - ) - ls, le = ranges[pp_rank] - is_first = pp_rank == 0 - is_last = pp_rank == pp_size - 1 - - # Read config values before to_empty() modifies parameters. - H = model.model.language_model.cfg.hidden_size - R = model.text_inv_freq.shape[0] * 2 # rope_dim = 2 * len(inv_freq) - - stage_module = PPStageModule(model, ls, le, is_first, is_last) - - if meta_load: - from models.qwen3_5.utils import load_stage_weights - stage_module.to_empty(device=device) - stage_module.to(dtype) - load_stage_weights(stage_module, snapshot_dir, ls, le, is_first, is_last, device, dtype) - else: - dtype = next(model.parameters()).dtype - stage_module = stage_module.to(device) - - loss_fn = _ScaledLoss() - schedule = _PPSchedule( - stage_module, pp_rank, pp_size, pp_group, - hidden_size=H, rope_dim=R, dtype=dtype, - loss_fn=loss_fn if is_last else None, - ) - - return stage_module, None, schedule, loss_fn, pp_rank, pp_size, is_last diff --git a/train/train_qwen.py b/train/train_qwen.py index 86340e6..86a92b7 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -91,10 +91,18 @@ def get_local_fqns( end_idx = num_first elif pp_rank == pp_size - 1: - start_idx = num_layers - num_last - end_idx = num_layers + # When there are middle ranks, last rank gets the last num_last layers + # When there are no middle ranks (PP=2), last rank gets everything after first rank + if pp_size > 2: + start_idx = num_layers - num_last + end_idx = num_layers + else: + # PP=2 case: rank 0 gets first num_first layers, rank 1 gets the rest + start_idx = num_first + end_idx = num_layers else: + # Middle ranks middle_layers = num_layers - num_first - num_last middle_ranks = pp_size - 2 @@ -241,6 +249,21 @@ def __init__(self, cfg: Config): target_dtype = torch.bfloat16 if self.training_args.bfloat16 else torch.float32 + # Load stage-specific weights when PP > 1 and not using random init + if self.pp_size > 1 and not self.training_args.random_init: + logger.info(f"PP rank {pp_rank}: About to load stage weights for layers {layer_start}-{layer_end}") + load_stage_weights( + stage=self.model, + snapshot_dir=self.training_args.model_dir, + layer_start=layer_start, + layer_end=layer_end, + is_first=pp_rank == 0, + is_last=pp_rank == self.pp_size - 1, + device=self.device, + dtype=target_dtype, + ) + logger.info(f"PP rank {pp_rank}: Finished loading stage weights") + self.model = self.model.to(self.device) self.pp_stage = PipelineStage( From 6fccd16b397e53451db836434f14811efc0840b4 Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 22:38:40 +0200 Subject: [PATCH 14/23] [fix] layers well split in PP ranks --- train/train_qwen.py | 56 ++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/train/train_qwen.py b/train/train_qwen.py index 86a92b7..aa1528f 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -83,45 +83,33 @@ def get_local_fqns( fqns = [] if pp_rank == 0: - fqns.extend([ - "model.visual", - "model.language_model.embed_tokens" - ]) - start_idx = 0 - end_idx = num_first - - elif pp_rank == pp_size - 1: - # When there are middle ranks, last rank gets the last num_last layers - # When there are no middle ranks (PP=2), last rank gets everything after first rank - if pp_size > 2: - start_idx = num_layers - num_last - end_idx = num_layers - else: - # PP=2 case: rank 0 gets first num_first layers, rank 1 gets the rest - start_idx = num_first - end_idx = num_layers + fqns.extend(["model.visual", "model.language_model.embed_tokens"]) + if pp_size == 2: + mid_point = num_layers // 2 + (num_layers % 2) + start_idx = 0 if pp_rank == 0 else mid_point + end_idx = mid_point if pp_rank == 0 else num_layers else: - # Middle ranks - middle_layers = num_layers - num_first - num_last - middle_ranks = pp_size - 2 - - layers_per_mid = middle_layers // middle_ranks - remainder = middle_layers % middle_ranks - - mid_idx = pp_rank - 1 - start_idx = num_first + (mid_idx * layers_per_mid) + min(mid_idx, remainder) - num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) - end_idx = start_idx + num_layers_this_rank + if pp_rank == 0: + start_idx, end_idx = 0, num_first + elif pp_rank == pp_size - 1: + start_idx, end_idx = num_layers - num_last, num_layers + else: + middle_layers = num_layers - num_first - num_last + middle_ranks = pp_size - 2 + + layers_per_mid = middle_layers // middle_ranks + remainder = middle_layers % middle_ranks + + mid_idx = pp_rank - 1 + start_idx = num_first + (mid_idx * layers_per_mid) + min(mid_idx, remainder) + num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) + end_idx = start_idx + num_layers_this_rank - for i in range(start_idx, end_idx): - fqns.append(f"model.language_model.layers.{i}") + fqns.extend([f"model.language_model.layers.{i}" for i in range(start_idx, end_idx)]) if pp_rank == pp_size - 1: - fqns.extend([ - "model.language_model.norm", - "lm_head" - ]) + fqns.extend(["model.language_model.norm", "lm_head"]) return fqns From 579efb32a213d446371bb179a75c434561f916c5 Mon Sep 17 00:00:00 2001 From: tomiock Date: Tue, 28 Apr 2026 23:57:09 +0200 Subject: [PATCH 15/23] init moe --- configs/cvc/moe.toml | 34 +++++++++++ models/qwen3_5/config.py | 16 +++++- models/qwen3_5/model.py | 118 ++++++++++++++++++++++++++++++++++++--- scripts/finetune.sh | 2 +- train/train_qwen.py | 9 +-- train/utils.py | 13 +---- utils/down.py | 4 +- 7 files changed, 166 insertions(+), 30 deletions(-) create mode 100644 configs/cvc/moe.toml diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml new file mode 100644 index 0000000..58a23c4 --- /dev/null +++ b/configs/cvc/moe.toml @@ -0,0 +1,34 @@ +[model] +model_name = "Qwen/Qwen3.5-6B-A3B" +model_impl = "native" + +train_llm = true +train_mlp = true +train_vit = false + +[wandb] +run_name = "moe" +project_name = "test" + +[training] +model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_6b_a3b" +output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" + +save_steps = 10000 +total_steps = 10000 +random_init = false + +tp_size = 1 +pp_size = 1 +data_parallel = 'ddp' + +ac_mode = "full" +compile = false + +[data] +data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" +seq_len = 256 + +packing_buffer_size = 100 + +batch_size = 0 diff --git a/models/qwen3_5/config.py b/models/qwen3_5/config.py index 4acbc46..00f6804 100644 --- a/models/qwen3_5/config.py +++ b/models/qwen3_5/config.py @@ -5,7 +5,7 @@ class Qwen3_5TextConfig: vocab_size: int hidden_size: int - intermediate_size: int + moe_intermediate_size: int num_hidden_layers: int num_attention_heads: int num_key_value_heads: int @@ -14,6 +14,12 @@ class Qwen3_5TextConfig: rms_norm_eps: float tie_word_embeddings: bool + # moe + num_experts: int + num_experts_per_tok: int + router_aux_loss_coef: float + shared_expert_intermediate_size: int + # linear attention layer_types: list[str] full_attention_interval: int @@ -64,7 +70,7 @@ def from_json(cls, path: str) -> "Qwen3VLConfig": text = Qwen3_5TextConfig( vocab_size=tc["vocab_size"], hidden_size=tc["hidden_size"], - intermediate_size=tc["intermediate_size"], + moe_intermediate_size=tc["moe_intermediate_size"], num_hidden_layers=tc["num_hidden_layers"], num_attention_heads=tc["num_attention_heads"], num_key_value_heads=tc["num_key_value_heads"], @@ -81,7 +87,11 @@ def from_json(cls, path: str) -> "Qwen3VLConfig": mtp_num_hidden_layers=tc['mtp_num_hidden_layers'], mtp_use_dedicated_embeddings=tc['mtp_use_dedicated_embeddings'], tie_word_embeddings=tc.get("tie_word_embeddings", raw.get("tie_word_embeddings", False)), - rope_parameters=tc['rope_parameters'] + rope_parameters=tc['rope_parameters'], + num_experts=tc['num_experts'], + num_experts_per_tok=tc['num_experts_per_tok'], + router_aux_loss_coef=tc['router_aux_loss_coef'], + shared_expert_intermediate_size=tc['shared_expert_intermediate_size'], ) vc = raw["vision_config"] vision = Qwen3_5VisionConfig( diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 4f76bc7..9030797 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -231,6 +231,102 @@ def __init__(self, cfg: Qwen3_5TextConfig): def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) +class MoeMLP(nn.Module): + def __init__(self, config: Qwen3_5MoeConfig, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = F.silu + + def forward(self, x): + down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +class MoeExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = F.silu + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + +class TopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices + +class MoE(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = TopKRouter(config) + self.experts = MoeExperts(config) + self.shared_expert = MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + + shared_expert_output = self.shared_expert(hidden_states_reshaped) + + router_logits, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output + expert_output = expert_output + shared_expert_output + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.experts.num_experts) + tokens_per_expert = expert_mask.sum(dim=(0, 1), dtype=torch.float) + router_probs = torch.nn.functional.softmax(router_logits, dim=-1).sum(dim=0) + aux_loss = torch.sum(tokens_per_expert * router_probs) / (batch_size * sequence_length) + + return expert_output.reshape(batch_size, sequence_length, hidden_dim), aux_loss + class DecoderLayer(nn.Module): def __init__(self, cfg: Qwen3_5TextConfig, layer_type: str): super().__init__() @@ -240,7 +336,7 @@ def __init__(self, cfg: Qwen3_5TextConfig, layer_type: str): else: self.linear_attn = GatedDeltaNet(cfg) - self.mlp = MLP(cfg) + self.mlp = MoE(cfg) self.input_layernorm = OffsetRMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) self.post_attention_layernorm = OffsetRMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) @@ -255,8 +351,9 @@ def forward(self, x, cos, sin, cu_seqlens, max_seqlen): cu_seqlens=cu_seqlens, max_seqlen=max_seqlen ) - x = x + self.mlp(self.post_attention_layernorm(x)) - return x + mlp_out, aux_loss = self.mlp(self.post_attention_layernorm(x)) + x = x + mlp_out + return x, aux_loss class LanguageModel(nn.Module): """HF name: `model.language_model`.""" @@ -287,15 +384,18 @@ def forward( deepstack_visual_embeds: list[torch.Tensor] | None = None, ) -> torch.Tensor: x = inputs_embeds + + total_aux_loss = 0 for i, layer in enumerate(self.layers): - x = layer(x, cos, sin, cu_seqlens, max_seqlen) + x, aux_loss = layer(x, cos, sin, cu_seqlens, max_seqlen) + total_aux_loss += aux_loss if deepstack_visual_embeds is not None and i < len(deepstack_visual_embeds): x = x.clone() x[visual_pos_masks] = ( x[visual_pos_masks] + deepstack_visual_embeds[i].to(x.dtype) ) - return self.norm(x) if self.norm is not None else x + return self.norm(x) if self.norm is not None else x, total_aux_loss class VisionPatchEmbed(nn.Module): def __init__(self, cfg: Qwen3_5VisionConfig): @@ -772,7 +872,7 @@ def forward( cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) - h = self.model.language_model( + h, total_aux_loss = self.model.language_model( inputs_embeds, cos, sin, @@ -783,9 +883,9 @@ def forward( ) if getattr(self, "lm_head", None) is not None: - return self.lm_head(h) + return self.lm_head(h), total_aux_loss else: - return h + return h, total_aux_loss @classmethod def from_pretrained( @@ -811,6 +911,7 @@ def from_pretrained( model = model.to_empty(device=device).to(dtype=dtype) + """ load_safetensors_into( model, snapshot_dir, @@ -818,6 +919,7 @@ def from_pretrained( dtype=dtype, load_vision=load_vision, ) + """ if cfg.tie_word_embeddings: model.lm_head.weight = model.model.language_model.embed_tokens.weight diff --git a/scripts/finetune.sh b/scripts/finetune.sh index d6a03dc..3462bfb 100755 --- a/scripts/finetune.sh +++ b/scripts/finetune.sh @@ -11,7 +11,7 @@ else NGPUS=$(echo $CUDA_VISIBLE_DEVICES | grep -o '[^,]\+' | wc -l) fi -torchrun --nproc_per_node=$NGPUS \ +torchrun --nproc_per_node=2 \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ -m train.train_qwen \ diff --git a/train/train_qwen.py b/train/train_qwen.py index 86340e6..68ae8f9 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -296,10 +296,6 @@ def pp_loss_fn(logits, labels): else: raise Exception('invalid sharding strategy for Data Parallel') - if self.pp_size > 1: - if self.training_args.data_parallel == 'ddp': - self.model = replicate(self.model, device_mesh=self.dp_group) - # loading into GPU self.model = self.model.to(device=self.device) if self.training_args.bfloat16: @@ -643,8 +639,9 @@ def _train_step(self, data_iterator, optimizer): s_model = time.perf_counter() with record_function("forward_pass"): with torch.autocast('cuda', torch.bfloat16): - logits = self.model(input_ids, **batch) - loss = causal_lm_loss(logits, labels) + logits, aux_loss = self.model(input_ids, **batch) + ce_loss = causal_lm_loss(logits, labels) + loss = ce_loss + (.01 * aux_loss) with record_function("backward_pass"): scaled_loss = loss / self.current_accum_target diff --git a/train/utils.py b/train/utils.py index 3b13838..9e50527 100644 --- a/train/utils.py +++ b/train/utils.py @@ -515,22 +515,15 @@ def get_dense_model_nparams_and_flops( if isinstance(m, torch.nn.Embedding) ) - if "8B" in model_name: - tied = False - elif "9B" in model_name: - tied = False - elif "27B" in model_name: - tied = False - elif "2B" in model_name: + if "2B" in model_name: tied = True elif "4B" in model_name: tied = True elif "1.7B" in model_name: tied = True else: - # ValueError - return 0, 0 - + tied = False + # we take into account the embedding params num_flops_per_token = 6 * nparams diff --git a/utils/down.py b/utils/down.py index 1428d84..e61fbb9 100644 --- a/utils/down.py +++ b/utils/down.py @@ -1,6 +1,6 @@ from huggingface_hub import snapshot_download snapshot_download( - repo_id="Qwen/Qwen3.5-27B", - local_dir="/e/project1/reformo/ockier1/qwen_models/qwen3_5_27b", + repo_id="Qwen/Qwen3.5-35B-A3B", + local_dir="/data/151-1/users/tockier/qwen_finetune/cache/qwen35_35b_a3", ) From e0696342ce4adca0d9e0b226463dd9d8d9aebddc Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 29 Apr 2026 00:36:43 +0200 Subject: [PATCH 16/23] [feat] moe implemented (bad performance) --- configs/cvc/moe.toml | 6 +-- models/qwen3_5/model.py | 96 +++++++++++++++++++++++++++++++++++------ scripts/finetune.sh | 2 +- train/train_qwen.py | 7 ++- 4 files changed, 94 insertions(+), 17 deletions(-) diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index 58a23c4..d6e88df 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -7,8 +7,8 @@ train_mlp = true train_vit = false [wandb] -run_name = "moe" -project_name = "test" +run_name = "test" +project_name = "moe" [training] model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_6b_a3b" @@ -29,6 +29,6 @@ compile = false data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" seq_len = 256 -packing_buffer_size = 100 +packing_buffer_size = 1000 batch_size = 0 diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 9030797..e82fa95 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -310,6 +310,9 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape + N = batch_size * sequence_length + num_experts = self.experts.num_experts + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) @@ -322,8 +325,12 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.experts.num_experts) tokens_per_expert = expert_mask.sum(dim=(0, 1), dtype=torch.float) + fraction_tokens = tokens_per_expert / (N * self.gate.top_k) + router_probs = torch.nn.functional.softmax(router_logits, dim=-1).sum(dim=0) - aux_loss = torch.sum(tokens_per_expert * router_probs) / (batch_size * sequence_length) + fraction_probs = router_probs.sum(dim=0) / N + + aux_loss = num_experts * torch.sum(fraction_tokens * fraction_probs) return expert_output.reshape(batch_size, sequence_length, hidden_dim), aux_loss @@ -909,17 +916,54 @@ def from_pretrained( with torch.device("meta"): model = cls(cfg) - model = model.to_empty(device=device).to(dtype=dtype) - - """ - load_safetensors_into( - model, - snapshot_dir, - device=device, - dtype=dtype, - load_vision=load_vision, - ) - """ + model = model.to_empty(device='cuda').to(dtype=dtype) + if False: + load_safetensors_into( + model, + snapshot_dir, + device=device, + dtype=dtype, + load_vision=load_vision, + ) + else: + with torch.no_grad(): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, torch.nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, torch.nn.Conv3d) or isinstance(module, torch.nn.Conv1d): + torch.nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + elif "Norm" in module.__class__.__name__: + if hasattr(module, "weight") and module.weight is not None: + if "Offset" in module.__class__.__name__: + torch.nn.init.zeros_(module.weight) + else: + torch.nn.init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + elif "MoeExperts" in module.__class__.__name__: + torch.nn.init.normal_(module.gate_up_proj, mean=0.0, std=0.02) + torch.nn.init.normal_(module.down_proj, mean=0.0, std=0.02) + + elif "TopKRouter" in module.__class__.__name__: + # Router weights are typically initialized to 0 or very small values + torch.nn.init.zeros_(module.weight) + + elif "RMSNormGated" in module.__class__.__name__: + if hasattr(module, "weight") and module.weight is not None: + torch.nn.init.ones_(module.weight) + + elif "OffsetRMSNorm" in module.__class__.__name__: + if hasattr(module, "weight") and module.weight is not None: + # Offset norm weights start at 0 + torch.nn.init.zeros_(module.weight) if cfg.tie_word_embeddings: model.lm_head.weight = model.model.language_model.embed_tokens.weight @@ -944,3 +988,31 @@ def from_pretrained( model.text_inv_freq = text_inv return model + +@torch.no_grad() +def initialize_missing_weights(model): + for module in model.modules(): + if hasattr(module, 'reset_parameters'): + module.reset_parameters() + + elif isinstance(module, OffsetRMSNorm): + torch.nn.init.zeros_(module.weight) + elif isinstance(module, RMSNormGated): + torch.nn.init.ones_(module.weight) + elif isinstance(module, GatedDeltaNet): + torch.nn.init.zeros_(module.A_log) + torch.nn.init.ones_(module.dt_bias) + + elif "MoeExperts" in module.__class__.__name__: + torch.nn.init.normal_(module.gate_up_proj, mean=0.0, std=0.02) + torch.nn.init.normal_(module.down_proj, mean=0.0, std=0.02) + elif "TopKRouter" in module.__class__.__name__: + torch.nn.init.zeros_(module.weight) + + for name, param in model.named_parameters(): + if torch.isnan(param).any() or torch.isinf(param).any(): + print(f"WARNING: Fallback init applied to missed parameter: {name}") + if param.dim() >= 2: + torch.nn.init.normal_(param, mean=0.0, std=0.02) + else: + torch.nn.init.zeros_(param) diff --git a/scripts/finetune.sh b/scripts/finetune.sh index 3462bfb..d6a03dc 100755 --- a/scripts/finetune.sh +++ b/scripts/finetune.sh @@ -11,7 +11,7 @@ else NGPUS=$(echo $CUDA_VISIBLE_DEVICES | grep -o '[^,]\+' | wc -l) fi -torchrun --nproc_per_node=2 \ +torchrun --nproc_per_node=$NGPUS \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ -m train.train_qwen \ diff --git a/train/train_qwen.py b/train/train_qwen.py index 68ae8f9..7b5ad8f 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -38,6 +38,7 @@ compile_model, ) from models.qwen3_5.utils import causal_lm_loss, load_stage_weights +from models.qwen3_5.model import initialize_missing_weights from train.utils import ( set_determinism, @@ -182,7 +183,7 @@ def __init__(self, cfg: Config): # Load the model on CPU with weights; for PP the split stage is moved to GPU below. self.model = select_model_class( - self.model_type, self.model_args, self.training_args, + self.model_type, self.model_args, self.training_args ) # we calculate the flops per token used to get the MFU number @@ -269,6 +270,7 @@ def pp_loss_fn(logits, labels): init_qwen3vl(self.model) else: logger.info('model not initlized, incompatible') + initialize_missing_weights(self.model) self.model.train() if self.training_args.bfloat16: @@ -556,6 +558,9 @@ def log(self, avg_loss, max_loss, global_tokens, global_assistant_tokens, global # GB200 (JUP) and SXM H100 (MN5) peak_tflops_per_gpu = 989.4 + # BLACKWELL 6000 + peak_tflops_per_gpu = 504 + # L40S #peak_tflops_per_gpu = 362 From 7cf1de7c3c0ee15fcc38a0588e97c06cf1a9c57b Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 29 Apr 2026 10:28:16 +0200 Subject: [PATCH 17/23] sync --- configs/cvc/moe.toml | 4 ++-- train/train_qwen.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index d6e88df..b6b6c1a 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -4,7 +4,7 @@ model_impl = "native" train_llm = true train_mlp = true -train_vit = false +train_vit = true [wandb] run_name = "test" @@ -19,7 +19,7 @@ total_steps = 10000 random_init = false tp_size = 1 -pp_size = 1 +pp_size = 2 data_parallel = 'ddp' ac_mode = "full" diff --git a/train/train_qwen.py b/train/train_qwen.py index 984b92d..bf46afd 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -239,7 +239,7 @@ def __init__(self, cfg: Config): target_dtype = torch.bfloat16 if self.training_args.bfloat16 else torch.float32 # Load stage-specific weights when PP > 1 and not using random init - if self.pp_size > 1 and not self.training_args.random_init: + if False: logger.info(f"PP rank {pp_rank}: About to load stage weights for layers {layer_start}-{layer_end}") load_stage_weights( stage=self.model, @@ -556,7 +556,7 @@ def batch_generator(self): yield batch - def log(self, avg_loss, max_loss, global_tokens, global_assistant_tokens, global_samples, lr): + def log(self, avg_loss, aux_loss, max_loss, global_tokens, global_assistant_tokens, global_samples, lr): time_delta = time.perf_counter() - self.time_last_log @@ -584,6 +584,7 @@ def log(self, avg_loss, max_loss, global_tokens, global_assistant_tokens, global logger.info( f"{color.red}step {self.global_step} " f"{color.green}loss {avg_loss:.4f} " + f"{color.green}aux {aux_loss:.4f} " f"{color.blue}tps {tps:.2f} " f"{color.magenta}mfu {mfu:.1f}% " f"{color.reset}" @@ -664,7 +665,7 @@ def _train_step(self, data_iterator, optimizer): scaled_loss.backward() self.fwd_bwd_time = time.perf_counter() - s_model - return self._maybe_optimizer_step(loss, optimizer) + return self._maybe_optimizer_step(loss, ce_loss, aux_loss, optimizer) def train_step(self, data_iterator, optimizer): if self.pp_size == 1: @@ -672,7 +673,7 @@ def train_step(self, data_iterator, optimizer): else: return self._train_step_pp(data_iterator, optimizer) - def _maybe_optimizer_step(self, loss, optimizer): + def _maybe_optimizer_step(self, loss, ce_loss, aux_loss, optimizer): """Shared optimizer-step logic after fwd+bwd (regular and PP paths).""" self.current_accum_count += 1 @@ -684,9 +685,10 @@ def _maybe_optimizer_step(self, loss, optimizer): lr = optimizer.param_groups[0]['lr'] self.global_step += 1 - avg_loss, max_loss, global_tokens, global_assistant, global_samples = ( - dist_mean(loss, self.dp_group), - dist_max(loss, self.dp_group), + avg_loss, aux_loss, max_loss, global_tokens, global_assistant, global_samples = ( + dist_mean(ce_loss, self.dp_group), + dist_mean(aux_loss, self.dp_group), + dist_max(ce_loss, self.dp_group), dist_sum( torch.tensor(self.tokens_seen, dtype=torch.int64, device=self.device), self.dp_group, @@ -706,7 +708,7 @@ def _maybe_optimizer_step(self, loss, optimizer): ) if self.if_log_rank(): - self.log(avg_loss, max_loss, global_tokens, global_assistant, global_samples, lr) + self.log(avg_loss, aux_loss, max_loss, global_tokens, global_assistant, global_samples, lr) self.total_ntokens_since_last_log = 0 self.ntokens_since_last_log = 0 From 22241ccbae2244eea3772087e8b9d16d2a8c56e3 Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 29 Apr 2026 12:02:03 +0200 Subject: [PATCH 18/23] [feat] MoE TP & PP (not at the same time) --- configs/cvc/moe.toml | 4 ++-- models/qwen3_5/model.py | 28 +++++++++++++++++++++++++++- train/infra.py | 7 ++++--- train/train_qwen.py | 22 +++++++++++++++++----- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index b6b6c1a..7fd6da3 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -19,7 +19,7 @@ total_steps = 10000 random_init = false tp_size = 1 -pp_size = 2 +pp_size = 1 data_parallel = 'ddp' ac_mode = "full" @@ -29,6 +29,6 @@ compile = false data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" seq_len = 256 -packing_buffer_size = 1000 +packing_buffer_size = 100 batch_size = 0 diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index e82fa95..29fa02e 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -328,10 +328,13 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens fraction_tokens = tokens_per_expert / (N * self.gate.top_k) router_probs = torch.nn.functional.softmax(router_logits, dim=-1).sum(dim=0) - fraction_probs = router_probs.sum(dim=0) / N + fraction_probs = router_probs / N aux_loss = num_experts * torch.sum(fraction_tokens * fraction_probs) + dummy = (self.experts.gate_up_proj * 0.0).sum() + (self.experts.down_proj * 0.0).sum() + aux_loss = aux_loss + dummy.to(aux_loss.dtype) + return expert_output.reshape(batch_size, sequence_length, hidden_dim), aux_loss class DecoderLayer(nn.Module): @@ -773,6 +776,7 @@ def _compute_cos_sin(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, to def forward( self, hidden_states: torch.Tensor | None = None, + prev_aux_loss: torch.Tensor | None = None, *, input_ids: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, @@ -889,6 +893,28 @@ def forward( deepstack_visual_embeds=deepstack_visual_embeds, ) + if hasattr(total_aux_loss, "to_local"): + total_aux_loss = total_aux_loss.to_local() + + if prev_aux_loss is not None: + if hasattr(prev_aux_loss, "to_local"): + prev_aux_loss = prev_aux_loss.to_local() + + total_aux_loss = total_aux_loss + prev_aux_loss + + if isinstance(total_aux_loss, (int, float)): + total_aux_loss = torch.tensor([total_aux_loss], device=h.device, dtype=h.dtype) + elif total_aux_loss.dim() == 0: + total_aux_loss = total_aux_loss.unsqueeze(0) + + if prev_aux_loss is not None: + total_aux_loss = total_aux_loss + prev_aux_loss + + if isinstance(total_aux_loss, (int, float)): + total_aux_loss = torch.tensor([total_aux_loss], device=h.device, dtype=h.dtype) + elif total_aux_loss.dim() == 0: + total_aux_loss = total_aux_loss.unsqueeze(0) + if getattr(self, "lm_head", None) is not None: return self.lm_head(h), total_aux_loss else: diff --git a/train/infra.py b/train/infra.py index 490c62a..1cde44e 100644 --- a/train/infra.py +++ b/train/infra.py @@ -630,10 +630,11 @@ def _apply_tp_to_decoder_qwen3_5( } layer_plan.update({ - "mlp.gate_proj": colwise_parallel(), - "mlp.down_proj": rowwise_parallel(output_layouts=Replicate()), - "mlp.up_proj": colwise_parallel(), + "mlp.shared_expert.gate_proj": colwise_parallel(), + "mlp.shared_expert.down_proj": rowwise_parallel(output_layouts=Replicate()), + "mlp.shared_expert.up_proj": colwise_parallel(), }) + parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/train/train_qwen.py b/train/train_qwen.py index bf46afd..a78e514 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -263,8 +263,15 @@ def __init__(self, cfg: Config): group=self.mesh.get_group(mesh_dim="pp"), ) - def pp_loss_fn(logits, labels): - return causal_lm_loss(logits, labels) / self.current_accum_target + def pp_loss_fn(outputs, labels): + logits, aux_loss = outputs + ce_loss = causal_lm_loss(logits, labels) + aux_loss = aux_loss.squeeze() + + self._recent_ce_loss = ce_loss.detach() + self._recent_aux_loss = aux_loss.detach() + + return (ce_loss + 0.01 * aux_loss) / self.current_accum_target self.pp_schedule = ScheduleGPipe( self.pp_stage, @@ -275,7 +282,7 @@ def pp_loss_fn(logits, labels): if self.training_args.random_init: if self.model_type == ModelType.Qwen3_5: logger.info('initilizing decoder and projecter of Qwen3.5') - init_qwen35(self.model) + #init_qwen35(self.model) elif self.model_type == ModelType.Qwen3_vl: logger.info('initilizing projector of Qwen3-VL') init_qwen3vl(self.model) @@ -642,11 +649,16 @@ def _train_step_pp(self, data_iterator, optimizer): self.fwd_bwd_time = time.perf_counter() - s_model scaled_loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) - loss_for_logging = scaled_loss * self.current_accum_target torch.distributed.all_reduce(loss_for_logging, group=self.pp_group.get_group()) + + ce_loss = getattr(self, '_recent_ce_loss', torch.tensor(0.0, device=self.device)) + aux_loss = getattr(self, '_recent_aux_loss', torch.tensor(0.0, device=self.device)) + + torch.distributed.all_reduce(ce_loss, group=self.pp_group.get_group()) + torch.distributed.all_reduce(aux_loss, group=self.pp_group.get_group()) - return self._maybe_optimizer_step(loss_for_logging, optimizer) + return self._maybe_optimizer_step(loss_for_logging, ce_loss, aux_loss, optimizer) def _train_step(self, data_iterator, optimizer): batch = next(data_iterator) From 1b0e6c428fe5cf6949d430f1898b4e82ab6b6ed7 Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 29 Apr 2026 16:47:52 +0200 Subject: [PATCH 19/23] test EP --- models/qwen3_5/dev_ep.py | 100 ++++++++++ models/qwen3_5/dispatcher.py | 110 ++++++++++ models/qwen3_5/test_ep.py | 375 +++++++++++++++++++++++++++++++++++ 3 files changed, 585 insertions(+) create mode 100644 models/qwen3_5/dev_ep.py create mode 100644 models/qwen3_5/dispatcher.py create mode 100644 models/qwen3_5/test_ep.py diff --git a/models/qwen3_5/dev_ep.py b/models/qwen3_5/dev_ep.py new file mode 100644 index 0000000..60fa13c --- /dev/null +++ b/models/qwen3_5/dev_ep.py @@ -0,0 +1,100 @@ +import os +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh + +# Import directly from your existing files +from config import Qwen3VLConfig, Qwen3_5TextConfig, Qwen3_5VisionConfig +from model import Qwen3_5ForCausalLM, initialize_missing_weights + +def get_debug_config() -> Qwen3VLConfig: + text_cfg = Qwen3_5TextConfig( + vocab_size=128, hidden_size=64, moe_intermediate_size=128, + num_hidden_layers=2, num_attention_heads=2, num_key_value_heads=1, + head_dim=32, max_position_embeddings=128, rms_norm_eps=1e-6, + tie_word_embeddings=False, num_experts=4, num_experts_per_tok=2, + router_aux_loss_coef=0.01, shared_expert_intermediate_size=128, + layer_types=["linear_attention", "full_attention"], full_attention_interval=2, + linear_conv_kernel_dim=4, linear_key_head_dim=32, linear_num_key_heads=2, + linear_num_value_heads=2, linear_value_head_dim=32, mtp_num_hidden_layers=0, + mtp_use_dedicated_embeddings=False, rope_parameters={"rope_theta": 10000.0, "mrope_section": [16, 16, 16]} + ) + + vision_cfg = Qwen3_5VisionConfig( + depth=1, hidden_size=64, intermediate_size=128, num_heads=2, + in_channels=3, patch_size=14, temporal_patch_size=2, spatial_merge_size=2, + num_position_embeddings=1024, out_hidden_size=64, hidden_act="silu", + deepstack_visual_indexes=[] + ) + + return Qwen3VLConfig( + text=text_cfg, vision=vision_cfg, + image_token_id=120, video_token_id=121, + vision_start_token_id=122, vision_end_token_id=123, + tie_word_embeddings=False, torch_dtype="bfloat16" + ) + +def main(): + # 1. Initialize Distributed Env + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + # Set up a 2D mesh: (dp, ep). For this test, we map all ranks to EP. + dp_size = 1 + ep_size = world_size + mesh = init_device_mesh("cuda", (dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + if rank == 0: + print(f"Initialized DeviceMesh: {mesh}") + + # 2. Build Model + cfg = get_debug_config() + # Note: As you modify model.py to accept DeviceMesh for EP, pass it here + model = Qwen3_5ForCausalLM(cfg).cuda().bfloat16() + initialize_missing_weights(model) + + # 3. Dummy Data (Packed varlen format as expected by your forward pass) + seq_len = 64 + # (1, total) shape expected by your model + input_ids = torch.randint(0, cfg.text.vocab_size, (1, seq_len), device="cuda") + + # attention_mask is used as cu_seqlens in your varlen implementation + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda") + + # 4. Forward Pass + if rank == 0: print("\nStarting Forward Pass...") + logits, aux_loss = model( + input_ids=input_ids, + attention_mask=cu_seqlens + ) + + if rank == 0: + print(f"Forward Success. Logits shape: {logits.shape}, Aux Loss: {aux_loss.item()}") + + # 5. Backward Pass + if rank == 0: print("Starting Backward Pass...") + + # Ensure aux_loss is a scalar for summation + if aux_loss.dim() > 0: + aux_loss = aux_loss.sum() + + loss = logits.sum() + aux_loss + loss.backward() + + # Quick check if gradients flowed through the experts + # Path depends on your current model.py structure, adjust if needed + expert_grad_exists = False + for name, param in model.named_parameters(): + if "experts" in name and param.grad is not None: + expert_grad_exists = True + break + + if rank == 0: + print(f"Backward Success. Expert gradients populated: {expert_grad_exists}\n") + + dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/models/qwen3_5/dispatcher.py b/models/qwen3_5/dispatcher.py new file mode 100644 index 0000000..5f39984 --- /dev/null +++ b/models/qwen3_5/dispatcher.py @@ -0,0 +1,110 @@ +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +class TokenDispatcher: + """Consolidated EP/SP dispatcher. Handles local token reorder and all-to-all.""" + + def __init__(self, num_experts: int, top_k: int, score_before_experts: bool = True): + self.num_experts = num_experts + self.top_k = top_k + self.score_before_experts = score_before_experts + + self.ep_mesh: DeviceMesh | None = None + self.sp_size: int = 1 + self.sp_rank: int = -1 + + def _split_along_sp(self, *tensors: torch.Tensor) -> list[torch.Tensor]: + results = [] + for t in tensors: + local_num_tokens = t.shape[0] // self.sp_size + offset = self.sp_rank * local_num_tokens + results.append(t[offset : offset + local_num_tokens]) + return results + + def _permute(self, routed_input, num_tokens_per_expert_group, ep_size, num_local_experts): + device = num_tokens_per_expert_group.device + total = num_tokens_per_expert_group.sum().item() + + t_mat = num_tokens_per_expert_group.view(ep_size, num_local_experts) + input_starts = (num_tokens_per_expert_group.cumsum(0) - num_tokens_per_expert_group).view(ep_size, num_local_experts) + + segment_lens = t_mat.t().reshape(-1) + input_starts = input_starts.t().reshape(-1) + + seg_ids = torch.arange(segment_lens.shape[0], device=device).repeat_interleave(segment_lens.long()) + output_starts = segment_lens.cumsum(0) - segment_lens + permuted_indices = (input_starts[seg_ids] + torch.arange(total, device=device) - output_starts[seg_ids]).long() + + num_tokens_per_expert = t_mat.sum(0) + return routed_input.shape, routed_input[permuted_indices, :], permuted_indices, num_tokens_per_expert + + def _unpermute(self, routed_output, input_shape, permuted_indices): + out_unpermuted = routed_output.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = routed_output + return out_unpermuted + + def dispatch(self, x: torch.Tensor, top_scores: torch.Tensor, selected_experts_indices: torch.Tensor): + if self.sp_size > 1: + x, top_scores, selected_experts_indices = self._split_along_sp(x, top_scores, selected_experts_indices) + + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1).float(), + bins=self.num_experts, + min=0, + max=self.num_experts - 1, + ) + + token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + routed_input = x[token_indices_experts_sorted] + + if self.score_before_experts: + routed_input = (routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1)).to(x.dtype) + + # Skip all-to-all logic entirely if ep_mesh is missing (EP=1) + if self.ep_mesh is None: + metadata = AllToAllDispatchMetadata( + token_indices_experts_sorted, top_scores_experts_sorted, + None, None, None, None + ) + return routed_input, num_tokens_per_expert, metadata + + ep_size = self.ep_mesh.size() + num_tokens_per_expert_group = torch.empty_like(num_tokens_per_expert).repeat(ep_size) + dist.all_to_all_single(num_tokens_per_expert_group, num_tokens_per_expert, group=self.ep_mesh.get_group()) + + input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1).cpu().tolist() + output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1).cpu().tolist() + + routed_input = all_to_all_single_autograd(routed_input, output_splits, input_splits, self.ep_mesh) + + num_local_experts = num_tokens_per_expert_group.shape[0] // ep_size + input_shape, routed_input, permuted_indices, num_tokens_per_expert_group = self._permute( + routed_input, num_tokens_per_expert_group, ep_size, num_local_experts + ) + + metadata = AllToAllDispatchMetadata( + token_indices_experts_sorted, top_scores_experts_sorted, + input_shape, permuted_indices, input_splits, output_splits + ) + return routed_input, num_tokens_per_expert_group, metadata + + def combine(self, routed_output: torch.Tensor, metadata: "AllToAllDispatchMetadata", x: torch.Tensor, shared_experts: torch.nn.Module | None = None) -> torch.Tensor: + if self.ep_mesh is not None: + routed_output = self._unpermute(routed_output, metadata.input_shape, metadata.permuted_indices) + routed_output = all_to_all_single_autograd(routed_output, metadata.input_splits, metadata.output_splits, self.ep_mesh) + + out = shared_experts(x) if shared_experts is not None else torch.zeros_like(x) + + if not self.score_before_experts: + routed_output = (routed_output.to(torch.float32) * metadata.top_scores_experts_sorted.reshape(-1, 1)).to(routed_output.dtype) + + token_indices_experts_sorted = metadata.token_indices_experts_sorted + if self.sp_size > 1: + local_num_tokens = x.shape[0] // self.sp_size + token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * self.sp_rank + + out.scatter_add_(0, token_indices_experts_sorted.reshape(-1, 1).expand(-1, x.shape[-1]), routed_output) + return out diff --git a/models/qwen3_5/test_ep.py b/models/qwen3_5/test_ep.py new file mode 100644 index 0000000..668a275 --- /dev/null +++ b/models/qwen3_5/test_ep.py @@ -0,0 +1,375 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh, DeviceMesh + +class LocalDispatchMetadata: + def __init__(self, token_indices_experts_sorted, top_scores_experts_sorted): + self.token_indices_experts_sorted = token_indices_experts_sorted + self.top_scores_experts_sorted = top_scores_experts_sorted + +class AllToAllDispatchMetadata: + def __init__(self, token_indices_experts_sorted, top_scores_experts_sorted, input_shape, permuted_indices, input_splits, output_splits): + self.token_indices_experts_sorted = token_indices_experts_sorted + self.top_scores_experts_sorted = top_scores_experts_sorted + self.input_shape = input_shape + self.permuted_indices = permuted_indices + self.input_splits = input_splits + self.output_splits = output_splits + +class _AllToAllSingleAutograd(torch.autograd.Function): + """ + wrapper around all-to-all to add backward pass + """ + @staticmethod + def forward(ctx, input_, output_splits, input_splits, group): + ctx.group = group + ctx.input_splits = input_splits + ctx.output_splits = output_splits + + out_total = int(sum(output_splits)) if output_splits else 0 + out = torch.empty((out_total, input_.size(1)), dtype=input_.dtype, device=input_.device) + + if group is not None: + dist.all_to_all_single(out, input_, output_splits, input_splits, group=group.get_group()) + return out + + @staticmethod + def backward(ctx, grad_output): + in_total = int(sum(ctx.input_splits)) if ctx.input_splits else 0 + grad_input = torch.empty((in_total, grad_output.size(1)), dtype=grad_output.dtype, device=grad_output.device) + + if ctx.group is not None: + dist.all_to_all_single(grad_input, grad_output, ctx.input_splits, ctx.output_splits, group=ctx.group.get_group()) + return grad_input, None, None, None + +def all_to_all_single_autograd(input_, output_splits, input_splits, group): + return _AllToAllSingleAutograd.apply(input_, output_splits, input_splits, group) + +class TokenDispatcher: + """Consolidated EP/SP dispatcher. Handles local token reorder and all-to-all.""" + + def __init__(self, num_experts: int, top_k: int, score_before_experts: bool = True): + self.num_experts = num_experts + self.top_k = top_k + self.score_before_experts = score_before_experts + + self.ep_mesh: DeviceMesh | None = None + self.sp_size: int = 1 + self.sp_rank: int = -1 + + def _split_along_sp(self, *tensors: torch.Tensor) -> list[torch.Tensor]: + results = [] + for t in tensors: + local_num_tokens = t.shape[0] // self.sp_size + offset = self.sp_rank * local_num_tokens + results.append(t[offset : offset + local_num_tokens]) + return results + + def _permute(self, routed_input, num_tokens_per_expert_group, ep_size, num_local_experts): + device = num_tokens_per_expert_group.device + total = num_tokens_per_expert_group.sum().item() + + t_mat = num_tokens_per_expert_group.view(ep_size, num_local_experts) + input_starts = (num_tokens_per_expert_group.cumsum(0) - num_tokens_per_expert_group).view(ep_size, num_local_experts) + + segment_lens = t_mat.t().reshape(-1) + input_starts = input_starts.t().reshape(-1) + + seg_ids = torch.arange(segment_lens.shape[0], device=device).repeat_interleave(segment_lens.long()) + output_starts = segment_lens.cumsum(0) - segment_lens + permuted_indices = (input_starts[seg_ids] + torch.arange(total, device=device) - output_starts[seg_ids]).long() + + num_tokens_per_expert = t_mat.sum(0) + return routed_input.shape, routed_input[permuted_indices, :], permuted_indices, num_tokens_per_expert + + def _unpermute(self, routed_output, input_shape, permuted_indices): + out_unpermuted = routed_output.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = routed_output + return out_unpermuted + + def dispatch(self, x: torch.Tensor, top_scores: torch.Tensor, selected_experts_indices: torch.Tensor): + if self.sp_size > 1: + x, top_scores, selected_experts_indices = self._split_along_sp(x, top_scores, selected_experts_indices) + + flat_experts = selected_experts_indices.view(-1) + num_tokens_per_expert = torch.bincount(flat_experts, minlength=self.num_experts).float() + + token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + routed_input = x[token_indices_experts_sorted] + + if self.score_before_experts: + routed_input = (routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1)).to(x.dtype) + + # Skip all-to-all logic entirely if ep_mesh is missing (EP=1) + if self.ep_mesh is None: + metadata = AllToAllDispatchMetadata( + token_indices_experts_sorted, top_scores_experts_sorted, + None, None, None, None + ) + return routed_input, num_tokens_per_expert, metadata + + ep_size = self.ep_mesh.size() + num_tokens_per_expert_group = torch.empty_like(num_tokens_per_expert) + dist.all_to_all_single(num_tokens_per_expert_group, num_tokens_per_expert, group=self.ep_mesh.get_group()) + + input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1).int().cpu().tolist() + output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1).int().cpu().tolist() + + routed_input = all_to_all_single_autograd(routed_input, output_splits, input_splits, self.ep_mesh) + + num_local_experts = num_tokens_per_expert_group.shape[0] // ep_size + input_shape, routed_input, permuted_indices, num_tokens_per_expert_group = self._permute( + routed_input, num_tokens_per_expert_group, ep_size, num_local_experts + ) + + metadata = AllToAllDispatchMetadata( + token_indices_experts_sorted, top_scores_experts_sorted, + input_shape, permuted_indices, input_splits, output_splits + ) + return routed_input, num_tokens_per_expert_group, metadata + + def combine(self, routed_output: torch.Tensor, metadata: "AllToAllDispatchMetadata", x: torch.Tensor, shared_experts: torch.nn.Module | None = None) -> torch.Tensor: + if self.ep_mesh is not None: + routed_output = self._unpermute(routed_output, metadata.input_shape, metadata.permuted_indices) + routed_output = all_to_all_single_autograd(routed_output, metadata.input_splits, metadata.output_splits, self.ep_mesh) + + out = shared_experts(x) if shared_experts is not None else torch.zeros_like(x) + + if not self.score_before_experts: + routed_output = (routed_output.to(torch.float32) * metadata.top_scores_experts_sorted.reshape(-1, 1)).to(routed_output.dtype) + + token_indices_experts_sorted = metadata.token_indices_experts_sorted + if self.sp_size > 1: + local_num_tokens = x.shape[0] // self.sp_size + token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * self.sp_rank + + out.index_add_(0, token_indices_experts_sorted, routed_output) + return out + + +class MockLocalExperts(nn.Module): + """A dummy expert implementation holding only local weights.""" + def __init__(self, num_local_experts, hidden_dim, intermediate_dim): + super().__init__() + self.num_local_experts = num_local_experts + self.w1 = nn.Parameter(torch.randn(num_local_experts, hidden_dim, intermediate_dim)) + self.w2 = nn.Parameter(torch.randn(num_local_experts, intermediate_dim, hidden_dim)) + + def forward(self, x, num_tokens_per_expert): + """ + x: (total_routed_tokens, hidden_dim) + num_tokens_per_expert: (num_local_experts,) + """ + total_tokens = int(num_tokens_per_expert.sum().item()) + if total_tokens == 0: + return torch.empty_like(x) + if total_tokens == x.shape[0] and self.num_local_experts == 1: + return torch.relu(x @ self.w1[0]) @ self.w2[0] + + offsets = torch.zeros(self.num_local_experts + 1, dtype=torch.long, device=x.device) + offsets[1:] = num_tokens_per_expert.cumsum(0) + + outputs = [] + for i in range(self.num_local_experts): + start, end = int(offsets[i]), int(offsets[i+1]) + if end > start: + chunk = x[start:end] + hidden = torch.relu(chunk @ self.w1[i]) + outputs.append(hidden @ self.w2[i]) + + return torch.cat(outputs, dim=0) if outputs else torch.empty_like(x) + + forward_compiled = None + +class MockLocalExperts(nn.Module): + def __init__(self, num_local_experts, hidden_dim, intermediate_dim): + super().__init__() + self.num_local_experts = num_local_experts + self.w1 = nn.Parameter(torch.randn(num_local_experts, hidden_dim, intermediate_dim)) + self.w2 = nn.Parameter(torch.randn(num_local_experts, intermediate_dim, hidden_dim)) + + def forward(self, x, num_tokens_per_expert): + total_tokens = int(num_tokens_per_expert.sum().item()) + if total_tokens == 0: + return torch.empty_like(x) + if total_tokens == x.shape[0] and self.num_local_experts == 1: + return torch.relu(F.linear(x, self.w1[0].t())) @ self.w2[0].t() + + offsets = torch.zeros(self.num_local_experts + 1, dtype=torch.long, device=x.device) + offsets[1:] = num_tokens_per_expert.cumsum(0) + + outputs = [] + for i in range(self.num_local_experts): + start, end = int(offsets[i]), int(offsets[i+1]) + if end > start: + chunk = x[start:end] + hidden = torch.relu(F.linear(chunk, self.w1[i].t())) + outputs.append(F.linear(hidden, self.w2[i].t())) + return torch.cat(outputs, dim=0) if outputs else torch.empty_like(x) + return compiled_forward + + def forward(self, x, num_tokens_per_expert): + total_tokens = int(num_tokens_per_expert.sum().item()) + if total_tokens == 0: + return torch.empty_like(x) + if total_tokens == x.shape[0] and self.num_local_experts == 1: + return torch.relu(F.linear(x, self.w1[0].t()) @ self.w2[0].t()) + + offsets = torch.zeros(self.num_local_experts + 1, dtype=torch.long, device=x.device) + offsets[1:] = num_tokens_per_expert.cumsum(0) + + outputs = [] + for i in range(self.num_local_experts): + start, end = int(offsets[i]), int(offsets[i+1]) + if end > start: + chunk = x[start:end] + hidden = torch.relu(F.linear(chunk, self.w1[i].t())) + outputs.append(F.linear(hidden, self.w2[i].t())) + return torch.cat(outputs, dim=0) if outputs else torch.empty_like(x) + +def run_ep_benchmark(ep_size: int, num_experts: int = 16, hidden_dim: int = 4096, batch_size: int = 8, seq_len: int = 2048): + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + dp_size = world_size // ep_size + global_mesh = init_device_mesh("cuda", (dp_size, ep_size), mesh_dim_names=("dp", "ep")) + ep_mesh = global_mesh["ep"] + + num_local_experts = num_experts // ep_size + intermediate_dim = 14336 + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + experts = MockLocalExperts(num_local_experts, hidden_dim, intermediate_dim).to(device) + shared_expert = nn.Linear(hidden_dim, hidden_dim).to(device) + dispatcher = TokenDispatcher(num_experts=num_experts, top_k=2, score_before_experts=True) + + if ep_size > 1: + dispatcher.ep_mesh = ep_mesh + + num_tokens = batch_size * seq_len + x = torch.randn(num_tokens, hidden_dim, device=device, requires_grad=True) + router_logits = torch.randn(num_tokens, num_experts, device=device) + top_scores, selected_experts_indices = torch.topk(torch.softmax(router_logits, dim=-1), k=2, dim=-1) + + # WARMUP + for _ in range(3): + r_in, counts, meta = dispatcher.dispatch(x, top_scores, selected_experts_indices) + e_out = experts(r_in, counts) + out = dispatcher.combine(e_out, meta, x, shared_experts=shared_expert) + out.sum().backward() + + torch.cuda.synchronize(device) + + # BENCHMARK + iters = 10 + dispatch_times = [] + expert_times = [] + shared_fwd_times = [] + index_add_times = [] + backward_times = [] + + for _ in range(iters): + torch.cuda.synchronize(device) + + t0 = torch.cuda.Event(enable_timing=True) + t1 = torch.cuda.Event(enable_timing=True) + t2 = torch.cuda.Event(enable_timing=True) + t3 = torch.cuda.Event(enable_timing=True) + t4 = torch.cuda.Event(enable_timing=True) + t5 = torch.cuda.Event(enable_timing=True) + + t0.record() + r_in, counts, meta = dispatcher.dispatch(x, top_scores, selected_experts_indices) + t1.record() + e_out = experts(r_in, counts) + t2.record() + + out = shared_expert(x) + t3.record() + + routed = e_out + if dispatcher.ep_mesh is not None: + routed = dispatcher._unpermute(routed, meta.input_shape, meta.permuted_indices) + routed = all_to_all_single_autograd(routed, meta.input_splits, meta.output_splits, dispatcher.ep_mesh) + + token_indices_experts_sorted = meta.token_indices_experts_sorted + if dispatcher.sp_size > 1: + local_num_tokens = x.shape[0] // dispatcher.sp_size + token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * dispatcher.sp_rank + + out.index_add_(0, token_indices_experts_sorted, routed) + t4.record() + + out.sum().backward() + t5.record() + + torch.cuda.synchronize(device) + dispatch_times.append(t0.elapsed_time(t1)) + expert_times.append(t1.elapsed_time(t2)) + shared_fwd_times.append(t2.elapsed_time(t3)) + index_add_times.append(t3.elapsed_time(t4)) + backward_times.append(t4.elapsed_time(t5)) + + avg_dispatch = sum(dispatch_times) / iters + avg_expert = sum(expert_times) / iters + avg_shared_fwd = sum(shared_fwd_times) / iters + avg_index_add = sum(index_add_times) / iters + avg_backward = sum(backward_times) / iters + avg_time_ms = avg_dispatch + avg_expert + avg_backward + peak_mem_gb = torch.cuda.max_memory_allocated(device) / (1024 ** 3) + + # FLOPs Math: FWD is 2*H*I per weight matrix, BWD is ~2x FWD. Total = 12 * H * I per token. + # Total routed tokens globally = num_tokens * dp_size * top_k + global_routed_tokens = num_tokens * dp_size * dispatcher.top_k + expert_flops_per_iter = 12 * global_routed_tokens * hidden_dim * intermediate_dim + + # Include shared expert FLOPs + shared_flops_per_iter = 12 * (num_tokens * dp_size) * hidden_dim * hidden_dim + total_flops = expert_flops_per_iter + shared_flops_per_iter + + tflops_per_sec = (total_flops / (avg_time_ms / 1000.0)) / (1e12) + tflops_per_gpu = tflops_per_sec / world_size + + dist.barrier() + + if rank == 0: + print(f"--- Configuration: DP={dp_size}, EP={ep_size} ---") + print(f"Peak Memory Allocated: {peak_mem_gb:.2f} GB") + print(f"Dispatch: {avg_dispatch:.2f}ms, Expert: {avg_expert:.2f}ms, SharedFwd: {avg_shared_fwd:.2f}ms, IndexAdd: {avg_index_add:.2f}ms, Backward: {avg_backward:.2f}ms") + print(f"Average Time / Iteration: {avg_time_ms:.2f} ms") + print(f"Total TFLOPS (Cluster): {tflops_per_sec:.2f} TFLOPS") + print(f"Per-GPU TFLOPS: {tflops_per_gpu:.2f} TFLOPS\n") + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True + +if __name__ == "__main__": + dist.init_process_group(backend="nccl") + + world_size = dist.get_world_size() + assert world_size == 4, "This test is designed to run exactly on 4 devices." + + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + + if dist.get_rank() == 0: + print("Starting EP Memory Benchmark...\n") + + # Baseline: EP=1 (All 16 experts on every GPU) + run_ep_benchmark(ep_size=1) + + # Run EP=2 + run_ep_benchmark(ep_size=2) + + # Run EP=4 + run_ep_benchmark(ep_size=4) + + dist.destroy_process_group() From f036910fc38189363cb90ea61d60a3e3cc5d9785 Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 29 Apr 2026 20:54:27 +0200 Subject: [PATCH 20/23] [feat] EP + TP implemented --- configs/cvc/moe.toml | 5 +- models/qwen3_5/dispatcher.py | 127 +++++++++++++++++++++++++++++++---- models/qwen3_5/model.py | 64 ++++++++++++++++-- train/config.py | 1 + train/infra.py | 79 +++++++++++++++++++++- train/train_qwen.py | 47 ++++++++++++- 6 files changed, 295 insertions(+), 28 deletions(-) diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index 7fd6da3..f36057b 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -15,11 +15,12 @@ model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_6b_a3b" output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" save_steps = 10000 -total_steps = 10000 -random_init = false +total_steps = 100 +random_init = true tp_size = 1 pp_size = 1 +ep_size = 4 data_parallel = 'ddp' ac_mode = "full" diff --git a/models/qwen3_5/dispatcher.py b/models/qwen3_5/dispatcher.py index 5f39984..343ef67 100644 --- a/models/qwen3_5/dispatcher.py +++ b/models/qwen3_5/dispatcher.py @@ -2,14 +2,111 @@ import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True + + +class AllToAllDispatchMetadata: + def __init__(self, token_indices_experts_sorted, top_scores_experts_sorted, input_shape, permuted_indices, input_splits, output_splits): + self.token_indices_experts_sorted = token_indices_experts_sorted + self.top_scores_experts_sorted = top_scores_experts_sorted + self.input_shape = input_shape + self.permuted_indices = permuted_indices + self.input_splits = input_splits + self.output_splits = output_splits + + +class _AllToAllSingleAutograd(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, output_splits, input_splits, group): + ctx.group = group + ctx.input_splits = input_splits + ctx.output_splits = output_splits + + out_total = int(sum(output_splits)) if output_splits else 0 + out = torch.empty((out_total, input_.size(1)), dtype=input_.dtype, device=input_.device) + + if group is not None: + dist.all_to_all_single(out, input_, output_splits, input_splits, group=group.get_group()) + return out + + @staticmethod + def backward(ctx, grad_output): + in_total = int(sum(ctx.input_splits)) if ctx.input_splits else 0 + grad_input = torch.empty((in_total, grad_output.size(1)), dtype=grad_output.dtype, device=grad_output.device) + + if ctx.group is not None: + dist.all_to_all_single(grad_input, grad_output, ctx.input_splits, ctx.output_splits, group=ctx.group.get_group()) + return grad_input, None, None, None + + +def all_to_all_single_autograd(input_, output_splits, input_splits, group): + return _AllToAllSingleAutograd.apply(input_, output_splits, input_splits, group) + + +class _AllReduceForward(torch.autograd.Function): + """All-reduce SUM in forward (partial → replicate). Identity in backward. + + Apply at the *output* of a TP-local computation whose value is a per-rank partial + sum. After the forward all-reduce all ranks hold the full sum. In backward, each + rank already sees the same upstream gradient (the output is replicated), so + no further communication is needed. + """ + + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + if group is None: + return input_ + out = input_.contiguous().clone() + dist.all_reduce(out, op=dist.ReduceOp.SUM, group=group) + return out + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _AllReduceBackward(torch.autograd.Function): + """Identity in forward. All-reduce SUM in backward (per-rank partial → replicate gradient). + + Apply at the *input* of a TP-local computation. Each rank's autograd produces a + per-rank partial gradient (because each rank used a different weight slice + downstream); summing across TP gives the full Replicate gradient that the + upstream graph expects. + """ + + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.group is None: + return grad_output, None + g = grad_output.contiguous().clone() + dist.all_reduce(g, op=dist.ReduceOp.SUM, group=ctx.group) + return g, None + + +def all_reduce_forward(input_, group): + return _AllReduceForward.apply(input_, group) + + +def all_reduce_backward(input_, group): + return _AllReduceBackward.apply(input_, group) + + class TokenDispatcher: """Consolidated EP/SP dispatcher. Handles local token reorder and all-to-all.""" - + def __init__(self, num_experts: int, top_k: int, score_before_experts: bool = True): self.num_experts = num_experts self.top_k = top_k self.score_before_experts = score_before_experts - + self.ep_mesh: DeviceMesh | None = None self.sp_size: int = 1 self.sp_rank: int = -1 @@ -40,6 +137,10 @@ def _permute(self, routed_input, num_tokens_per_expert_group, ep_size, num_local return routed_input.shape, routed_input[permuted_indices, :], permuted_indices, num_tokens_per_expert def _unpermute(self, routed_output, input_shape, permuted_indices): + # Empty path (rank received 0 tokens): pass through to preserve the autograd + # edge into the combine A2A so its backward fires on every rank in the EP group. + if routed_output.shape[0] == 0: + return routed_output out_unpermuted = routed_output.new_empty(input_shape) out_unpermuted[permuted_indices, :] = routed_output return out_unpermuted @@ -48,12 +149,10 @@ def dispatch(self, x: torch.Tensor, top_scores: torch.Tensor, selected_experts_i if self.sp_size > 1: x, top_scores, selected_experts_indices = self._split_along_sp(x, top_scores, selected_experts_indices) - num_tokens_per_expert = torch.histc( - selected_experts_indices.view(-1).float(), - bins=self.num_experts, - min=0, - max=self.num_experts - 1, - ) + num_tokens_per_expert = torch.bincount( + selected_experts_indices.view(-1), + minlength=self.num_experts, + ).float() token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] @@ -66,17 +165,19 @@ def dispatch(self, x: torch.Tensor, top_scores: torch.Tensor, selected_experts_i # Skip all-to-all logic entirely if ep_mesh is missing (EP=1) if self.ep_mesh is None: metadata = AllToAllDispatchMetadata( - token_indices_experts_sorted, top_scores_experts_sorted, + token_indices_experts_sorted, top_scores_experts_sorted, None, None, None, None ) return routed_input, num_tokens_per_expert, metadata ep_size = self.ep_mesh.size() - num_tokens_per_expert_group = torch.empty_like(num_tokens_per_expert).repeat(ep_size) + # equal-split all-to-all: each rank sends num_local_experts counts to every rank, + # receiving num_local_experts from every rank → output same size as input. + num_tokens_per_expert_group = torch.empty_like(num_tokens_per_expert) dist.all_to_all_single(num_tokens_per_expert_group, num_tokens_per_expert, group=self.ep_mesh.get_group()) - input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1).cpu().tolist() - output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1).cpu().tolist() + input_splits = [int(x) for x in num_tokens_per_expert.long().view(ep_size, -1).sum(dim=1).tolist()] + output_splits = [int(x) for x in num_tokens_per_expert_group.long().view(ep_size, -1).sum(dim=1).tolist()] routed_input = all_to_all_single_autograd(routed_input, output_splits, input_splits, self.ep_mesh) @@ -106,5 +207,5 @@ def combine(self, routed_output: torch.Tensor, metadata: "AllToAllDispatchMetada local_num_tokens = x.shape[0] // self.sp_size token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * self.sp_rank - out.scatter_add_(0, token_indices_experts_sorted.reshape(-1, 1).expand(-1, x.shape[-1]), routed_output) + out.index_add_(0, token_indices_experts_sorted, routed_output) return out diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 29fa02e..076c640 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -274,14 +274,56 @@ def forward( continue top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + gate_up = F.linear(current_state, self.gate_up_proj[expert_idx]) + gate, up = gate_up.chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up - current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states + def forward_ep(self, routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + """Forward for EP mode: routed_input is already sorted/dispatched to local experts. + + With EP+TP, gate_up_proj and down_proj are sharded along the intermediate + dimension across TP. The input is Replicate across TP and the output is + Replicate after a forward all-reduce. The backward all-reduce on the input + is required so the gradient flowing back into the dispatch A2A is also + Replicate across TP. + """ + tp_mesh = getattr(self, 'tp_mesh', None) + use_tp = tp_mesh is not None and tp_mesh.size() > 1 + if use_tp: + from models.qwen3_5.dispatcher import all_reduce_backward + routed_input = all_reduce_backward(routed_input, tp_mesh.get_group()) + + num_local_experts = self.gate_up_proj.shape[0] + offsets = torch.zeros(num_local_experts + 1, dtype=torch.long, device=routed_input.device) + offsets[1:] = num_tokens_per_expert.long().cumsum(0) + + outputs = [] + for i in range(num_local_experts): + start, end = int(offsets[i]), int(offsets[i + 1]) + if end > start: + chunk = routed_input[start:end] + gate_up = F.linear(chunk, self.gate_up_proj[i]) + gate, up = gate_up.chunk(2, dim=-1) + outputs.append(F.linear(self.act_fn(gate) * up, self.down_proj[i])) + + if outputs: + result = torch.cat(outputs, dim=0) + else: + # Empty path: preserve autograd connection to routed_input so the EP A2A + # backward fires on this rank too (otherwise other ranks hang on the unmatched collective). + result = routed_input[:0] + + if use_tp: + from models.qwen3_5.dispatcher import all_reduce_forward + result = all_reduce_forward(result, tp_mesh.get_group()) + + return result + class TopKRouter(nn.Module): def __init__(self, config): super().__init__() @@ -307,6 +349,7 @@ def __init__(self, config): self.experts = MoeExperts(config) self.shared_expert = MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + self.dispatcher = None # set by apply_ep() when EP > 1 def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -315,15 +358,22 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - shared_expert_output = self.shared_expert(hidden_states_reshaped) - router_logits, routing_weights, selected_experts = self.gate(hidden_states_reshaped) - expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + if self.dispatcher is not None: + routed_input, num_tokens_per_expert, metadata = self.dispatcher.dispatch( + hidden_states_reshaped, routing_weights, selected_experts + ) + routed_output = self.experts.forward_ep(routed_input, num_tokens_per_expert) + expert_output = self.dispatcher.combine(routed_output, metadata, hidden_states_reshaped) + else: + expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + + shared_expert_output = self.shared_expert(hidden_states_reshaped) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output expert_output = expert_output + shared_expert_output - - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.experts.num_experts) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts) tokens_per_expert = expert_mask.sum(dim=(0, 1), dtype=torch.float) fraction_tokens = tokens_per_expert / (N * self.gate.top_k) diff --git a/train/config.py b/train/config.py index 3b6dc21..a390d8c 100644 --- a/train/config.py +++ b/train/config.py @@ -96,6 +96,7 @@ class Training: tp_size: int = 1 # 1 means disabled pp_size: int = 1 # 1 means disabled; supported values: 2, 2, 4 + ep_size: int = 1 # 1 means disabled; must divide num_experts evenly pp_num_layers_first: int = 1 pp_num_layers_last: int = 1 diff --git a/train/infra.py b/train/infra.py index 1cde44e..2b3e43f 100644 --- a/train/infra.py +++ b/train/infra.py @@ -118,18 +118,25 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: def get_mesh(training_args, world_size): tp_size = training_args.tp_size pp_size = getattr(training_args, "pp_size", 1) + ep_size = getattr(training_args, "ep_size", 1) - if world_size % (tp_size * pp_size) != 0: + if world_size % (tp_size * pp_size * ep_size) != 0: raise ValueError( - f"world_size {world_size} not divisible by tp_size*pp_size={tp_size * pp_size}" + f"world_size {world_size} not divisible by tp_size*pp_size*ep_size={tp_size * pp_size * ep_size}" ) + dp_size = world_size // (tp_size * pp_size * ep_size) - dp_size = world_size // (tp_size * pp_size) + if pp_size > 1 and ep_size > 1: + raise NotImplementedError("PP + EP is not yet supported") if pp_size > 1: return init_device_mesh( "cuda", (dp_size, pp_size, tp_size), mesh_dim_names=("dp", "pp", "tp") ) + if ep_size > 1: + return init_device_mesh( + "cuda", (dp_size, ep_size, tp_size), mesh_dim_names=("dp", "ep", "tp") + ) return init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) def get_tp_group(mesh): @@ -147,6 +154,72 @@ def get_pp_group(mesh): return mesh["pp"] return None +def get_ep_group(mesh): + if "ep" in mesh.mesh_dim_names: + return mesh["ep"] + return None + +def apply_ep(model, ep_mesh, tp_mesh=None): + """Slice expert parameters to the local subset and attach a TokenDispatcher. + + EP shards the routed experts along ``num_experts``. When ``tp_mesh`` is provided + (EP+TP), each rank additionally holds only its slice of ``moe_intermediate_size``; + the partial down-projection is all-reduced across TP at the end of ``forward_ep``. + The shared_expert is sharded by ``apply_tp`` separately. + """ + from models.qwen3_5.dispatcher import TokenDispatcher + + ep_rank = ep_mesh.get_local_rank() + ep_size = ep_mesh.size() + tp_rank = tp_mesh.get_local_rank() if tp_mesh is not None else 0 + tp_size = tp_mesh.size() if tp_mesh is not None else 1 + + lm = model.model.language_model + for layer in lm.layers: + moe = layer.mlp + experts = moe.experts + num_experts = experts.num_experts + moe_inter = experts.intermediate_dim + + if num_experts % ep_size != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by ep_size={ep_size}" + ) + if tp_size > 1 and moe_inter % tp_size != 0: + raise ValueError( + f"moe_intermediate_size={moe_inter} must be divisible by tp_size={tp_size}" + ) + + num_local = num_experts // ep_size + e_start, e_end = ep_rank * num_local, (ep_rank + 1) * num_local + local_inter = moe_inter // tp_size + i_start, i_end = tp_rank * local_inter, (tp_rank + 1) * local_inter + + # gate_up_proj: [E, 2*I, H] (the 2*I is laid out as [gate(I) | up(I)]). + # Take the EP slice, then within each of gate and up keep only this TP rank's + # I/tp slice and re-concat so the local layout stays [gate_local | up_local]. + gate_up = experts.gate_up_proj.data[e_start:e_end] + if tp_size > 1: + gate_part = gate_up[:, :moe_inter, :][:, i_start:i_end, :] + up_part = gate_up[:, moe_inter:, :][:, i_start:i_end, :] + gate_up = torch.cat([gate_part, up_part], dim=1) + experts.gate_up_proj = nn.Parameter(gate_up.contiguous()) + + # down_proj: [E, H, I] → [E_local, H, I/tp] + down = experts.down_proj.data[e_start:e_end, :, i_start:i_end] + experts.down_proj = nn.Parameter(down.contiguous()) + + # forward_ep needs the TP mesh to all-reduce the partial down-projection + experts.tp_mesh = tp_mesh if tp_size > 1 else None + + dispatcher = TokenDispatcher( + num_experts=num_experts, + top_k=moe.gate.top_k, + score_before_experts=True, + ) + dispatcher.ep_mesh = ep_mesh + moe.dispatcher = dispatcher + def module_filter_float8_fn(mod: torch.nn.Module, fqn: str): if "visual" in fqn: return False diff --git a/train/train_qwen.py b/train/train_qwen.py index a78e514..50f8445 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -11,6 +11,7 @@ from transformers import AutoProcessor +import torch.distributed as dist from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed._composable.replicate import replicate from torch.distributed.pipelining import PipelineStage @@ -31,8 +32,10 @@ get_tp_group, get_dp_group, get_pp_group, + get_ep_group, apply_fsdp, apply_tp, + apply_ep, apply_ac, ACConfig, compile_model, @@ -133,7 +136,9 @@ def __init__(self, cfg: Config): self.tp_group = get_tp_group(self.mesh) self.dp_group = get_dp_group(self.mesh) self.pp_group = get_pp_group(self.mesh) + self.ep_group = get_ep_group(self.mesh) self.pp_size = getattr(self.training_args, "pp_size", 1) + self.ep_size = getattr(self.training_args, "ep_size", 1) self.device = torch.device(f"cuda:{self.local_rank}") if self.if_log_rank(): @@ -299,6 +304,13 @@ def pp_loss_fn(outputs, labels): if self.training_args.tp_size > 1: apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) + if self.ep_size > 1: + if self.model_type != ModelType.Qwen3_5: + raise NotImplementedError("EP is only supported for Qwen3.5 MoE models") + tp_mesh = self.tp_group if self.training_args.tp_size > 1 else None + apply_ep(self.model, self.ep_group, tp_mesh=tp_mesh) + logger.info(f"expert parallelism applied (ep_size={self.ep_size}, tp_size={self.training_args.tp_size})") + ac_mode = getattr(self.training_args, "ac_mode", "off") if ac_mode != "off": ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) @@ -312,7 +324,14 @@ def pp_loss_fn(outputs, labels): if self.training_args.data_parallel == 'fsdp': apply_fsdp(self.model_type, self.model, mesh=self.dp_group) elif self.training_args.data_parallel == 'ddp': - self.model = replicate(self.model, device_mesh=self.dp_group) + if self.ep_size > 1 and self.dp_group.size() > 1: + # Skip DDP hook-based all_reduce when EP is active to avoid a NCCL deadlock: + # DDP hooks fire async on dp_comm while EP backward A2As block on ep_comm, + # creating a cross-communicator cycle. Gradients are synced manually after backward. + logger.info(f"rank={self.rank()} EP+DP: skipping replicate(), will manually sync grads (ep={self.ep_size}, dp={self.dp_group.size()})") + elif self.dp_group.size() > 1: + self.model = replicate(self.model, device_mesh=self.dp_group) + logger.info(f"rank={self.rank()} DDP applied (dp={self.dp_group.size()})") else: raise Exception('invalid sharding strategy for Data Parallel') @@ -341,7 +360,7 @@ def pp_loss_fn(outputs, labels): self.processor = AutoProcessor.from_pretrained( self.training_args.model_dir, - + ) # set_model freezes/unfreezes param groups; skip for PP (stage module @@ -660,6 +679,23 @@ def _train_step_pp(self, data_iterator, optimizer): return self._maybe_optimizer_step(loss_for_logging, ce_loss, aux_loss, optimizer) + def _sync_gradients(self): + """All_reduce gradients across dp_group. Used instead of DDP hooks when EP is active.""" + dp_size = self.dp_group.size() + if dp_size <= 1: + return + grp = self.dp_group.get_group() + from torch.distributed.tensor import DTensor + for p in self.model.parameters(): + if p.grad is None: + continue + g = p.grad + # TP-sharded params (e.g. shared_expert.*) have DTensor grads; reduce the local shard. + if isinstance(g, DTensor): + g = g.to_local() + dist.all_reduce(g, group=grp) + g.div_(dp_size) + def _train_step(self, data_iterator, optimizer): batch = next(data_iterator) input_ids = batch.pop('input_ids') @@ -676,6 +712,11 @@ def _train_step(self, data_iterator, optimizer): scaled_loss = loss / self.current_accum_target scaled_loss.backward() + if self.ep_size > 1 and self.dp_group.size() > 1: + is_last_accum = (self.current_accum_count + 1 >= self.current_accum_target) + if is_last_accum: + self._sync_gradients() + self.fwd_bwd_time = time.perf_counter() - s_model return self._maybe_optimizer_step(loss, ce_loss, aux_loss, optimizer) @@ -770,7 +811,7 @@ def trace_handler(prof): try: while self.global_step < self.training_args.total_steps: self.micro_step += 1 - + # training step executed here optimizer_updated = self.train_step(data_iterator, optimizer) From d61392e30e721a2ee91badbab1c29d77807b8496 Mon Sep 17 00:00:00 2001 From: tomiock Date: Wed, 29 Apr 2026 21:41:05 +0200 Subject: [PATCH 21/23] [feat] 4D parallelism --- configs/cvc/moe.toml | 3 +- train/config.py | 6 +++ train/infra.py | 7 +++- train/train_qwen.py | 99 ++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 103 insertions(+), 12 deletions(-) diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index f36057b..1408b1c 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -12,6 +12,7 @@ project_name = "moe" [training] model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_6b_a3b" +#model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_35b_a3b" output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" save_steps = 10000 @@ -20,7 +21,7 @@ random_init = true tp_size = 1 pp_size = 1 -ep_size = 4 +ep_size = 1 data_parallel = 'ddp' ac_mode = "full" diff --git a/train/config.py b/train/config.py index a390d8c..7f4061f 100644 --- a/train/config.py +++ b/train/config.py @@ -101,6 +101,12 @@ class Training: pp_num_layers_first: int = 1 pp_num_layers_last: int = 1 + # Pipeline schedule: "gpipe" or "1f1b". Single-stage-per-rank schedules only. + pp_schedule: str = "gpipe" + # Number of microbatches per optimizer step. Must be >= pp_size for 1F1B + # to actually pipeline (smaller values degrade to GPipe-like behavior). + pp_microbatches: int = 1 + # compiler flag for TP (goes faster) async_tp: bool = True diff --git a/train/infra.py b/train/infra.py index 2b3e43f..1020b43 100644 --- a/train/infra.py +++ b/train/infra.py @@ -127,8 +127,11 @@ def get_mesh(training_args, world_size): dp_size = world_size // (tp_size * pp_size * ep_size) if pp_size > 1 and ep_size > 1: - raise NotImplementedError("PP + EP is not yet supported") - + return init_device_mesh( + "cuda", + (dp_size, pp_size, ep_size, tp_size), + mesh_dim_names=("dp", "pp", "ep", "tp"), + ) if pp_size > 1: return init_device_mesh( "cuda", (dp_size, pp_size, tp_size), mesh_dim_names=("dp", "pp", "tp") diff --git a/train/train_qwen.py b/train/train_qwen.py index 50f8445..07611e2 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -15,7 +15,8 @@ from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed._composable.replicate import replicate from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import ScheduleGPipe +from torch.distributed.pipelining.microbatch import _Replicate +from torch.distributed.pipelining.schedules import Schedule1F1B, ScheduleGPipe from torch.profiler import profile, record_function, ProfilerActivity, schedule @@ -278,11 +279,46 @@ def pp_loss_fn(outputs, labels): return (ce_loss + 0.01 * aux_loss) / self.current_accum_target - self.pp_schedule = ScheduleGPipe( + schedule_name = getattr(self.training_args, "pp_schedule", "gpipe").lower() + n_microbatches = getattr(self.training_args, "pp_microbatches", 1) + schedule_cls = {"gpipe": ScheduleGPipe, "1f1b": Schedule1F1B}.get(schedule_name) + if schedule_cls is None: + raise ValueError( + f"unknown pp_schedule={schedule_name!r}; expected one of: gpipe, 1f1b" + ) + if schedule_cls is Schedule1F1B: + # The plumbing exists (schedule construction, kwargs_chunk_spec, + # tiled input in _train_step_pp), but 1F1B requires + # n_microbatches >= pp_size and the dataloader currently emits + # one packed (1, total) sample per step — so microbatches are + # tiled copies of the same content and the loss is meaningless. + # Re-enable once the data path produces n_microbatches independent + # packed rows per step (per-row cu_seqlens, labels, image scatter). + raise NotImplementedError( + "pp_schedule='1f1b' is disabled until the data path supports " + "n_microbatches independent packed rows per step. Use 'gpipe' " + "for now." + ) + # The dataloader emits a single packed (1, total) sample per step. + # When n_microbatches > 1 we tile input_ids/labels to (N, total) so + # the schedule can chunk along dim 0; everything else (cu_seqlens, + # pixel_values, image_grid_thw, etc.) is per-batch metadata that + # must be passed identically to every microbatch — mark it replicate. + self.pp_microbatches = n_microbatches + kwargs_chunk_spec = { + k: _Replicate() for k in ( + "input_ids", "attention_mask", "original_mask", + "image_grid_thw", "pixel_values", + "pixel_values_videos", "video_grid_thw", + ) + } + self.pp_schedule = schedule_cls( self.pp_stage, - n_microbatches=1, + n_microbatches=n_microbatches, loss_fn=pp_loss_fn, + kwargs_chunk_spec=kwargs_chunk_spec, ) + logger.info(f"PP schedule: {schedule_name} (n_microbatches={n_microbatches})") if self.training_args.random_init: if self.model_type == ModelType.Qwen3_5: @@ -649,22 +685,36 @@ def setup_accumulation(self, tpi_multiplier=1.5): def _train_step_pp(self, data_iterator, optimizer): batch = next(data_iterator) - + labels = batch.pop('labels', None) input_ids = batch.pop('input_ids') batch['input_ids'] = input_ids + # Schedule chunks the positional input_ids and target along dim 0; + # tile the (1, total) packed sample to (N, total) so n_microbatches > 1 + # produces N actual chunks. Each microbatch is identical content — fine + # for benchmarking the schedule, not for real training. + n = self.pp_microbatches + tiled_input_ids = input_ids.repeat(n, 1) if n > 1 else input_ids + tiled_labels = labels.repeat(n, 1) if (n > 1 and labels is not None) else labels + losses = [] if self.pp_has_last_stage else None - target = labels if self.pp_has_last_stage else None + target = tiled_labels if self.pp_has_last_stage else None s_model = time.perf_counter() with record_function("pp_forward_backward"): with torch.autocast('cuda', torch.bfloat16): if self.pp_has_first_stage: - self.pp_schedule.step(input_ids, **batch, target=target, losses=losses) + self.pp_schedule.step(tiled_input_ids, **batch, target=target, losses=losses) else: self.pp_schedule.step(**batch, target=target, losses=losses) + if self.ep_size > 1 and self.dp_group.size() > 1: + is_last_accum = (self.current_accum_count + 1 >= self.current_accum_target) + if is_last_accum: + # we use a custom bucking system instead of the replicate hooks + self._sync_gradients() + self.fwd_bwd_time = time.perf_counter() - s_model scaled_loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) @@ -680,12 +730,21 @@ def _train_step_pp(self, data_iterator, optimizer): return self._maybe_optimizer_step(loss_for_logging, ce_loss, aux_loss, optimizer) def _sync_gradients(self): - """All_reduce gradients across dp_group. Used instead of DDP hooks when EP is active.""" + """Bucketed grad all_reduce across dp_group. One collective per ~25 MB + bucket (per dtype) instead of one per parameter, so DP scales by NCCL + bandwidth rather than per-launch latency. Used instead of DDP hooks when + EP is active. + """ dp_size = self.dp_group.size() if dp_size <= 1: return grp = self.dp_group.get_group() + from torch.distributed.tensor import DTensor + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + # NCCL all_reduce requires uniform dtype within a call. + by_dtype: dict[torch.dtype, list[torch.Tensor]] = {} for p in self.model.parameters(): if p.grad is None: continue @@ -693,8 +752,30 @@ def _sync_gradients(self): # TP-sharded params (e.g. shared_expert.*) have DTensor grads; reduce the local shard. if isinstance(g, DTensor): g = g.to_local() - dist.all_reduce(g, group=grp) - g.div_(dp_size) + by_dtype.setdefault(g.dtype, []).append(g) + + bucket_max_elems = 25 * 1024 * 1024 # ~50 MB at bf16, ~100 MB at fp32 + inv_dp = 1.0 / dp_size + + def _flush(bucket: list[torch.Tensor]) -> None: + flat = _flatten_dense_tensors(bucket) + dist.all_reduce(flat, group=grp) + flat.mul_(inv_dp) + for g, synced in zip(bucket, _unflatten_dense_tensors(flat, bucket)): + g.copy_(synced) + + for grads in by_dtype.values(): + bucket: list[torch.Tensor] = [] + bucket_elems = 0 + for g in grads: + n = g.numel() + if bucket and bucket_elems + n > bucket_max_elems: + _flush(bucket) + bucket, bucket_elems = [], 0 + bucket.append(g) + bucket_elems += n + if bucket: + _flush(bucket) def _train_step(self, data_iterator, optimizer): batch = next(data_iterator) From ae04d254f243610beffb3076fcf634082ace993a Mon Sep 17 00:00:00 2001 From: tomiock Date: Sat, 2 May 2026 12:04:42 +0200 Subject: [PATCH 22/23] [feat] properly behaving 4D moe set up --- configs/cvc/moe.toml | 14 +- models/qwen3_5/dev_ep.py | 100 ---------- models/qwen3_5/test_ep.py | 375 -------------------------------------- train/config.py | 18 +- train/infra.py | 257 ++++++++++++++++++++++---- train/train_qwen.py | 264 +++++++-------------------- 6 files changed, 297 insertions(+), 731 deletions(-) delete mode 100644 models/qwen3_5/dev_ep.py delete mode 100644 models/qwen3_5/test_ep.py diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index 1408b1c..648b505 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -1,5 +1,5 @@ [model] -model_name = "Qwen/Qwen3.5-6B-A3B" +model_name = "Qwen/Qwen3.5-8B-A1B" model_impl = "native" train_llm = true @@ -11,7 +11,7 @@ run_name = "test" project_name = "moe" [training] -model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_6b_a3b" +model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_8b_a1b" #model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_35b_a3b" output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" @@ -19,17 +19,19 @@ save_steps = 10000 total_steps = 100 random_init = true +compile = false + +[parallel] tp_size = 1 pp_size = 1 ep_size = 1 -data_parallel = 'ddp' +data_parallel = 'fsdp' -ac_mode = "full" -compile = false +ac_mode = "off" [data] data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" -seq_len = 256 +seq_len = 512 packing_buffer_size = 100 diff --git a/models/qwen3_5/dev_ep.py b/models/qwen3_5/dev_ep.py deleted file mode 100644 index 60fa13c..0000000 --- a/models/qwen3_5/dev_ep.py +++ /dev/null @@ -1,100 +0,0 @@ -import os -import torch -import torch.distributed as dist -from torch.distributed.device_mesh import init_device_mesh - -# Import directly from your existing files -from config import Qwen3VLConfig, Qwen3_5TextConfig, Qwen3_5VisionConfig -from model import Qwen3_5ForCausalLM, initialize_missing_weights - -def get_debug_config() -> Qwen3VLConfig: - text_cfg = Qwen3_5TextConfig( - vocab_size=128, hidden_size=64, moe_intermediate_size=128, - num_hidden_layers=2, num_attention_heads=2, num_key_value_heads=1, - head_dim=32, max_position_embeddings=128, rms_norm_eps=1e-6, - tie_word_embeddings=False, num_experts=4, num_experts_per_tok=2, - router_aux_loss_coef=0.01, shared_expert_intermediate_size=128, - layer_types=["linear_attention", "full_attention"], full_attention_interval=2, - linear_conv_kernel_dim=4, linear_key_head_dim=32, linear_num_key_heads=2, - linear_num_value_heads=2, linear_value_head_dim=32, mtp_num_hidden_layers=0, - mtp_use_dedicated_embeddings=False, rope_parameters={"rope_theta": 10000.0, "mrope_section": [16, 16, 16]} - ) - - vision_cfg = Qwen3_5VisionConfig( - depth=1, hidden_size=64, intermediate_size=128, num_heads=2, - in_channels=3, patch_size=14, temporal_patch_size=2, spatial_merge_size=2, - num_position_embeddings=1024, out_hidden_size=64, hidden_act="silu", - deepstack_visual_indexes=[] - ) - - return Qwen3VLConfig( - text=text_cfg, vision=vision_cfg, - image_token_id=120, video_token_id=121, - vision_start_token_id=122, vision_end_token_id=123, - tie_word_embeddings=False, torch_dtype="bfloat16" - ) - -def main(): - # 1. Initialize Distributed Env - dist.init_process_group(backend="nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - torch.cuda.set_device(rank) - - # Set up a 2D mesh: (dp, ep). For this test, we map all ranks to EP. - dp_size = 1 - ep_size = world_size - mesh = init_device_mesh("cuda", (dp_size, ep_size), mesh_dim_names=("dp", "ep")) - - if rank == 0: - print(f"Initialized DeviceMesh: {mesh}") - - # 2. Build Model - cfg = get_debug_config() - # Note: As you modify model.py to accept DeviceMesh for EP, pass it here - model = Qwen3_5ForCausalLM(cfg).cuda().bfloat16() - initialize_missing_weights(model) - - # 3. Dummy Data (Packed varlen format as expected by your forward pass) - seq_len = 64 - # (1, total) shape expected by your model - input_ids = torch.randint(0, cfg.text.vocab_size, (1, seq_len), device="cuda") - - # attention_mask is used as cu_seqlens in your varlen implementation - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda") - - # 4. Forward Pass - if rank == 0: print("\nStarting Forward Pass...") - logits, aux_loss = model( - input_ids=input_ids, - attention_mask=cu_seqlens - ) - - if rank == 0: - print(f"Forward Success. Logits shape: {logits.shape}, Aux Loss: {aux_loss.item()}") - - # 5. Backward Pass - if rank == 0: print("Starting Backward Pass...") - - # Ensure aux_loss is a scalar for summation - if aux_loss.dim() > 0: - aux_loss = aux_loss.sum() - - loss = logits.sum() + aux_loss - loss.backward() - - # Quick check if gradients flowed through the experts - # Path depends on your current model.py structure, adjust if needed - expert_grad_exists = False - for name, param in model.named_parameters(): - if "experts" in name and param.grad is not None: - expert_grad_exists = True - break - - if rank == 0: - print(f"Backward Success. Expert gradients populated: {expert_grad_exists}\n") - - dist.destroy_process_group() - -if __name__ == "__main__": - main() diff --git a/models/qwen3_5/test_ep.py b/models/qwen3_5/test_ep.py deleted file mode 100644 index 668a275..0000000 --- a/models/qwen3_5/test_ep.py +++ /dev/null @@ -1,375 +0,0 @@ -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -from torch.distributed.device_mesh import init_device_mesh, DeviceMesh - -class LocalDispatchMetadata: - def __init__(self, token_indices_experts_sorted, top_scores_experts_sorted): - self.token_indices_experts_sorted = token_indices_experts_sorted - self.top_scores_experts_sorted = top_scores_experts_sorted - -class AllToAllDispatchMetadata: - def __init__(self, token_indices_experts_sorted, top_scores_experts_sorted, input_shape, permuted_indices, input_splits, output_splits): - self.token_indices_experts_sorted = token_indices_experts_sorted - self.top_scores_experts_sorted = top_scores_experts_sorted - self.input_shape = input_shape - self.permuted_indices = permuted_indices - self.input_splits = input_splits - self.output_splits = output_splits - -class _AllToAllSingleAutograd(torch.autograd.Function): - """ - wrapper around all-to-all to add backward pass - """ - @staticmethod - def forward(ctx, input_, output_splits, input_splits, group): - ctx.group = group - ctx.input_splits = input_splits - ctx.output_splits = output_splits - - out_total = int(sum(output_splits)) if output_splits else 0 - out = torch.empty((out_total, input_.size(1)), dtype=input_.dtype, device=input_.device) - - if group is not None: - dist.all_to_all_single(out, input_, output_splits, input_splits, group=group.get_group()) - return out - - @staticmethod - def backward(ctx, grad_output): - in_total = int(sum(ctx.input_splits)) if ctx.input_splits else 0 - grad_input = torch.empty((in_total, grad_output.size(1)), dtype=grad_output.dtype, device=grad_output.device) - - if ctx.group is not None: - dist.all_to_all_single(grad_input, grad_output, ctx.input_splits, ctx.output_splits, group=ctx.group.get_group()) - return grad_input, None, None, None - -def all_to_all_single_autograd(input_, output_splits, input_splits, group): - return _AllToAllSingleAutograd.apply(input_, output_splits, input_splits, group) - -class TokenDispatcher: - """Consolidated EP/SP dispatcher. Handles local token reorder and all-to-all.""" - - def __init__(self, num_experts: int, top_k: int, score_before_experts: bool = True): - self.num_experts = num_experts - self.top_k = top_k - self.score_before_experts = score_before_experts - - self.ep_mesh: DeviceMesh | None = None - self.sp_size: int = 1 - self.sp_rank: int = -1 - - def _split_along_sp(self, *tensors: torch.Tensor) -> list[torch.Tensor]: - results = [] - for t in tensors: - local_num_tokens = t.shape[0] // self.sp_size - offset = self.sp_rank * local_num_tokens - results.append(t[offset : offset + local_num_tokens]) - return results - - def _permute(self, routed_input, num_tokens_per_expert_group, ep_size, num_local_experts): - device = num_tokens_per_expert_group.device - total = num_tokens_per_expert_group.sum().item() - - t_mat = num_tokens_per_expert_group.view(ep_size, num_local_experts) - input_starts = (num_tokens_per_expert_group.cumsum(0) - num_tokens_per_expert_group).view(ep_size, num_local_experts) - - segment_lens = t_mat.t().reshape(-1) - input_starts = input_starts.t().reshape(-1) - - seg_ids = torch.arange(segment_lens.shape[0], device=device).repeat_interleave(segment_lens.long()) - output_starts = segment_lens.cumsum(0) - segment_lens - permuted_indices = (input_starts[seg_ids] + torch.arange(total, device=device) - output_starts[seg_ids]).long() - - num_tokens_per_expert = t_mat.sum(0) - return routed_input.shape, routed_input[permuted_indices, :], permuted_indices, num_tokens_per_expert - - def _unpermute(self, routed_output, input_shape, permuted_indices): - out_unpermuted = routed_output.new_empty(input_shape) - out_unpermuted[permuted_indices, :] = routed_output - return out_unpermuted - - def dispatch(self, x: torch.Tensor, top_scores: torch.Tensor, selected_experts_indices: torch.Tensor): - if self.sp_size > 1: - x, top_scores, selected_experts_indices = self._split_along_sp(x, top_scores, selected_experts_indices) - - flat_experts = selected_experts_indices.view(-1) - num_tokens_per_expert = torch.bincount(flat_experts, minlength=self.num_experts).float() - - token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) - top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - routed_input = x[token_indices_experts_sorted] - - if self.score_before_experts: - routed_input = (routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1)).to(x.dtype) - - # Skip all-to-all logic entirely if ep_mesh is missing (EP=1) - if self.ep_mesh is None: - metadata = AllToAllDispatchMetadata( - token_indices_experts_sorted, top_scores_experts_sorted, - None, None, None, None - ) - return routed_input, num_tokens_per_expert, metadata - - ep_size = self.ep_mesh.size() - num_tokens_per_expert_group = torch.empty_like(num_tokens_per_expert) - dist.all_to_all_single(num_tokens_per_expert_group, num_tokens_per_expert, group=self.ep_mesh.get_group()) - - input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1).int().cpu().tolist() - output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1).int().cpu().tolist() - - routed_input = all_to_all_single_autograd(routed_input, output_splits, input_splits, self.ep_mesh) - - num_local_experts = num_tokens_per_expert_group.shape[0] // ep_size - input_shape, routed_input, permuted_indices, num_tokens_per_expert_group = self._permute( - routed_input, num_tokens_per_expert_group, ep_size, num_local_experts - ) - - metadata = AllToAllDispatchMetadata( - token_indices_experts_sorted, top_scores_experts_sorted, - input_shape, permuted_indices, input_splits, output_splits - ) - return routed_input, num_tokens_per_expert_group, metadata - - def combine(self, routed_output: torch.Tensor, metadata: "AllToAllDispatchMetadata", x: torch.Tensor, shared_experts: torch.nn.Module | None = None) -> torch.Tensor: - if self.ep_mesh is not None: - routed_output = self._unpermute(routed_output, metadata.input_shape, metadata.permuted_indices) - routed_output = all_to_all_single_autograd(routed_output, metadata.input_splits, metadata.output_splits, self.ep_mesh) - - out = shared_experts(x) if shared_experts is not None else torch.zeros_like(x) - - if not self.score_before_experts: - routed_output = (routed_output.to(torch.float32) * metadata.top_scores_experts_sorted.reshape(-1, 1)).to(routed_output.dtype) - - token_indices_experts_sorted = metadata.token_indices_experts_sorted - if self.sp_size > 1: - local_num_tokens = x.shape[0] // self.sp_size - token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * self.sp_rank - - out.index_add_(0, token_indices_experts_sorted, routed_output) - return out - - -class MockLocalExperts(nn.Module): - """A dummy expert implementation holding only local weights.""" - def __init__(self, num_local_experts, hidden_dim, intermediate_dim): - super().__init__() - self.num_local_experts = num_local_experts - self.w1 = nn.Parameter(torch.randn(num_local_experts, hidden_dim, intermediate_dim)) - self.w2 = nn.Parameter(torch.randn(num_local_experts, intermediate_dim, hidden_dim)) - - def forward(self, x, num_tokens_per_expert): - """ - x: (total_routed_tokens, hidden_dim) - num_tokens_per_expert: (num_local_experts,) - """ - total_tokens = int(num_tokens_per_expert.sum().item()) - if total_tokens == 0: - return torch.empty_like(x) - if total_tokens == x.shape[0] and self.num_local_experts == 1: - return torch.relu(x @ self.w1[0]) @ self.w2[0] - - offsets = torch.zeros(self.num_local_experts + 1, dtype=torch.long, device=x.device) - offsets[1:] = num_tokens_per_expert.cumsum(0) - - outputs = [] - for i in range(self.num_local_experts): - start, end = int(offsets[i]), int(offsets[i+1]) - if end > start: - chunk = x[start:end] - hidden = torch.relu(chunk @ self.w1[i]) - outputs.append(hidden @ self.w2[i]) - - return torch.cat(outputs, dim=0) if outputs else torch.empty_like(x) - - forward_compiled = None - -class MockLocalExperts(nn.Module): - def __init__(self, num_local_experts, hidden_dim, intermediate_dim): - super().__init__() - self.num_local_experts = num_local_experts - self.w1 = nn.Parameter(torch.randn(num_local_experts, hidden_dim, intermediate_dim)) - self.w2 = nn.Parameter(torch.randn(num_local_experts, intermediate_dim, hidden_dim)) - - def forward(self, x, num_tokens_per_expert): - total_tokens = int(num_tokens_per_expert.sum().item()) - if total_tokens == 0: - return torch.empty_like(x) - if total_tokens == x.shape[0] and self.num_local_experts == 1: - return torch.relu(F.linear(x, self.w1[0].t())) @ self.w2[0].t() - - offsets = torch.zeros(self.num_local_experts + 1, dtype=torch.long, device=x.device) - offsets[1:] = num_tokens_per_expert.cumsum(0) - - outputs = [] - for i in range(self.num_local_experts): - start, end = int(offsets[i]), int(offsets[i+1]) - if end > start: - chunk = x[start:end] - hidden = torch.relu(F.linear(chunk, self.w1[i].t())) - outputs.append(F.linear(hidden, self.w2[i].t())) - return torch.cat(outputs, dim=0) if outputs else torch.empty_like(x) - return compiled_forward - - def forward(self, x, num_tokens_per_expert): - total_tokens = int(num_tokens_per_expert.sum().item()) - if total_tokens == 0: - return torch.empty_like(x) - if total_tokens == x.shape[0] and self.num_local_experts == 1: - return torch.relu(F.linear(x, self.w1[0].t()) @ self.w2[0].t()) - - offsets = torch.zeros(self.num_local_experts + 1, dtype=torch.long, device=x.device) - offsets[1:] = num_tokens_per_expert.cumsum(0) - - outputs = [] - for i in range(self.num_local_experts): - start, end = int(offsets[i]), int(offsets[i+1]) - if end > start: - chunk = x[start:end] - hidden = torch.relu(F.linear(chunk, self.w1[i].t())) - outputs.append(F.linear(hidden, self.w2[i].t())) - return torch.cat(outputs, dim=0) if outputs else torch.empty_like(x) - -def run_ep_benchmark(ep_size: int, num_experts: int = 16, hidden_dim: int = 4096, batch_size: int = 8, seq_len: int = 2048): - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") - - dp_size = world_size // ep_size - global_mesh = init_device_mesh("cuda", (dp_size, ep_size), mesh_dim_names=("dp", "ep")) - ep_mesh = global_mesh["ep"] - - num_local_experts = num_experts // ep_size - intermediate_dim = 14336 - - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats(device) - - experts = MockLocalExperts(num_local_experts, hidden_dim, intermediate_dim).to(device) - shared_expert = nn.Linear(hidden_dim, hidden_dim).to(device) - dispatcher = TokenDispatcher(num_experts=num_experts, top_k=2, score_before_experts=True) - - if ep_size > 1: - dispatcher.ep_mesh = ep_mesh - - num_tokens = batch_size * seq_len - x = torch.randn(num_tokens, hidden_dim, device=device, requires_grad=True) - router_logits = torch.randn(num_tokens, num_experts, device=device) - top_scores, selected_experts_indices = torch.topk(torch.softmax(router_logits, dim=-1), k=2, dim=-1) - - # WARMUP - for _ in range(3): - r_in, counts, meta = dispatcher.dispatch(x, top_scores, selected_experts_indices) - e_out = experts(r_in, counts) - out = dispatcher.combine(e_out, meta, x, shared_experts=shared_expert) - out.sum().backward() - - torch.cuda.synchronize(device) - - # BENCHMARK - iters = 10 - dispatch_times = [] - expert_times = [] - shared_fwd_times = [] - index_add_times = [] - backward_times = [] - - for _ in range(iters): - torch.cuda.synchronize(device) - - t0 = torch.cuda.Event(enable_timing=True) - t1 = torch.cuda.Event(enable_timing=True) - t2 = torch.cuda.Event(enable_timing=True) - t3 = torch.cuda.Event(enable_timing=True) - t4 = torch.cuda.Event(enable_timing=True) - t5 = torch.cuda.Event(enable_timing=True) - - t0.record() - r_in, counts, meta = dispatcher.dispatch(x, top_scores, selected_experts_indices) - t1.record() - e_out = experts(r_in, counts) - t2.record() - - out = shared_expert(x) - t3.record() - - routed = e_out - if dispatcher.ep_mesh is not None: - routed = dispatcher._unpermute(routed, meta.input_shape, meta.permuted_indices) - routed = all_to_all_single_autograd(routed, meta.input_splits, meta.output_splits, dispatcher.ep_mesh) - - token_indices_experts_sorted = meta.token_indices_experts_sorted - if dispatcher.sp_size > 1: - local_num_tokens = x.shape[0] // dispatcher.sp_size - token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * dispatcher.sp_rank - - out.index_add_(0, token_indices_experts_sorted, routed) - t4.record() - - out.sum().backward() - t5.record() - - torch.cuda.synchronize(device) - dispatch_times.append(t0.elapsed_time(t1)) - expert_times.append(t1.elapsed_time(t2)) - shared_fwd_times.append(t2.elapsed_time(t3)) - index_add_times.append(t3.elapsed_time(t4)) - backward_times.append(t4.elapsed_time(t5)) - - avg_dispatch = sum(dispatch_times) / iters - avg_expert = sum(expert_times) / iters - avg_shared_fwd = sum(shared_fwd_times) / iters - avg_index_add = sum(index_add_times) / iters - avg_backward = sum(backward_times) / iters - avg_time_ms = avg_dispatch + avg_expert + avg_backward - peak_mem_gb = torch.cuda.max_memory_allocated(device) / (1024 ** 3) - - # FLOPs Math: FWD is 2*H*I per weight matrix, BWD is ~2x FWD. Total = 12 * H * I per token. - # Total routed tokens globally = num_tokens * dp_size * top_k - global_routed_tokens = num_tokens * dp_size * dispatcher.top_k - expert_flops_per_iter = 12 * global_routed_tokens * hidden_dim * intermediate_dim - - # Include shared expert FLOPs - shared_flops_per_iter = 12 * (num_tokens * dp_size) * hidden_dim * hidden_dim - total_flops = expert_flops_per_iter + shared_flops_per_iter - - tflops_per_sec = (total_flops / (avg_time_ms / 1000.0)) / (1e12) - tflops_per_gpu = tflops_per_sec / world_size - - dist.barrier() - - if rank == 0: - print(f"--- Configuration: DP={dp_size}, EP={ep_size} ---") - print(f"Peak Memory Allocated: {peak_mem_gb:.2f} GB") - print(f"Dispatch: {avg_dispatch:.2f}ms, Expert: {avg_expert:.2f}ms, SharedFwd: {avg_shared_fwd:.2f}ms, IndexAdd: {avg_index_add:.2f}ms, Backward: {avg_backward:.2f}ms") - print(f"Average Time / Iteration: {avg_time_ms:.2f} ms") - print(f"Total TFLOPS (Cluster): {tflops_per_sec:.2f} TFLOPS") - print(f"Per-GPU TFLOPS: {tflops_per_gpu:.2f} TFLOPS\n") - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cudnn.benchmark = True - -if __name__ == "__main__": - dist.init_process_group(backend="nccl") - - world_size = dist.get_world_size() - assert world_size == 4, "This test is designed to run exactly on 4 devices." - - torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) - - if dist.get_rank() == 0: - print("Starting EP Memory Benchmark...\n") - - # Baseline: EP=1 (All 16 experts on every GPU) - run_ep_benchmark(ep_size=1) - - # Run EP=2 - run_ep_benchmark(ep_size=2) - - # Run EP=4 - run_ep_benchmark(ep_size=4) - - dist.destroy_process_group() diff --git a/train/config.py b/train/config.py index 7f4061f..321f32f 100644 --- a/train/config.py +++ b/train/config.py @@ -89,13 +89,21 @@ class Training: min_lr_ratio: float = 0.1 # --------------- + # torch dynamo compiler + compile: bool = True + """ + Always on by default, unless you have an error. + """ + +@dataclass +class Parallel: data_parallel: str = "ddp" # fsdp, ddp """ Use `fsdp` when you want to decrease usage to increase seq_len/batch_size. """ tp_size: int = 1 # 1 means disabled - pp_size: int = 1 # 1 means disabled; supported values: 2, 2, 4 + pp_size: int = 1 # 1 means disabled; supported values: 2, 4 ep_size: int = 1 # 1 means disabled; must divide num_experts evenly pp_num_layers_first: int = 1 @@ -107,16 +115,9 @@ class Training: # to actually pipeline (smaller values degrade to GPipe-like behavior). pp_microbatches: int = 1 - # compiler flag for TP (goes faster) async_tp: bool = True - # torch dynamo compiler - compile: bool = True - """ - Always on by default, unless you have an error. - """ - # activation checkpointing ac_mode: str = "off" """ @@ -169,5 +170,6 @@ class Config: model: Model = field(default_factory=Model) data: Data = field(default_factory=Data) wandb: Wandb = field(default_factory=Wandb) + parallel: Parallel = field(default_factory=Parallel) config: str = '/home/tockier/vlm-training/configs/cvc_config.toml' diff --git a/train/infra.py b/train/infra.py index 1020b43..3197929 100644 --- a/train/infra.py +++ b/train/infra.py @@ -3,6 +3,7 @@ from typing import Optional from train.config import ModelType +from train.logger import logger import torch import torch._inductor.config @@ -46,6 +47,12 @@ CheckpointPolicy, create_selective_checkpoint_contexts, ) +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.microbatch import _Replicate +from torch.distributed.pipelining.schedules import Schedule1F1B, ScheduleGPipe + +# used for PP +from models.qwen3_5.utils import causal_lm_loss, load_stage_weights class NoParallel(ParallelStyle): def __init__( @@ -115,51 +122,59 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ), ) -def get_mesh(training_args, world_size): - tp_size = training_args.tp_size - pp_size = getattr(training_args, "pp_size", 1) - ep_size = getattr(training_args, "ep_size", 1) +def get_mesh(parallel_args, world_size): + tp = getattr(parallel_args, "tp_size", 1) + pp = getattr(parallel_args, "pp_size", 1) + ep = getattr(parallel_args, "ep_size", 1) - if world_size % (tp_size * pp_size * ep_size) != 0: - raise ValueError( - f"world_size {world_size} not divisible by tp_size*pp_size*ep_size={tp_size * pp_size * ep_size}" - ) - dp_size = world_size // (tp_size * pp_size * ep_size) + assert world_size % (tp * pp) == 0, f"world_size not divisible by tp*pp" + dp = world_size // (tp * pp) - if pp_size > 1 and ep_size > 1: - return init_device_mesh( + if ep > 1 or dp == "fsdp": + dp_shard = dp + dp_replicate = 1 + else: + dp_replicate = dp + dp_shard = 1 + + if ep > 1: + assert dp_shard % ep == 0, f"EP ({ep}) must divide dp_shard ({dp_shard})" + dp_mod_ep = dp_shard // ep + + mesh = init_device_mesh( "cuda", - (dp_size, pp_size, ep_size, tp_size), - mesh_dim_names=("dp", "pp", "ep", "tp"), - ) - if pp_size > 1: - return init_device_mesh( - "cuda", (dp_size, pp_size, tp_size), mesh_dim_names=("dp", "pp", "tp") + (pp, dp_replicate, dp_mod_ep, ep, tp), + mesh_dim_names=("pp", "dp_replicate", "dp_mod_ep", "ep", "tp") ) - if ep_size > 1: - return init_device_mesh( - "cuda", (dp_size, ep_size, tp_size), mesh_dim_names=("dp", "ep", "tp") + + mesh._flattened_submeshes = { + "dp": mesh["dp_replicate", "dp_mod_ep", "ep"]._flatten("dp"), + "dp_shard": mesh["dp_mod_ep", "ep"]._flatten("dp_shard"), + } + else: + mesh = init_device_mesh( + "cuda", + (pp, dp_replicate, dp_shard, tp), + mesh_dim_names=("pp", "dp_replicate", "dp_shard", "tp") ) - return init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) -def get_tp_group(mesh): - if "tp" in mesh.mesh_dim_names: - return mesh["tp"] - return None - -def get_dp_group(mesh): - if "dp" in mesh.mesh_dim_names: - return mesh["dp"] - return None - -def get_pp_group(mesh): - if "pp" in mesh.mesh_dim_names: - return mesh["pp"] - return None - -def get_ep_group(mesh): - if "ep" in mesh.mesh_dim_names: - return mesh["ep"] + print(mesh) + print(mesh) + print(mesh) + + mesh._flattened_submeshes = { + "dp": mesh["dp_replicate", "dp_shard"]._flatten("dp") + } + + return mesh + +def get_mesh_group(mesh, dim_name: str): + if dim_name in mesh.mesh_dim_names: + return mesh[dim_name] + + if hasattr(mesh, "_flattened_submeshes") and dim_name in mesh._flattened_submeshes: + return mesh._flattened_submeshes[dim_name] + return None def apply_ep(model, ep_mesh, tp_mesh=None): @@ -735,3 +750,167 @@ def _apply_tp_to_decoder_qwen3_5( if enable_async_tp: torch._inductor.config._micro_pipeline_tp = True + +def get_local_fqns( + num_layers: int, + pp_size: int, + pp_rank: int, + num_first: int, + num_last: int +) -> list[str]: + if pp_size == 1: + return [ + "model.visual", + "model.language_model.embed_tokens", + ] + [f"model.language_model.layers.{i}" for i in range(num_layers)] + [ + "model.language_model.norm", + "lm_head" + ] + + fqns = [] + + if pp_rank == 0: + fqns.extend(["model.visual", "model.language_model.embed_tokens"]) + + if pp_size == 2: + mid_point = num_layers // 2 + (num_layers % 2) + start_idx = 0 if pp_rank == 0 else mid_point + end_idx = mid_point if pp_rank == 0 else num_layers + else: + if pp_rank == 0: + start_idx, end_idx = 0, num_first + elif pp_rank == pp_size - 1: + start_idx, end_idx = num_layers - num_last, num_layers + else: + middle_layers = num_layers - num_first - num_last + middle_ranks = pp_size - 2 + + layers_per_mid = middle_layers // middle_ranks + remainder = middle_layers % middle_ranks + + mid_idx = pp_rank - 1 + start_idx = num_first + (mid_idx * layers_per_mid) + min(mid_idx, remainder) + num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) + end_idx = start_idx + num_layers_this_rank + + fqns.extend([f"model.language_model.layers.{i}" for i in range(start_idx, end_idx)]) + + if pp_rank == pp_size - 1: + fqns.extend(["model.language_model.norm", "lm_head"]) + + return fqns + + +def apply_pp( + model, + mesh, + parallel_args, + training_args, + device, + pp_loss_fn, + ): + logger.info("Applying Pipeline Parallelism module split...") + pp_rank = mesh.get_local_rank(mesh_dim="pp") + total_layers = model.cfg.text.num_hidden_layers + pp_size = parallel_args.pp_size + + local_fqns = get_local_fqns( + num_layers=total_layers, + pp_size=pp_size, + pp_rank=pp_rank, + num_first= parallel_args.pp_num_layers_first, + num_last= parallel_args.pp_num_layers_last + ) + + if "model.visual" not in local_fqns: + model.model.visual = None + if "model.language_model.embed_tokens" not in local_fqns: + model.model.language_model.embed_tokens = None + + layers = model.model.language_model.layers + kept_indices = {int(f.split('.')[-1]) for f in local_fqns if "layers." in f} + model.model.language_model.layers = torch.nn.ModuleList( + [m for i, m in enumerate(layers) if i in kept_indices] + ) + + if "model.language_model.norm" not in local_fqns: + model.model.language_model.norm = None + if "lm_head" not in local_fqns: + model.lm_head = None + + model.to(device=device) + + layer_indices = [int(f.split('.')[-1]) for f in local_fqns if "layers." in f] + layer_start = min(layer_indices) if layer_indices else 0 + layer_end = max(layer_indices) + 1 if layer_indices else 0 + + target_dtype = torch.bfloat16 if training_args.bfloat16 else torch.float32 + + # Load stage-specific weights when PP > 1 and not using random init + if False: + logger.info(f"PP rank {pp_rank}: About to load stage weights for layers {layer_start}-{layer_end}") + load_stage_weights( + stage=self.model, + snapshot_dir=self.training_args.model_dir, + layer_start=layer_start, + layer_end=layer_end, + is_first=pp_rank == 0, + is_last=pp_rank == self.pp_size - 1, + device=self.device, + dtype=target_dtype, + ) + logger.info(f"PP rank {pp_rank}: Finished loading stage weights") + + # materialize model in GPU + model = model.to(device) + + pp_stage = PipelineStage( + model, + stage_index=pp_rank, + num_stages=pp_size, + device=device, + group=mesh.get_group(mesh_dim="pp"), + ) + + schedule_name = getattr(parallel_args, "pp_schedule", "gpipe").lower() + n_microbatches = getattr(parallel_args, "pp_microbatches", 1) + schedule_cls = {"gpipe": ScheduleGPipe, "1f1b": Schedule1F1B}.get(schedule_name) + if schedule_cls is None: + raise ValueError( + f"unknown pp_schedule={schedule_name!r}; expected one of: gpipe, 1f1b" + ) + if schedule_cls is Schedule1F1B: + # The plumbing exists (schedule construction, kwargs_chunk_spec, + # tiled input in _train_step_pp), but 1F1B requires + # n_microbatches >= pp_size and the dataloader currently emits + # one packed (1, total) sample per step — so microbatches are + # tiled copies of the same content and the loss is meaningless. + # Re-enable once the data path produces n_microbatches independent + # packed rows per step (per-row cu_seqlens, labels, image scatter). + raise NotImplementedError( + "pp_schedule='1f1b' is disabled until the data path supports " + "n_microbatches independent packed rows per step. Use 'gpipe' " + "for now." + ) + # The dataloader emits a single packed (1, total) sample per step. + # When n_microbatches > 1 we tile input_ids/labels to (N, total) so + # the schedule can chunk along dim 0; everything else (cu_seqlens, + # pixel_values, image_grid_thw, etc.) is per-batch metadata that + # must be passed identically to every microbatch — mark it replicate. + pp_microbatches = n_microbatches + kwargs_chunk_spec = { + k: _Replicate() for k in ( + "input_ids", "attention_mask", "original_mask", + "image_grid_thw", "pixel_values", + "pixel_values_videos", "video_grid_thw", + ) + } + pp_schedule = schedule_cls( + pp_stage, + n_microbatches=n_microbatches, + loss_fn=pp_loss_fn, + kwargs_chunk_spec=kwargs_chunk_spec, + ) + logger.info(f"PP schedule: {schedule_name} (n_microbatches={n_microbatches})") + + return pp_microbatches, pp_schedule diff --git a/train/train_qwen.py b/train/train_qwen.py index 07611e2..66fdf6f 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -14,9 +14,6 @@ import torch.distributed as dist from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed._composable.replicate import replicate -from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.microbatch import _Replicate -from torch.distributed.pipelining.schedules import Schedule1F1B, ScheduleGPipe from torch.profiler import profile, record_function, ProfilerActivity, schedule @@ -24,24 +21,25 @@ from megatron.energon import get_train_dataset, get_loader, WorkerConfig from data.task_encoder_factory import build_task_encoder +from models.qwen3_5.utils import causal_lm_loss, load_stage_weights + # training imports from train.config_manager import ConfigManager from train.config import Config, ModelType from train.logger import init_logger, logger, Color from train.infra import ( get_mesh, - get_tp_group, - get_dp_group, - get_pp_group, - get_ep_group, + get_mesh_group, + apply_fsdp, apply_tp, apply_ep, apply_ac, + apply_pp, + ACConfig, compile_model, ) -from models.qwen3_5.utils import causal_lm_loss, load_stage_weights from models.qwen3_5.model import initialize_missing_weights from train.utils import ( @@ -69,55 +67,6 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -def get_local_fqns( - num_layers: int, - pp_size: int, - pp_rank: int, - num_first: int, - num_last: int -) -> list[str]: - if pp_size == 1: - return [ - "model.visual", - "model.language_model.embed_tokens", - ] + [f"model.language_model.layers.{i}" for i in range(num_layers)] + [ - "model.language_model.norm", - "lm_head" - ] - - fqns = [] - - if pp_rank == 0: - fqns.extend(["model.visual", "model.language_model.embed_tokens"]) - - if pp_size == 2: - mid_point = num_layers // 2 + (num_layers % 2) - start_idx = 0 if pp_rank == 0 else mid_point - end_idx = mid_point if pp_rank == 0 else num_layers - else: - if pp_rank == 0: - start_idx, end_idx = 0, num_first - elif pp_rank == pp_size - 1: - start_idx, end_idx = num_layers - num_last, num_layers - else: - middle_layers = num_layers - num_first - num_last - middle_ranks = pp_size - 2 - - layers_per_mid = middle_layers // middle_ranks - remainder = middle_layers % middle_ranks - - mid_idx = pp_rank - 1 - start_idx = num_first + (mid_idx * layers_per_mid) + min(mid_idx, remainder) - num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) - end_idx = start_idx + num_layers_this_rank - - fqns.extend([f"model.language_model.layers.{i}" for i in range(start_idx, end_idx)]) - - if pp_rank == pp_size - 1: - fqns.extend(["model.language_model.norm", "lm_head"]) - - return fqns - class Trainer(torch.distributed.checkpoint.stateful.Stateful): @record @@ -125,6 +74,7 @@ def __init__(self, cfg: Config): self.model_args = cfg.model self.training_args = cfg.training self.data_args = cfg.data + self.p_args = cfg.parallel self.wandb_args = cfg.wandb self.debug_mode = bool(os.environ.get("DEBUG", False)) @@ -133,13 +83,18 @@ def __init__(self, cfg: Config): self.world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(self.local_rank) - self.mesh = get_mesh(self.training_args, self.world_size) - self.tp_group = get_tp_group(self.mesh) - self.dp_group = get_dp_group(self.mesh) - self.pp_group = get_pp_group(self.mesh) - self.ep_group = get_ep_group(self.mesh) - self.pp_size = getattr(self.training_args, "pp_size", 1) - self.ep_size = getattr(self.training_args, "ep_size", 1) + self.mesh = get_mesh(self.p_args, self.world_size) + + self.tp_group = get_mesh_group(self.mesh, 'tp') + self.pp_group = get_mesh_group(self.mesh, 'pp') + self.ep_group = get_mesh_group(self.mesh, 'ep') + self.shard_group = get_mesh_group(self.mesh, 'dp_shard') + self.replicate_group = get_mesh_group(self.mesh, 'dp_replicate') + # this mesh group unifies `shard` and `replicate` + self.dp_group = get_mesh_group(self.mesh, "dp") + + self.pp_size = getattr(self.p_args, "pp_size", 1) + self.ep_size = getattr(self.p_args, "ep_size", 1) self.device = torch.device(f"cuda:{self.local_rank}") if self.if_log_rank(): @@ -151,10 +106,10 @@ def __init__(self, cfg: Config): **vars(self.model_args), **vars(self.training_args), **vars(self.data_args), + **vars(self.p_args), "mesh": self.mesh, "world_size": self.world_size, - "dp_group": self.dp_group, - "tp_group": self.tp_group, + # TODO: add all parallel args here }, ) @@ -170,6 +125,8 @@ def __init__(self, cfg: Config): set_determinism(seed=42 + self.local_rank, deterministic=True, world_mesh=self.mesh, debug_mode=self.debug_mode) + self.setup_accumulation(self.training_args.tpi_multiplier) + if self.rank() == 0: if not os.path.exists(self.training_args.output_dir): os.makedirs(self.training_args.output_dir) @@ -200,74 +157,8 @@ def __init__(self, cfg: Config): self.optimizer = None # defined later on - if self.training_args.load_text_model: - self.text_model = select_text_model(self.training_args) - self.model = load_text_model(self.model, self.text_model) - - if self.pp_size > 1: - logger.info("Applying Pipeline Parallelism module split...") - pp_rank = self.mesh.get_local_rank(mesh_dim="pp") - total_layers = self.model.cfg.text.num_hidden_layers - - local_fqns = get_local_fqns( - num_layers=total_layers, - pp_size=self.pp_size, - pp_rank=pp_rank, - num_first=self.training_args.pp_num_layers_first, - num_last=self.training_args.pp_num_layers_last - ) - - if "model.visual" not in local_fqns: - self.model.model.visual = None - if "model.language_model.embed_tokens" not in local_fqns: - self.model.model.language_model.embed_tokens = None - - layers = self.model.model.language_model.layers - kept_indices = {int(f.split('.')[-1]) for f in local_fqns if "layers." in f} - self.model.model.language_model.layers = torch.nn.ModuleList( - [m for i, m in enumerate(layers) if i in kept_indices] - ) - - if "model.language_model.norm" not in local_fqns: - self.model.model.language_model.norm = None - if "lm_head" not in local_fqns: - self.model.lm_head = None - - self.model.to(device=self.device) - - self.pp_has_first_stage = self.model.model.visual is not None - self.pp_has_last_stage = self.model.lm_head is not None - - layer_indices = [int(f.split('.')[-1]) for f in local_fqns if "layers." in f] - layer_start = min(layer_indices) if layer_indices else 0 - layer_end = max(layer_indices) + 1 if layer_indices else 0 - - target_dtype = torch.bfloat16 if self.training_args.bfloat16 else torch.float32 - - # Load stage-specific weights when PP > 1 and not using random init - if False: - logger.info(f"PP rank {pp_rank}: About to load stage weights for layers {layer_start}-{layer_end}") - load_stage_weights( - stage=self.model, - snapshot_dir=self.training_args.model_dir, - layer_start=layer_start, - layer_end=layer_end, - is_first=pp_rank == 0, - is_last=pp_rank == self.pp_size - 1, - device=self.device, - dtype=target_dtype, - ) - logger.info(f"PP rank {pp_rank}: Finished loading stage weights") - - self.model = self.model.to(self.device) - - self.pp_stage = PipelineStage( - self.model, - stage_index=pp_rank, - num_stages=self.pp_size, - device=self.device, - group=self.mesh.get_group(mesh_dim="pp"), - ) + # -- PIPELINE PARALLEL + if self.p_args.pp_size > 1: def pp_loss_fn(outputs, labels): logits, aux_loss = outputs @@ -279,47 +170,15 @@ def pp_loss_fn(outputs, labels): return (ce_loss + 0.01 * aux_loss) / self.current_accum_target - schedule_name = getattr(self.training_args, "pp_schedule", "gpipe").lower() - n_microbatches = getattr(self.training_args, "pp_microbatches", 1) - schedule_cls = {"gpipe": ScheduleGPipe, "1f1b": Schedule1F1B}.get(schedule_name) - if schedule_cls is None: - raise ValueError( - f"unknown pp_schedule={schedule_name!r}; expected one of: gpipe, 1f1b" - ) - if schedule_cls is Schedule1F1B: - # The plumbing exists (schedule construction, kwargs_chunk_spec, - # tiled input in _train_step_pp), but 1F1B requires - # n_microbatches >= pp_size and the dataloader currently emits - # one packed (1, total) sample per step — so microbatches are - # tiled copies of the same content and the loss is meaningless. - # Re-enable once the data path produces n_microbatches independent - # packed rows per step (per-row cu_seqlens, labels, image scatter). - raise NotImplementedError( - "pp_schedule='1f1b' is disabled until the data path supports " - "n_microbatches independent packed rows per step. Use 'gpipe' " - "for now." - ) - # The dataloader emits a single packed (1, total) sample per step. - # When n_microbatches > 1 we tile input_ids/labels to (N, total) so - # the schedule can chunk along dim 0; everything else (cu_seqlens, - # pixel_values, image_grid_thw, etc.) is per-batch metadata that - # must be passed identically to every microbatch — mark it replicate. - self.pp_microbatches = n_microbatches - kwargs_chunk_spec = { - k: _Replicate() for k in ( - "input_ids", "attention_mask", "original_mask", - "image_grid_thw", "pixel_values", - "pixel_values_videos", "video_grid_thw", - ) - } - self.pp_schedule = schedule_cls( - self.pp_stage, - n_microbatches=n_microbatches, - loss_fn=pp_loss_fn, - kwargs_chunk_spec=kwargs_chunk_spec, + self.pp_microbatches, self.pp_schedule = apply_pp( + self.model, self.mesh, self.p_args, self.training_args, self.device, pp_loss_fn ) - logger.info(f"PP schedule: {schedule_name} (n_microbatches={n_microbatches})") + self.pp_has_first_stage = self.model.model.visual is not None + self.pp_has_last_stage = self.model.lm_head is not None + logger.info("model loaded") + + # -- WEIGHT INIT if self.training_args.random_init: if self.model_type == ModelType.Qwen3_5: logger.info('initilizing decoder and projecter of Qwen3.5') @@ -331,22 +190,16 @@ def pp_loss_fn(outputs, labels): logger.info('model not initlized, incompatible') initialize_missing_weights(self.model) + # -- MIXED PRECISION self.model.train() if self.training_args.bfloat16: self.model = self.model.to(torch.bfloat16) - logger.info("model loaded") - - if self.training_args.tp_size > 1: - apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) - - if self.ep_size > 1: - if self.model_type != ModelType.Qwen3_5: - raise NotImplementedError("EP is only supported for Qwen3.5 MoE models") - tp_mesh = self.tp_group if self.training_args.tp_size > 1 else None - apply_ep(self.model, self.ep_group, tp_mesh=tp_mesh) - logger.info(f"expert parallelism applied (ep_size={self.ep_size}, tp_size={self.training_args.tp_size})") + # -- TENSOR PARALLEL + if self.p_args.tp_size > 1: + apply_tp(self.model, self.model_type, self.tp_group, self.p_args.async_tp) + # -- ACTIVATION CHECKPOINTING ac_mode = getattr(self.training_args, "ac_mode", "off") if ac_mode != "off": ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) @@ -357,29 +210,32 @@ def pp_loss_fn(outputs, labels): ) logger.info(f"activation checkpointing applied ({ac_mode})") - if self.training_args.data_parallel == 'fsdp': - apply_fsdp(self.model_type, self.model, mesh=self.dp_group) - elif self.training_args.data_parallel == 'ddp': - if self.ep_size > 1 and self.dp_group.size() > 1: - # Skip DDP hook-based all_reduce when EP is active to avoid a NCCL deadlock: - # DDP hooks fire async on dp_comm while EP backward A2As block on ep_comm, - # creating a cross-communicator cycle. Gradients are synced manually after backward. - logger.info(f"rank={self.rank()} EP+DP: skipping replicate(), will manually sync grads (ep={self.ep_size}, dp={self.dp_group.size()})") - elif self.dp_group.size() > 1: - self.model = replicate(self.model, device_mesh=self.dp_group) - logger.info(f"rank={self.rank()} DDP applied (dp={self.dp_group.size()})") + # -- EXPERT PARALLEL + if self.ep_size > 1: + if self.model_type != ModelType.Qwen3_5: + raise NotImplementedError("EP is only supported for Qwen3.5 MoE models") + tp_mesh = self.tp_group if self.p_args.tp_size > 1 else None + apply_ep(self.model, self.ep_group, tp_mesh=tp_mesh) + logger.info(f"expert parallelism applied (ep_size={self.ep_size})") + + # -- DATA PARALLEL + dp_shard_mesh = get_mesh_group(self.mesh, "dp_shard") + dp_replicate_mesh = get_mesh_group(self.mesh, "dp_replicate") + + if dp_shard_mesh is not None and dp_shard_mesh.size() > 1: + apply_fsdp(self.model_type, self.model, mesh=dp_shard_mesh) + logger.info(f"FSDP applied (dp_shard={dp_shard_mesh.size()})") + elif dp_replicate_mesh is not None and dp_replicate_mesh.size() > 1: + self.model = replicate(self.model, device_mesh=dp_replicate_mesh) + logger.info(f"DDP applied (dp_replicate={dp_replicate_mesh.size()})") else: - raise Exception('invalid sharding strategy for Data Parallel') + logger.info(f"no DP applied (dp=1)") # loading into GPU self.model = self.model.to(device=self.device) if self.training_args.bfloat16: self.model = self.model.to(torch.bfloat16) - # get rank of local GPU that belongs to the DP group - data_rank = self.dp_group.get_local_rank() - data_world_size = self.dp_group.size() - logger.info('sharding/parallelism applied') if self.training_args.compile and self.pp_size == 1: @@ -396,7 +252,6 @@ def pp_loss_fn(outputs, labels): self.processor = AutoProcessor.from_pretrained( self.training_args.model_dir, - ) # set_model freezes/unfreezes param groups; skip for PP (stage module @@ -404,6 +259,10 @@ def pp_loss_fn(outputs, labels): if self.pp_size == 1: self.model = set_model(self.model_type, self.model_args, self.model) + # get rank of local GPU that belongs to the DP group + data_rank = self.dp_group.get_local_rank() + data_world_size = self.dp_group.size() + worker_config = WorkerConfig( rank=data_rank, world_size=data_world_size, @@ -428,8 +287,6 @@ def pp_loss_fn(outputs, labels): self.data_loader = get_loader(ds) - self.setup_accumulation(self.training_args.tpi_multiplier) - self.global_step = 0 self.micro_step = 0 @@ -721,6 +578,7 @@ def _train_step_pp(self, data_iterator, optimizer): loss_for_logging = scaled_loss * self.current_accum_target torch.distributed.all_reduce(loss_for_logging, group=self.pp_group.get_group()) + # TODO: FIX THIS ce_loss = getattr(self, '_recent_ce_loss', torch.tensor(0.0, device=self.device)) aux_loss = getattr(self, '_recent_aux_loss', torch.tensor(0.0, device=self.device)) From 9e5ddfc0b684dc99758be3de52da7e0db81fddde Mon Sep 17 00:00:00 2001 From: tomiock Date: Sat, 2 May 2026 13:40:15 +0200 Subject: [PATCH 23/23] [fix] better PP stages --- configs/cvc/moe.toml | 4 ++-- train/infra.py | 15 ++++----------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml index 648b505..d7c4636 100644 --- a/configs/cvc/moe.toml +++ b/configs/cvc/moe.toml @@ -30,8 +30,8 @@ data_parallel = 'fsdp' ac_mode = "off" [data] -data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" -seq_len = 512 +data_path = "/data/151-1/datasets/synth_test_datasets/imagenet" +seq_len = 8192 packing_buffer_size = 100 diff --git a/train/infra.py b/train/infra.py index 3197929..adafa68 100644 --- a/train/infra.py +++ b/train/infra.py @@ -758,17 +758,7 @@ def get_local_fqns( num_first: int, num_last: int ) -> list[str]: - if pp_size == 1: - return [ - "model.visual", - "model.language_model.embed_tokens", - ] + [f"model.language_model.layers.{i}" for i in range(num_layers)] + [ - "model.language_model.norm", - "lm_head" - ] - fqns = [] - if pp_rank == 0: fqns.extend(["model.visual", "model.language_model.embed_tokens"]) @@ -778,6 +768,8 @@ def get_local_fqns( end_idx = mid_point if pp_rank == 0 else num_layers else: if pp_rank == 0: + pass + elif pp_rank == 1: start_idx, end_idx = 0, num_first elif pp_rank == pp_size - 1: start_idx, end_idx = num_layers - num_last, num_layers @@ -793,7 +785,8 @@ def get_local_fqns( num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) end_idx = start_idx + num_layers_this_rank - fqns.extend([f"model.language_model.layers.{i}" for i in range(start_idx, end_idx)]) + if pp_rank != 0: + fqns.extend([f"model.language_model.layers.{i}" for i in range(start_idx, end_idx)]) if pp_rank == pp_size - 1: fqns.extend(["model.language_model.norm", "lm_head"])