diff --git a/README.md b/README.md index eabf020..b32295c 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ Massive-scale VLM pre-training and finetuning on HPC environments. It is specifically designed and tested for **Marenostrum 5** and **JUPITER**. Works similary to torchtitan, only relying on native torch code for the distributed implementation. Compatibilty with HF state-dict, loads weights from HF snapshot directory. +See SCALABILITY.md and USAGE.md for more details. + ## Key Features * **Supported Architectures:** **Qwen3.5**, Qwen3-VL and Qwen3 (text). * **2D Parallelism:** FSDP/DDP (Single & Multi-node) and Tensor Parallelism (TP) support. Tested scaling up to 256 GPUs. @@ -32,19 +34,12 @@ Support for ROCm systems (LUMI) is work in progress. - `transformers=5.6.0` ## Datasets and Dataloading -Datasets are expected to be as a CrudeWebdataset. With https://github.com/NVIDIA/Megatron-Energon we handle the raw data and tokenize it on the fly. It is an asynchrnos process that does not have an impact on model performance. - -**Online datapacking is not yet supported** (no particular issue related to the HPC system, its skill issue on my part). We believe that data packing is a must-have for a visual-language training codebase with native resolution, as the varying image sizes on the datasets can be handled easily. +Datasets are expected to be as a CrudeWebdataset. With https://github.com/NVIDIA/Megatron-Energon we handle the raw data and tokenize it on the fly. It is an asynchrnos process that does not have an impact on model performance. **Online datapacking is used by default.** Support for Metadatasets (multiple sources). ## Model Weights & Offline Loading -Use `utils/down.py` on a login node to pre-download model weights and tokenizers to a shared filesystem: - -```bash -python utils/down.py -``` -Go into the file and change the arguments, it does not have CLI support. +Use `utils/down.py` on a login node to pre-download model weights and tokenizers to a shared filesystem. The models' archicture configuration relies on what is downloaded. -**Loading Mechanism:** During training, models are instantiated directly from these local paths. For Native Torch models, the architecture is initialized purely in PyTorch, and the offline weights are mapped and loaded directly into the native state dictionary. +**Loading Mechanism:** During training, models are instantiated directly from these local paths. The architecture is initialized purely in PyTorch, and the offline weights are mapped and loaded directly into the native state dictionary. ## Usage 1. Ensure your datasets are formatted as Nvidia Energon webdatasets. @@ -67,5 +62,6 @@ The codebase demonstrates linear scaling up to 256 GPUs using FSDP and Tensor Pa For a detailed breakdown of throughput, GPU efficiency, and scaling characteristics, please refer to [SCALABILITY.md](SCALABILITY.md). ## Known Issues & TODOs -* Online data packing for Energon dataloading is not yet supported. +* The entire workflow `training -> checkpoints -> eval/usage` needs a lot of work. * Static shape compilation (`torch.compile` with `fullgraph=True`) is pending. +* A better data packing implemented is needed. diff --git a/SCALABILITY.md b/SCALABILITY.md index 7672d9a..f8a65e5 100644 --- a/SCALABILITY.md +++ b/SCALABILITY.md @@ -1,6 +1,12 @@ ## Qwen3.5-2B @ JUPITER - 16 node (64 H200 96GB) tested -- 10,000 tks/sec/device +- +15,000 tks/sec/gpu + +## Qwen3.5-9B @ JUPITER +- scaling test from 16 to 256 nodes +- +500 TFLOPS/s/gpu +image + ## Qwen3-VL-8B @ JUPITER - ~380 TFLOPS with 4 nodes (16 GH200 96GB) @@ -16,4 +22,4 @@ ### Results Scalability throughput with 8B model on Marenostrum 5: -image +image diff --git a/configs/cvc/moe.toml b/configs/cvc/moe.toml new file mode 100644 index 0000000..d7c4636 --- /dev/null +++ b/configs/cvc/moe.toml @@ -0,0 +1,38 @@ +[model] +model_name = "Qwen/Qwen3.5-8B-A1B" +model_impl = "native" + +train_llm = true +train_mlp = true +train_vit = true + +[wandb] +run_name = "test" +project_name = "moe" + +[training] +model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_8b_a1b" +#model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_35b_a3b" +output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" + +save_steps = 10000 +total_steps = 100 +random_init = true + +compile = false + +[parallel] +tp_size = 1 +pp_size = 1 +ep_size = 1 +data_parallel = 'fsdp' + +ac_mode = "off" + +[data] +data_path = "/data/151-1/datasets/synth_test_datasets/imagenet" +seq_len = 8192 + +packing_buffer_size = 100 + +batch_size = 0 diff --git a/configs/cvc/qwen3_5_27b.toml b/configs/cvc/qwen3_5_27b.toml new file mode 100644 index 0000000..54554f2 --- /dev/null +++ b/configs/cvc/qwen3_5_27b.toml @@ -0,0 +1,33 @@ +[model] +model_name = "Qwen/Qwen3.5-27B" +model_impl = "native" + +train_llm = true +train_mlp = true +train_vit = false + +[wandb] +run_name = "test" +project_name = "qwen35_27b" + +[training] +model_dir = "/data/151-1/users/tockier/qwen_finetune/cache/qwen35_27b" +output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" + +save_steps = 10000 +total_steps = 10000 +random_init = false + +tp_size = 1 +pp_size = 2 +data_parallel = 'fsdp' + +ac_mode = "full" +compile = false + +[data] +data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" +seq_len = 200 + +packing_buffer_size = 100 +batch_size = 0 diff --git a/configs/cvc/qwen3_5_2b.toml b/configs/cvc/qwen3_5_2b.toml index 7875207..8f85558 100644 --- a/configs/cvc/qwen3_5_2b.toml +++ b/configs/cvc/qwen3_5_2b.toml @@ -17,24 +17,20 @@ random_init = false scheduler_type = "cosine" -total_steps: int = 1_000 -warmup_steps: int = 100 -wsd_decay_ratio: float = 0.1 -min_lr_ratio: float = 0.1 - tp_size = 1 +pp_size = 2 data_parallel = 'ddp' compile = false -tpi_multiplier = 12 +tpi_multiplier = 1 [data] -data_path = "/data/151-1/datasets/llava_recap" -seq_len = 8192 +data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" +seq_len = 256 shuffle_buffer_size = 1000 -packing_buffer_size = 1000 +packing_buffer_size = 100 max_samples_per_sequence = 100 batch_size = 0 diff --git a/configs/cvc/qwen3_5_9b.toml b/configs/cvc/qwen3_5_9b.toml index 221eb5e..65622ba 100644 --- a/configs/cvc/qwen3_5_9b.toml +++ b/configs/cvc/qwen3_5_9b.toml @@ -16,17 +16,19 @@ output_dir = "/data/151-1/users/shared_cache/qwen_finetune/checkpoints" save_steps = 10000 total_steps = 10000 -random_init_mlp = false +random_init = false -tp_size = 4 +tp_size = 1 +pp_size = 2 data_parallel = 'fsdp' ac_mode = "full" - compile = false [data] data_path = "/data/151-1/datasets/synth_test_datasets/cap_pretrain" -seq_len = 4096 +seq_len = 512 + +packing_buffer_size = 100 -batch_size = 32 +batch_size = 0 \ No newline at end of file diff --git a/configs/jupiter/qwen3_5_27b.toml b/configs/jupiter/qwen3_5_27b.toml new file mode 100644 index 0000000..6b9753d --- /dev/null +++ b/configs/jupiter/qwen3_5_27b.toml @@ -0,0 +1,49 @@ +[model] +model_name = "Qwen/Qwen3.5-27B" +model_impl = "native" + +train_llm = true +train_mlp = true +train_vit = false + +[wandb] +run_name = "test 27b" +project_name = "scaling_27b" +entity_name = "bsc_runs" + +[training] +model_dir = "/e/project1/reformo/ockier1/qwen_models/qwen3_5_27b" +output_dir = "/e/scratch/reformo/ockier1/checkpoints/test_35_27b" + +tpi_multiplier = 1.0 +save_steps = 1000 + +scheduler_type = "cosine" +total_steps = 18000 +warmup_steps = 100 +#wsd_decay_ratio = 0.1 +min_lr_ratio = 0.1 + +lr_llm = 0.00002 +lr_mlp = 0.0001 + +data_parallel = 'ddp' + +pp_size = 1 +tp_size = 1 + +resume_checkpoint = false +random_init = false + +ac_mode = 'off' + +compile = false +async_tp = false + +[data] +data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" + +packing_buffer_size = 100 + +seq_len = 512 +batch_size = 0 \ No newline at end of file diff --git a/configs/jupiter/qwen3_5_2b.toml b/configs/jupiter/qwen3_5_2b.toml index 2f4b2e4..cfa8502 100644 --- a/configs/jupiter/qwen3_5_2b.toml +++ b/configs/jupiter/qwen3_5_2b.toml @@ -17,27 +17,33 @@ model_dir = "/e/project1/reformo/ockier1/qwen_models/qwen3_5_2b" output_dir = "/e/scratch/reformo/ockier1/checkpoints/test_35_2b" tpi_multiplier = 1.0 -save_steps = 1000 +save_steps = 200 scheduler_type = "cosine" -total_steps = 18000 -warmup_steps = 100 +total_steps = 500 +warmup_steps = 10 #wsd_decay_ratio = 0.1 min_lr_ratio = 0.1 lr_llm = 0.00002 lr_mlp = 0.0001 -data_parallel = 'fsdp' +data_parallel = 'ddp' tp_size = 1 resume_checkpoint = false -random_init_mlp = false +random_init = true -compile = true +compile = false [data] -data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" +data_path = "/e/data1/datasets/products/llava_onevision_mid_training_85m/imagenet/EN" seq_len = 8192 -batch_size = 66 \ No newline at end of file + +shuffle_buffer_size = 1000 +packing_buffer_size = 1000 +max_samples_per_sequence = 100 + +batch_size = 0 + diff --git a/configs/jupiter/qwen3_5_9b.toml b/configs/jupiter/qwen3_5_9b.toml index 8b145b2..dbe49d4 100644 --- a/configs/jupiter/qwen3_5_9b.toml +++ b/configs/jupiter/qwen3_5_9b.toml @@ -28,18 +28,22 @@ lr_llm = 0.00002 lr_mlp = 0.0001 data_parallel = 'fsdp' -tp_size = 4 + +pp_size = 4 +tp_size = 1 resume_checkpoint = false -random_init_mlp = false +random_init = false ac_mode = 'off' -compile = true +compile = false async_tp = false [data] data_path = "/e/project1/jureap59/ockier1/datasets/cap_pretrain" -seq_len = 10240 -batch_size = 64 \ No newline at end of file +packing_buffer_size = 100 + +seq_len = 8192 +batch_size = 0 \ No newline at end of file diff --git a/data/energon_dataloader.py b/data/energon_dataloader.py index a60422e..8f139ff 100644 --- a/data/energon_dataloader.py +++ b/data/energon_dataloader.py @@ -93,6 +93,28 @@ class EnergonSample(Sample): image: torch.Tensor messages: list +@stateless +def cooker_llava_imagenet(sample: dict, add_system_prompt: bool = True) -> EnergonSample: + messages = [ + {'role': 'user', 'content': [ + {"type": "image"} + ]}, + {'role': 'assistant', 'content': [ + {"type": "text", "text": sample['txt']} + ]}, + ] + + if not add_system_prompt: + messages.append({"role": "system", "content": [{"type": "text", "text": ""}]}) + + image = sample['jpg'] + + return EnergonSample( + **basic_sample_keys(sample), + image=image, + messages=messages, + ) + @stateless def cooker_captioning(sample: dict, add_system_prompt: bool = True) -> EnergonSample: role_map = {'human': 'user', 'gpt': 'assistant', 'user': 'user', 'assistant': 'assistant'} @@ -254,11 +276,19 @@ def __init__(self, processor, max_seq_len): self.assistant_token = self.tokenizer.encode("assistant")[0] self.EOS_token = self.tokenizer.eos_token_id + """ cookers = [ # subflavors can be used to distinguish datasets when using a Metadataset - Cooker(cooker_captioning), + Cooker(cooker_captioning, has_subflavors={"type_dataset": "synth"}), + Cooker(cooker_llava_imagenet, has_subflavors={"type_dataset": "llava_onevision_midtraining"}), ] + """ + cookers = [ + # subflavors can be used to distinguish datasets when using a Metadataset + Cooker(cooker_captioning), + Cooker(cooker_llava_imagenet, has_subflavors={"type_dataset": "llava_onevision_midtraining"}), + ] # transform the RAW data, tokenize a single sample @stateless(restore_seeds=True) def encode_sample(self, sample: EnergonSample) -> EncodedSample: diff --git a/models/qwen3_5/config.py b/models/qwen3_5/config.py index 4acbc46..00f6804 100644 --- a/models/qwen3_5/config.py +++ b/models/qwen3_5/config.py @@ -5,7 +5,7 @@ class Qwen3_5TextConfig: vocab_size: int hidden_size: int - intermediate_size: int + moe_intermediate_size: int num_hidden_layers: int num_attention_heads: int num_key_value_heads: int @@ -14,6 +14,12 @@ class Qwen3_5TextConfig: rms_norm_eps: float tie_word_embeddings: bool + # moe + num_experts: int + num_experts_per_tok: int + router_aux_loss_coef: float + shared_expert_intermediate_size: int + # linear attention layer_types: list[str] full_attention_interval: int @@ -64,7 +70,7 @@ def from_json(cls, path: str) -> "Qwen3VLConfig": text = Qwen3_5TextConfig( vocab_size=tc["vocab_size"], hidden_size=tc["hidden_size"], - intermediate_size=tc["intermediate_size"], + moe_intermediate_size=tc["moe_intermediate_size"], num_hidden_layers=tc["num_hidden_layers"], num_attention_heads=tc["num_attention_heads"], num_key_value_heads=tc["num_key_value_heads"], @@ -81,7 +87,11 @@ def from_json(cls, path: str) -> "Qwen3VLConfig": mtp_num_hidden_layers=tc['mtp_num_hidden_layers'], mtp_use_dedicated_embeddings=tc['mtp_use_dedicated_embeddings'], tie_word_embeddings=tc.get("tie_word_embeddings", raw.get("tie_word_embeddings", False)), - rope_parameters=tc['rope_parameters'] + rope_parameters=tc['rope_parameters'], + num_experts=tc['num_experts'], + num_experts_per_tok=tc['num_experts_per_tok'], + router_aux_loss_coef=tc['router_aux_loss_coef'], + shared_expert_intermediate_size=tc['shared_expert_intermediate_size'], ) vc = raw["vision_config"] vision = Qwen3_5VisionConfig( diff --git a/models/qwen3_5/dispatcher.py b/models/qwen3_5/dispatcher.py new file mode 100644 index 0000000..343ef67 --- /dev/null +++ b/models/qwen3_5/dispatcher.py @@ -0,0 +1,211 @@ +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True + + +class AllToAllDispatchMetadata: + def __init__(self, token_indices_experts_sorted, top_scores_experts_sorted, input_shape, permuted_indices, input_splits, output_splits): + self.token_indices_experts_sorted = token_indices_experts_sorted + self.top_scores_experts_sorted = top_scores_experts_sorted + self.input_shape = input_shape + self.permuted_indices = permuted_indices + self.input_splits = input_splits + self.output_splits = output_splits + + +class _AllToAllSingleAutograd(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, output_splits, input_splits, group): + ctx.group = group + ctx.input_splits = input_splits + ctx.output_splits = output_splits + + out_total = int(sum(output_splits)) if output_splits else 0 + out = torch.empty((out_total, input_.size(1)), dtype=input_.dtype, device=input_.device) + + if group is not None: + dist.all_to_all_single(out, input_, output_splits, input_splits, group=group.get_group()) + return out + + @staticmethod + def backward(ctx, grad_output): + in_total = int(sum(ctx.input_splits)) if ctx.input_splits else 0 + grad_input = torch.empty((in_total, grad_output.size(1)), dtype=grad_output.dtype, device=grad_output.device) + + if ctx.group is not None: + dist.all_to_all_single(grad_input, grad_output, ctx.input_splits, ctx.output_splits, group=ctx.group.get_group()) + return grad_input, None, None, None + + +def all_to_all_single_autograd(input_, output_splits, input_splits, group): + return _AllToAllSingleAutograd.apply(input_, output_splits, input_splits, group) + + +class _AllReduceForward(torch.autograd.Function): + """All-reduce SUM in forward (partial → replicate). Identity in backward. + + Apply at the *output* of a TP-local computation whose value is a per-rank partial + sum. After the forward all-reduce all ranks hold the full sum. In backward, each + rank already sees the same upstream gradient (the output is replicated), so + no further communication is needed. + """ + + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + if group is None: + return input_ + out = input_.contiguous().clone() + dist.all_reduce(out, op=dist.ReduceOp.SUM, group=group) + return out + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _AllReduceBackward(torch.autograd.Function): + """Identity in forward. All-reduce SUM in backward (per-rank partial → replicate gradient). + + Apply at the *input* of a TP-local computation. Each rank's autograd produces a + per-rank partial gradient (because each rank used a different weight slice + downstream); summing across TP gives the full Replicate gradient that the + upstream graph expects. + """ + + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.group is None: + return grad_output, None + g = grad_output.contiguous().clone() + dist.all_reduce(g, op=dist.ReduceOp.SUM, group=ctx.group) + return g, None + + +def all_reduce_forward(input_, group): + return _AllReduceForward.apply(input_, group) + + +def all_reduce_backward(input_, group): + return _AllReduceBackward.apply(input_, group) + + +class TokenDispatcher: + """Consolidated EP/SP dispatcher. Handles local token reorder and all-to-all.""" + + def __init__(self, num_experts: int, top_k: int, score_before_experts: bool = True): + self.num_experts = num_experts + self.top_k = top_k + self.score_before_experts = score_before_experts + + self.ep_mesh: DeviceMesh | None = None + self.sp_size: int = 1 + self.sp_rank: int = -1 + + def _split_along_sp(self, *tensors: torch.Tensor) -> list[torch.Tensor]: + results = [] + for t in tensors: + local_num_tokens = t.shape[0] // self.sp_size + offset = self.sp_rank * local_num_tokens + results.append(t[offset : offset + local_num_tokens]) + return results + + def _permute(self, routed_input, num_tokens_per_expert_group, ep_size, num_local_experts): + device = num_tokens_per_expert_group.device + total = num_tokens_per_expert_group.sum().item() + + t_mat = num_tokens_per_expert_group.view(ep_size, num_local_experts) + input_starts = (num_tokens_per_expert_group.cumsum(0) - num_tokens_per_expert_group).view(ep_size, num_local_experts) + + segment_lens = t_mat.t().reshape(-1) + input_starts = input_starts.t().reshape(-1) + + seg_ids = torch.arange(segment_lens.shape[0], device=device).repeat_interleave(segment_lens.long()) + output_starts = segment_lens.cumsum(0) - segment_lens + permuted_indices = (input_starts[seg_ids] + torch.arange(total, device=device) - output_starts[seg_ids]).long() + + num_tokens_per_expert = t_mat.sum(0) + return routed_input.shape, routed_input[permuted_indices, :], permuted_indices, num_tokens_per_expert + + def _unpermute(self, routed_output, input_shape, permuted_indices): + # Empty path (rank received 0 tokens): pass through to preserve the autograd + # edge into the combine A2A so its backward fires on every rank in the EP group. + if routed_output.shape[0] == 0: + return routed_output + out_unpermuted = routed_output.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = routed_output + return out_unpermuted + + def dispatch(self, x: torch.Tensor, top_scores: torch.Tensor, selected_experts_indices: torch.Tensor): + if self.sp_size > 1: + x, top_scores, selected_experts_indices = self._split_along_sp(x, top_scores, selected_experts_indices) + + num_tokens_per_expert = torch.bincount( + selected_experts_indices.view(-1), + minlength=self.num_experts, + ).float() + + token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + routed_input = x[token_indices_experts_sorted] + + if self.score_before_experts: + routed_input = (routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1)).to(x.dtype) + + # Skip all-to-all logic entirely if ep_mesh is missing (EP=1) + if self.ep_mesh is None: + metadata = AllToAllDispatchMetadata( + token_indices_experts_sorted, top_scores_experts_sorted, + None, None, None, None + ) + return routed_input, num_tokens_per_expert, metadata + + ep_size = self.ep_mesh.size() + # equal-split all-to-all: each rank sends num_local_experts counts to every rank, + # receiving num_local_experts from every rank → output same size as input. + num_tokens_per_expert_group = torch.empty_like(num_tokens_per_expert) + dist.all_to_all_single(num_tokens_per_expert_group, num_tokens_per_expert, group=self.ep_mesh.get_group()) + + input_splits = [int(x) for x in num_tokens_per_expert.long().view(ep_size, -1).sum(dim=1).tolist()] + output_splits = [int(x) for x in num_tokens_per_expert_group.long().view(ep_size, -1).sum(dim=1).tolist()] + + routed_input = all_to_all_single_autograd(routed_input, output_splits, input_splits, self.ep_mesh) + + num_local_experts = num_tokens_per_expert_group.shape[0] // ep_size + input_shape, routed_input, permuted_indices, num_tokens_per_expert_group = self._permute( + routed_input, num_tokens_per_expert_group, ep_size, num_local_experts + ) + + metadata = AllToAllDispatchMetadata( + token_indices_experts_sorted, top_scores_experts_sorted, + input_shape, permuted_indices, input_splits, output_splits + ) + return routed_input, num_tokens_per_expert_group, metadata + + def combine(self, routed_output: torch.Tensor, metadata: "AllToAllDispatchMetadata", x: torch.Tensor, shared_experts: torch.nn.Module | None = None) -> torch.Tensor: + if self.ep_mesh is not None: + routed_output = self._unpermute(routed_output, metadata.input_shape, metadata.permuted_indices) + routed_output = all_to_all_single_autograd(routed_output, metadata.input_splits, metadata.output_splits, self.ep_mesh) + + out = shared_experts(x) if shared_experts is not None else torch.zeros_like(x) + + if not self.score_before_experts: + routed_output = (routed_output.to(torch.float32) * metadata.top_scores_experts_sorted.reshape(-1, 1)).to(routed_output.dtype) + + token_indices_experts_sorted = metadata.token_indices_experts_sorted + if self.sp_size > 1: + local_num_tokens = x.shape[0] // self.sp_size + token_indices_experts_sorted = token_indices_experts_sorted + local_num_tokens * self.sp_rank + + out.index_add_(0, token_indices_experts_sorted, routed_output) + return out diff --git a/models/qwen3_5/model.py b/models/qwen3_5/model.py index 5db7f2b..076c640 100644 --- a/models/qwen3_5/model.py +++ b/models/qwen3_5/model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from pathlib import Path import torch @@ -230,6 +231,162 @@ def __init__(self, cfg: Qwen3_5TextConfig): def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) +class MoeMLP(nn.Module): + def __init__(self, config: Qwen3_5MoeConfig, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = F.silu + + def forward(self, x): + down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +class MoeExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = F.silu + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = F.linear(current_state, self.gate_up_proj[expert_idx]) + gate, up = gate_up.chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + def forward_ep(self, routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + """Forward for EP mode: routed_input is already sorted/dispatched to local experts. + + With EP+TP, gate_up_proj and down_proj are sharded along the intermediate + dimension across TP. The input is Replicate across TP and the output is + Replicate after a forward all-reduce. The backward all-reduce on the input + is required so the gradient flowing back into the dispatch A2A is also + Replicate across TP. + """ + tp_mesh = getattr(self, 'tp_mesh', None) + use_tp = tp_mesh is not None and tp_mesh.size() > 1 + if use_tp: + from models.qwen3_5.dispatcher import all_reduce_backward + routed_input = all_reduce_backward(routed_input, tp_mesh.get_group()) + + num_local_experts = self.gate_up_proj.shape[0] + offsets = torch.zeros(num_local_experts + 1, dtype=torch.long, device=routed_input.device) + offsets[1:] = num_tokens_per_expert.long().cumsum(0) + + outputs = [] + for i in range(num_local_experts): + start, end = int(offsets[i]), int(offsets[i + 1]) + if end > start: + chunk = routed_input[start:end] + gate_up = F.linear(chunk, self.gate_up_proj[i]) + gate, up = gate_up.chunk(2, dim=-1) + outputs.append(F.linear(self.act_fn(gate) * up, self.down_proj[i])) + + if outputs: + result = torch.cat(outputs, dim=0) + else: + # Empty path: preserve autograd connection to routed_input so the EP A2A + # backward fires on this rank too (otherwise other ranks hang on the unmatched collective). + result = routed_input[:0] + + if use_tp: + from models.qwen3_5.dispatcher import all_reduce_forward + result = all_reduce_forward(result, tp_mesh.get_group()) + + return result + +class TopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices + +class MoE(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = TopKRouter(config) + self.experts = MoeExperts(config) + self.shared_expert = MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + self.dispatcher = None # set by apply_ep() when EP > 1 + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + N = batch_size * sequence_length + num_experts = self.experts.num_experts + + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + + router_logits, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + + if self.dispatcher is not None: + routed_input, num_tokens_per_expert, metadata = self.dispatcher.dispatch( + hidden_states_reshaped, routing_weights, selected_experts + ) + routed_output = self.experts.forward_ep(routed_input, num_tokens_per_expert) + expert_output = self.dispatcher.combine(routed_output, metadata, hidden_states_reshaped) + else: + expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + + shared_expert_output = self.shared_expert(hidden_states_reshaped) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output + expert_output = expert_output + shared_expert_output + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts) + tokens_per_expert = expert_mask.sum(dim=(0, 1), dtype=torch.float) + fraction_tokens = tokens_per_expert / (N * self.gate.top_k) + + router_probs = torch.nn.functional.softmax(router_logits, dim=-1).sum(dim=0) + fraction_probs = router_probs / N + + aux_loss = num_experts * torch.sum(fraction_tokens * fraction_probs) + + dummy = (self.experts.gate_up_proj * 0.0).sum() + (self.experts.down_proj * 0.0).sum() + aux_loss = aux_loss + dummy.to(aux_loss.dtype) + + return expert_output.reshape(batch_size, sequence_length, hidden_dim), aux_loss + class DecoderLayer(nn.Module): def __init__(self, cfg: Qwen3_5TextConfig, layer_type: str): super().__init__() @@ -239,7 +396,7 @@ def __init__(self, cfg: Qwen3_5TextConfig, layer_type: str): else: self.linear_attn = GatedDeltaNet(cfg) - self.mlp = MLP(cfg) + self.mlp = MoE(cfg) self.input_layernorm = OffsetRMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) self.post_attention_layernorm = OffsetRMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) @@ -254,8 +411,9 @@ def forward(self, x, cos, sin, cu_seqlens, max_seqlen): cu_seqlens=cu_seqlens, max_seqlen=max_seqlen ) - x = x + self.mlp(self.post_attention_layernorm(x)) - return x + mlp_out, aux_loss = self.mlp(self.post_attention_layernorm(x)) + x = x + mlp_out + return x, aux_loss class LanguageModel(nn.Module): """HF name: `model.language_model`.""" @@ -286,14 +444,18 @@ def forward( deepstack_visual_embeds: list[torch.Tensor] | None = None, ) -> torch.Tensor: x = inputs_embeds + + total_aux_loss = 0 for i, layer in enumerate(self.layers): - x = layer(x, cos, sin, cu_seqlens, max_seqlen) + x, aux_loss = layer(x, cos, sin, cu_seqlens, max_seqlen) + total_aux_loss += aux_loss if deepstack_visual_embeds is not None and i < len(deepstack_visual_embeds): x = x.clone() x[visual_pos_masks] = ( x[visual_pos_masks] + deepstack_visual_embeds[i].to(x.dtype) ) - return self.norm(x) + + return self.norm(x) if self.norm is not None else x, total_aux_loss class VisionPatchEmbed(nn.Module): def __init__(self, cfg: Qwen3_5VisionConfig): @@ -548,10 +710,14 @@ def forward( merged = self.merger(hidden_states) return merged, deepstack +class Qwen3_5InnerLanguage(nn.Module): + def __init__(self, cfg: Qwen3VLConfig): + super().__init__() + self.language_model = LanguageModel(cfg.text) + class Qwen3_5Inner(nn.Module): """HF name: `model`. Groups `language_model` and `visual`. This is only used to match the state keys. """ - def __init__(self, cfg: Qwen3VLConfig): super().__init__() self.language_model = LanguageModel(cfg.text) @@ -659,9 +825,10 @@ def _compute_cos_sin(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, to def forward( self, - input_ids: torch.Tensor | None = None, + hidden_states: torch.Tensor | None = None, + prev_aux_loss: torch.Tensor | None = None, *, - inputs_embeds: torch.Tensor | None = None, + input_ids: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, pixel_values_videos: torch.Tensor | None = None, @@ -679,17 +846,18 @@ def forward( (same tensor consumed by `torch.nn.attention.varlen.varlen_attn`). If `attention_mask` is None, the whole row is treated as one sample. """ - assert (input_ids is None) ^ (inputs_embeds is None) - if input_ids is not None and input_ids.dim() == 1: - input_ids = input_ids.unsqueeze(0) - if input_ids is not None: - assert input_ids.dim() == 2 and input_ids.shape[0] == 1, ( - f"varlen expects packed (1, total), got {tuple(input_ids.shape)}" - ) - - if inputs_embeds is None: + if getattr(self.model.language_model, "embed_tokens", None) is not None: + input_ids = hidden_states inputs_embeds = self.model.language_model.embed_tokens(input_ids) - assert inputs_embeds.dim() == 3 and inputs_embeds.shape[0] == 1 + else: + inputs_embeds = hidden_states + if input_ids is None: + raise ValueError("input_ids must be passed to intermediate stages for MRoPE calculation.") + + assert inputs_embeds.dim() == 3 and inputs_embeds.shape[0] == 1, ( + f"inputs_embeds should be (1, total, hidden_dim), got {tuple(inputs_embeds.shape)}" + ) + total = inputs_embeds.shape[1] device = inputs_embeds.device @@ -708,43 +876,42 @@ def forward( visual_pos_masks: torch.Tensor | None = None deepstack_visual_embeds: list[torch.Tensor] | None = None - if pixel_values is not None: - assert image_grid_thw is not None - merged, deepstack = self.model.visual(pixel_values, image_grid_thw) - merged = merged.to(inputs_embeds.dtype) - image_mask = input_ids == self.cfg.image_token_id - assert image_mask.sum().item() == merged.shape[0], ( - f"image tokens={image_mask.sum().item()} vs features={merged.shape[0]}" - ) - inputs_embeds = inputs_embeds.masked_scatter( - image_mask.unsqueeze(-1).expand_as(inputs_embeds), merged - ) - visual_pos_masks = image_mask - deepstack_visual_embeds = deepstack - - if pixel_values_videos is not None: - assert video_grid_thw is not None - merged_v, deepstack_v = self.model.visual(pixel_values_videos, video_grid_thw) - merged_v = merged_v.to(inputs_embeds.dtype) - video_mask = input_ids == self.cfg.video_token_id - inputs_embeds = inputs_embeds.masked_scatter( - video_mask.unsqueeze(-1).expand_as(inputs_embeds), merged_v - ) - if visual_pos_masks is None: - visual_pos_masks = video_mask - deepstack_visual_embeds = deepstack_v - else: - combined = visual_pos_masks | video_mask - image_only = visual_pos_masks[combined] - video_only = video_mask[combined] - merged_ds = [] - for img_ds, vid_ds in zip(deepstack_visual_embeds, deepstack_v): - e = img_ds.new_zeros(combined.sum().item(), img_ds.shape[-1]) - e[image_only] = img_ds - e[video_only] = vid_ds - merged_ds.append(e) - visual_pos_masks = combined - deepstack_visual_embeds = merged_ds + if getattr(self.model, "visual", None) is not None: + if pixel_values is not None: + assert image_grid_thw is not None + merged, deepstack = self.model.visual(pixel_values, image_grid_thw) + merged = merged.to(inputs_embeds.dtype) + image_mask = input_ids == self.cfg.image_token_id + #assert image_mask.sum().item() == merged.shape[0], (f"image tokens={image_mask.sum().item()} vs features={merged.shape[0]}") + inputs_embeds = inputs_embeds.masked_scatter( + image_mask.unsqueeze(-1).expand_as(inputs_embeds), merged + ) + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack + + if pixel_values_videos is not None: + assert video_grid_thw is not None + merged_v, deepstack_v = self.model.visual(pixel_values_videos, video_grid_thw) + merged_v = merged_v.to(inputs_embeds.dtype) + video_mask = input_ids == self.cfg.video_token_id + inputs_embeds = inputs_embeds.masked_scatter( + video_mask.unsqueeze(-1).expand_as(inputs_embeds), merged_v + ) + if visual_pos_masks is None: + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_v + else: + combined = visual_pos_masks | video_mask + image_only = visual_pos_masks[combined] + video_only = video_mask[combined] + merged_ds = [] + for img_ds, vid_ds in zip(deepstack_visual_embeds, deepstack_v): + e = img_ds.new_zeros(combined.sum().item(), img_ds.shape[-1]) + e[image_only] = img_ds + e[video_only] = vid_ds + merged_ds.append(e) + visual_pos_masks = combined + deepstack_visual_embeds = merged_ds if position_ids is None: if image_grid_thw is not None or video_grid_thw is not None: @@ -766,7 +933,7 @@ def forward( cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) - h = self.model.language_model( + h, total_aux_loss = self.model.language_model( inputs_embeds, cos, sin, @@ -775,14 +942,33 @@ def forward( visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, ) - logits = self.lm_head(h) - if labels is None: - return logits - if labels.dim() == 1: - labels = labels.unsqueeze(0) - loss = causal_lm_loss(logits, labels) - return CausalLMOutput(loss=loss, logits=logits) + if hasattr(total_aux_loss, "to_local"): + total_aux_loss = total_aux_loss.to_local() + + if prev_aux_loss is not None: + if hasattr(prev_aux_loss, "to_local"): + prev_aux_loss = prev_aux_loss.to_local() + + total_aux_loss = total_aux_loss + prev_aux_loss + + if isinstance(total_aux_loss, (int, float)): + total_aux_loss = torch.tensor([total_aux_loss], device=h.device, dtype=h.dtype) + elif total_aux_loss.dim() == 0: + total_aux_loss = total_aux_loss.unsqueeze(0) + + if prev_aux_loss is not None: + total_aux_loss = total_aux_loss + prev_aux_loss + + if isinstance(total_aux_loss, (int, float)): + total_aux_loss = torch.tensor([total_aux_loss], device=h.device, dtype=h.dtype) + elif total_aux_loss.dim() == 0: + total_aux_loss = total_aux_loss.unsqueeze(0) + + if getattr(self, "lm_head", None) is not None: + return self.lm_head(h), total_aux_loss + else: + return h, total_aux_loss @classmethod def from_pretrained( @@ -792,6 +978,7 @@ def from_pretrained( device: str | torch.device = "cpu", *, load_vision: bool = True, + weights: bool = True, ) -> "Qwen3_5ForCausalLM": snapshot_dir = Path(snapshot_dir) cfg = Qwen3VLConfig.from_json(snapshot_dir / "config.json") @@ -804,19 +991,56 @@ def from_pretrained( with torch.device("meta"): model = cls(cfg) - model = model.to_empty(device=device).to(dtype=dtype) - - load_safetensors_into( - model, - snapshot_dir, - device=device, - dtype=dtype, - load_vision=load_vision, - ) - # `to_empty` above re-materializes every parameter and breaks the - # tie established in `__init__`. Re-tie here so `lm_head` (absent - # from checkpoints when tied) shares storage with the embedding. + model = model.to_empty(device='cuda').to(dtype=dtype) + if False: + load_safetensors_into( + model, + snapshot_dir, + device=device, + dtype=dtype, + load_vision=load_vision, + ) + else: + with torch.no_grad(): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, torch.nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, torch.nn.Conv3d) or isinstance(module, torch.nn.Conv1d): + torch.nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + elif "Norm" in module.__class__.__name__: + if hasattr(module, "weight") and module.weight is not None: + if "Offset" in module.__class__.__name__: + torch.nn.init.zeros_(module.weight) + else: + torch.nn.init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + elif "MoeExperts" in module.__class__.__name__: + torch.nn.init.normal_(module.gate_up_proj, mean=0.0, std=0.02) + torch.nn.init.normal_(module.down_proj, mean=0.0, std=0.02) + + elif "TopKRouter" in module.__class__.__name__: + # Router weights are typically initialized to 0 or very small values + torch.nn.init.zeros_(module.weight) + + elif "RMSNormGated" in module.__class__.__name__: + if hasattr(module, "weight") and module.weight is not None: + torch.nn.init.ones_(module.weight) + + elif "OffsetRMSNorm" in module.__class__.__name__: + if hasattr(module, "weight") and module.weight is not None: + # Offset norm weights start at 0 + torch.nn.init.zeros_(module.weight) + if cfg.tie_word_embeddings: model.lm_head.weight = model.model.language_model.embed_tokens.weight @@ -830,7 +1054,6 @@ def from_pretrained( ) model.model.visual.rotary_pos_emb.inv_freq = inv_freq_v - # Recompute text inv_freq (non-persistent buffer wiped by `to_empty`). head_dim = cfg.text.head_dim partial = cfg.text.rope_parameters.get('partial_rotary_factor', 1.0) rope_dim = int(head_dim * partial) @@ -841,3 +1064,31 @@ def from_pretrained( model.text_inv_freq = text_inv return model + +@torch.no_grad() +def initialize_missing_weights(model): + for module in model.modules(): + if hasattr(module, 'reset_parameters'): + module.reset_parameters() + + elif isinstance(module, OffsetRMSNorm): + torch.nn.init.zeros_(module.weight) + elif isinstance(module, RMSNormGated): + torch.nn.init.ones_(module.weight) + elif isinstance(module, GatedDeltaNet): + torch.nn.init.zeros_(module.A_log) + torch.nn.init.ones_(module.dt_bias) + + elif "MoeExperts" in module.__class__.__name__: + torch.nn.init.normal_(module.gate_up_proj, mean=0.0, std=0.02) + torch.nn.init.normal_(module.down_proj, mean=0.0, std=0.02) + elif "TopKRouter" in module.__class__.__name__: + torch.nn.init.zeros_(module.weight) + + for name, param in model.named_parameters(): + if torch.isnan(param).any() or torch.isinf(param).any(): + print(f"WARNING: Fallback init applied to missed parameter: {name}") + if param.dim() >= 2: + torch.nn.init.normal_(param, mean=0.0, std=0.02) + else: + torch.nn.init.zeros_(param) diff --git a/models/qwen3_5/utils.py b/models/qwen3_5/utils.py index 14bf20f..05bb74f 100644 --- a/models/qwen3_5/utils.py +++ b/models/qwen3_5/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -52,7 +53,6 @@ def causal_lm_loss( labels: torch.Tensor, ignore_index: int = -100, ) -> torch.Tensor: - # Match HF ForCausalLMLoss: upcast to fp32 before CE to avoid bf16 precision issues. shift_logits = logits[..., :-1, :].contiguous().float() shift_labels = labels[..., 1:].contiguous() return F.cross_entropy( @@ -195,3 +195,73 @@ def load_safetensors_into( missing = {m for m in missing if not m.startswith("model.visual.")} if missing: raise RuntimeError(f"Missing weights after load: {sorted(missing)[:8]} ... ({len(missing)} total)") + +def load_stage_weights( + stage: nn.Module, + snapshot_dir: Path, + layer_start: int, + layer_end: int, + is_first: bool, + is_last: bool, + device: torch.device | str, + dtype: torch.dtype, +) -> None: + from models.qwen3_5.config import Qwen3VLConfig + + snapshot_dir = Path(snapshot_dir) + cfg = Qwen3VLConfig.from_json(snapshot_dir / "config.json") + + stage_to_ckpt: dict[str, str] = {} + for key in stage.state_dict(): + if "inv_freq" in key: + continue + + if key.startswith("model.language_model.layers."): + rest = key[len("model.language_model.layers."):] + i_str, suffix = rest.split(".", 1) + global_layer_idx = layer_start + int(i_str) + stage_to_ckpt[key] = f"model.language_model.layers.{global_layer_idx}.{suffix}" + + else: + stage_to_ckpt[key] = key + + index_path = snapshot_dir / "model.safetensors.index.json" + if index_path.exists(): + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + shard_to_keys: dict[str, list[str]] = defaultdict(list) + for ckpt_key in stage_to_ckpt.values(): + if ckpt_key in weight_map: + shard_to_keys[weight_map[ckpt_key]].append(ckpt_key) + else: + single = snapshot_dir / "model.safetensors" + assert single.exists(), f"No safetensors found in {snapshot_dir}" + shard_to_keys = {single.name: list(stage_to_ckpt.values())} + + ckpt_tensors: dict[str, torch.Tensor] = {} + for shard_name, keys in shard_to_keys.items(): + with safe_open(str(snapshot_dir / shard_name), framework="pt", device=str(device)) as f: + for k in keys: + ckpt_tensors[k] = f.get_tensor(k).to(dtype=dtype) + + stage_sd = {sk: ckpt_tensors[ck] for sk, ck in stage_to_ckpt.items() if ck in ckpt_tensors} + + missing = set(stage_to_ckpt.keys()) - set(stage_sd.keys()) + if missing: + raise RuntimeError(f"Missing weights for stage: {list(missing)[:5]}... ({len(missing)} total)") + + stage.load_state_dict(stage_sd, assign=True, strict=False) + + if is_first: + head_dim = cfg.text.head_dim + partial = cfg.text.rope_parameters.get("partial_rotary_factor", 1.0) + rope_dim = int(head_dim * partial) + stage.text_inv_freq = 1.0 / ( + cfg.text.rope_parameters["rope_theta"] + ** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=device) / rope_dim) + ) + head_dim_v = cfg.vision.hidden_size // cfg.vision.num_heads + rdim = head_dim_v // 2 + stage.model.visual.rotary_pos_emb.inv_freq = 1.0 / ( + 10000.0 ** (torch.arange(0, rdim, 2, dtype=torch.float32, device=device) / rdim) + ) diff --git a/scripts/jup_finetune.sh b/scripts/jup_finetune.sh index 35bcb04..7d50387 100755 --- a/scripts/jup_finetune.sh +++ b/scripts/jup_finetune.sh @@ -2,7 +2,6 @@ MASTER_ADDR="127.0.0.1" MASTER_PORT=$(shuf -i 20000-29999 -n 1) -NGPUS=$(nvidia-smi --list-gpus | wc -l) export WANDB_MODE=offline export HF_HUB_OFFLINE=1 @@ -23,7 +22,7 @@ ulimit -s unlimited torchrun \ --nnodes=1 \ - --nproc_per_node=$NGPUS \ + --nproc_per_node=4 \ --rdzv_id 101 \ --rdzv_backend c10d \ --rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ diff --git a/train/config.py b/train/config.py index 6dcde76..321f32f 100644 --- a/train/config.py +++ b/train/config.py @@ -89,21 +89,35 @@ class Training: min_lr_ratio: float = 0.1 # --------------- + # torch dynamo compiler + compile: bool = True + """ + Always on by default, unless you have an error. + """ + +@dataclass +class Parallel: data_parallel: str = "ddp" # fsdp, ddp - tp_size: int = 1 # 1 means disabled """ Use `fsdp` when you want to decrease usage to increase seq_len/batch_size. """ + tp_size: int = 1 # 1 means disabled + pp_size: int = 1 # 1 means disabled; supported values: 2, 4 + ep_size: int = 1 # 1 means disabled; must divide num_experts evenly + + pp_num_layers_first: int = 1 + pp_num_layers_last: int = 1 + + # Pipeline schedule: "gpipe" or "1f1b". Single-stage-per-rank schedules only. + pp_schedule: str = "gpipe" + # Number of microbatches per optimizer step. Must be >= pp_size for 1F1B + # to actually pipeline (smaller values degrade to GPipe-like behavior). + pp_microbatches: int = 1 + # compiler flag for TP (goes faster) async_tp: bool = True - # torch dynamo compiler - compile: bool = True - """ - Always on by default, unless you have an error. - """ - # activation checkpointing ac_mode: str = "off" """ @@ -156,5 +170,6 @@ class Config: model: Model = field(default_factory=Model) data: Data = field(default_factory=Data) wandb: Wandb = field(default_factory=Wandb) + parallel: Parallel = field(default_factory=Parallel) config: str = '/home/tockier/vlm-training/configs/cvc_config.toml' diff --git a/train/infra.py b/train/infra.py index 55aa03f..adafa68 100644 --- a/train/infra.py +++ b/train/infra.py @@ -1,10 +1,14 @@ from dataclasses import dataclass from functools import partial +from typing import Optional from train.config import ModelType +from train.logger import logger import torch import torch._inductor.config +import torch.distributed as dist +import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import Replicate, Shard @@ -43,6 +47,12 @@ CheckpointPolicy, create_selective_checkpoint_contexts, ) +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.microbatch import _Replicate +from torch.distributed.pipelining.schedules import Schedule1F1B, ScheduleGPipe + +# used for PP +from models.qwen3_5.utils import causal_lm_loss, load_stage_weights class NoParallel(ParallelStyle): def __init__( @@ -112,29 +122,121 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ), ) -def get_mesh(training_args, world_size): - """ - Creates a 2D DeviceMesh based on tp_size and world_size. - Always returns ('dp', 'tp'). - """ - tp_size = training_args.tp_size - - if world_size % tp_size != 0: - raise ValueError(f"World size {world_size} is not divisible by TP size {tp_size}") +def get_mesh(parallel_args, world_size): + tp = getattr(parallel_args, "tp_size", 1) + pp = getattr(parallel_args, "pp_size", 1) + ep = getattr(parallel_args, "ep_size", 1) - dp_size = world_size // tp_size + assert world_size % (tp * pp) == 0, f"world_size not divisible by tp*pp" + dp = world_size // (tp * pp) + + if ep > 1 or dp == "fsdp": + dp_shard = dp + dp_replicate = 1 + else: + dp_replicate = dp + dp_shard = 1 + + if ep > 1: + assert dp_shard % ep == 0, f"EP ({ep}) must divide dp_shard ({dp_shard})" + dp_mod_ep = dp_shard // ep + + mesh = init_device_mesh( + "cuda", + (pp, dp_replicate, dp_mod_ep, ep, tp), + mesh_dim_names=("pp", "dp_replicate", "dp_mod_ep", "ep", "tp") + ) + + mesh._flattened_submeshes = { + "dp": mesh["dp_replicate", "dp_mod_ep", "ep"]._flatten("dp"), + "dp_shard": mesh["dp_mod_ep", "ep"]._flatten("dp_shard"), + } + else: + mesh = init_device_mesh( + "cuda", + (pp, dp_replicate, dp_shard, tp), + mesh_dim_names=("pp", "dp_replicate", "dp_shard", "tp") + ) - return init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) + print(mesh) + print(mesh) + print(mesh) + + mesh._flattened_submeshes = { + "dp": mesh["dp_replicate", "dp_shard"]._flatten("dp") + } + + return mesh -def get_tp_group(mesh): - if "tp" in mesh.mesh_dim_names: - return mesh['tp'] +def get_mesh_group(mesh, dim_name: str): + if dim_name in mesh.mesh_dim_names: + return mesh[dim_name] + + if hasattr(mesh, "_flattened_submeshes") and dim_name in mesh._flattened_submeshes: + return mesh._flattened_submeshes[dim_name] + return None -def get_dp_group(mesh): - if "dp" in mesh.mesh_dim_names: - return mesh['dp'] - return None +def apply_ep(model, ep_mesh, tp_mesh=None): + """Slice expert parameters to the local subset and attach a TokenDispatcher. + + EP shards the routed experts along ``num_experts``. When ``tp_mesh`` is provided + (EP+TP), each rank additionally holds only its slice of ``moe_intermediate_size``; + the partial down-projection is all-reduced across TP at the end of ``forward_ep``. + The shared_expert is sharded by ``apply_tp`` separately. + """ + from models.qwen3_5.dispatcher import TokenDispatcher + + ep_rank = ep_mesh.get_local_rank() + ep_size = ep_mesh.size() + tp_rank = tp_mesh.get_local_rank() if tp_mesh is not None else 0 + tp_size = tp_mesh.size() if tp_mesh is not None else 1 + + lm = model.model.language_model + for layer in lm.layers: + moe = layer.mlp + experts = moe.experts + num_experts = experts.num_experts + moe_inter = experts.intermediate_dim + + if num_experts % ep_size != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by ep_size={ep_size}" + ) + if tp_size > 1 and moe_inter % tp_size != 0: + raise ValueError( + f"moe_intermediate_size={moe_inter} must be divisible by tp_size={tp_size}" + ) + + num_local = num_experts // ep_size + e_start, e_end = ep_rank * num_local, (ep_rank + 1) * num_local + local_inter = moe_inter // tp_size + i_start, i_end = tp_rank * local_inter, (tp_rank + 1) * local_inter + + # gate_up_proj: [E, 2*I, H] (the 2*I is laid out as [gate(I) | up(I)]). + # Take the EP slice, then within each of gate and up keep only this TP rank's + # I/tp slice and re-concat so the local layout stays [gate_local | up_local]. + gate_up = experts.gate_up_proj.data[e_start:e_end] + if tp_size > 1: + gate_part = gate_up[:, :moe_inter, :][:, i_start:i_end, :] + up_part = gate_up[:, moe_inter:, :][:, i_start:i_end, :] + gate_up = torch.cat([gate_part, up_part], dim=1) + experts.gate_up_proj = nn.Parameter(gate_up.contiguous()) + + # down_proj: [E, H, I] → [E_local, H, I/tp] + down = experts.down_proj.data[e_start:e_end, :, i_start:i_end] + experts.down_proj = nn.Parameter(down.contiguous()) + + # forward_ep needs the TP mesh to all-reduce the partial down-projection + experts.tp_mesh = tp_mesh if tp_size > 1 else None + + dispatcher = TokenDispatcher( + num_experts=num_experts, + top_k=moe.gate.top_k, + score_before_experts=True, + ) + dispatcher.ep_mesh = ep_mesh + moe.dispatcher = dispatcher def module_filter_float8_fn(mod: torch.nn.Module, fqn: str): if "visual" in fqn: @@ -238,7 +340,7 @@ def compile_model(model: torch.nn.Module): model.language_model = torch.compile(model.language_model, fullgraph=False, mode='default') model.visual = torch.compile(model.visual, fullgraph=False, mode='default') model.visual.merger = torch.compile(model.visual.merger, fullgraph=False, mode='default',) - #model = torch.compile(model, mode='default') + #model = torch.compile(model, mode='default') def apply_fsdp(model_type, model, **kwargs): if model_type == ModelType.Qwen3_text: @@ -417,7 +519,7 @@ def _apply_tp_to_decoder_qwen3_vl( desired_input_kwarg_layouts={ "hidden_states": Replicate(), }, - ), + ), "self_attn.q_proj": colwise_parallel(use_local_output=False), "self_attn.k_proj": colwise_parallel(use_local_output=False), "self_attn.v_proj": colwise_parallel(use_local_output=False), @@ -440,9 +542,6 @@ def _apply_tp_to_decoder_qwen3_vl( parallelize_plan=layer_plan, ) - if enable_async_tp: - torch._inductor.config._micro_pipeline_tp = True - def _register_tp_sum_hook(param, tp_mesh): """All-reduce SUM a parameter's grad on the TP process group. @@ -622,10 +721,11 @@ def _apply_tp_to_decoder_qwen3_5( } layer_plan.update({ - "mlp.gate_proj": colwise_parallel(), - "mlp.down_proj": rowwise_parallel(output_layouts=Replicate()), - "mlp.up_proj": colwise_parallel(), + "mlp.shared_expert.gate_proj": colwise_parallel(), + "mlp.shared_expert.down_proj": rowwise_parallel(output_layouts=Replicate()), + "mlp.shared_expert.up_proj": colwise_parallel(), }) + parallelize_module( module=transformer_block, device_mesh=tp_mesh, @@ -650,3 +750,160 @@ def _apply_tp_to_decoder_qwen3_5( if enable_async_tp: torch._inductor.config._micro_pipeline_tp = True + +def get_local_fqns( + num_layers: int, + pp_size: int, + pp_rank: int, + num_first: int, + num_last: int +) -> list[str]: + fqns = [] + if pp_rank == 0: + fqns.extend(["model.visual", "model.language_model.embed_tokens"]) + + if pp_size == 2: + mid_point = num_layers // 2 + (num_layers % 2) + start_idx = 0 if pp_rank == 0 else mid_point + end_idx = mid_point if pp_rank == 0 else num_layers + else: + if pp_rank == 0: + pass + elif pp_rank == 1: + start_idx, end_idx = 0, num_first + elif pp_rank == pp_size - 1: + start_idx, end_idx = num_layers - num_last, num_layers + else: + middle_layers = num_layers - num_first - num_last + middle_ranks = pp_size - 2 + + layers_per_mid = middle_layers // middle_ranks + remainder = middle_layers % middle_ranks + + mid_idx = pp_rank - 1 + start_idx = num_first + (mid_idx * layers_per_mid) + min(mid_idx, remainder) + num_layers_this_rank = layers_per_mid + (1 if mid_idx < remainder else 0) + end_idx = start_idx + num_layers_this_rank + + if pp_rank != 0: + fqns.extend([f"model.language_model.layers.{i}" for i in range(start_idx, end_idx)]) + + if pp_rank == pp_size - 1: + fqns.extend(["model.language_model.norm", "lm_head"]) + + return fqns + + +def apply_pp( + model, + mesh, + parallel_args, + training_args, + device, + pp_loss_fn, + ): + logger.info("Applying Pipeline Parallelism module split...") + pp_rank = mesh.get_local_rank(mesh_dim="pp") + total_layers = model.cfg.text.num_hidden_layers + pp_size = parallel_args.pp_size + + local_fqns = get_local_fqns( + num_layers=total_layers, + pp_size=pp_size, + pp_rank=pp_rank, + num_first= parallel_args.pp_num_layers_first, + num_last= parallel_args.pp_num_layers_last + ) + + if "model.visual" not in local_fqns: + model.model.visual = None + if "model.language_model.embed_tokens" not in local_fqns: + model.model.language_model.embed_tokens = None + + layers = model.model.language_model.layers + kept_indices = {int(f.split('.')[-1]) for f in local_fqns if "layers." in f} + model.model.language_model.layers = torch.nn.ModuleList( + [m for i, m in enumerate(layers) if i in kept_indices] + ) + + if "model.language_model.norm" not in local_fqns: + model.model.language_model.norm = None + if "lm_head" not in local_fqns: + model.lm_head = None + + model.to(device=device) + + layer_indices = [int(f.split('.')[-1]) for f in local_fqns if "layers." in f] + layer_start = min(layer_indices) if layer_indices else 0 + layer_end = max(layer_indices) + 1 if layer_indices else 0 + + target_dtype = torch.bfloat16 if training_args.bfloat16 else torch.float32 + + # Load stage-specific weights when PP > 1 and not using random init + if False: + logger.info(f"PP rank {pp_rank}: About to load stage weights for layers {layer_start}-{layer_end}") + load_stage_weights( + stage=self.model, + snapshot_dir=self.training_args.model_dir, + layer_start=layer_start, + layer_end=layer_end, + is_first=pp_rank == 0, + is_last=pp_rank == self.pp_size - 1, + device=self.device, + dtype=target_dtype, + ) + logger.info(f"PP rank {pp_rank}: Finished loading stage weights") + + # materialize model in GPU + model = model.to(device) + + pp_stage = PipelineStage( + model, + stage_index=pp_rank, + num_stages=pp_size, + device=device, + group=mesh.get_group(mesh_dim="pp"), + ) + + schedule_name = getattr(parallel_args, "pp_schedule", "gpipe").lower() + n_microbatches = getattr(parallel_args, "pp_microbatches", 1) + schedule_cls = {"gpipe": ScheduleGPipe, "1f1b": Schedule1F1B}.get(schedule_name) + if schedule_cls is None: + raise ValueError( + f"unknown pp_schedule={schedule_name!r}; expected one of: gpipe, 1f1b" + ) + if schedule_cls is Schedule1F1B: + # The plumbing exists (schedule construction, kwargs_chunk_spec, + # tiled input in _train_step_pp), but 1F1B requires + # n_microbatches >= pp_size and the dataloader currently emits + # one packed (1, total) sample per step — so microbatches are + # tiled copies of the same content and the loss is meaningless. + # Re-enable once the data path produces n_microbatches independent + # packed rows per step (per-row cu_seqlens, labels, image scatter). + raise NotImplementedError( + "pp_schedule='1f1b' is disabled until the data path supports " + "n_microbatches independent packed rows per step. Use 'gpipe' " + "for now." + ) + # The dataloader emits a single packed (1, total) sample per step. + # When n_microbatches > 1 we tile input_ids/labels to (N, total) so + # the schedule can chunk along dim 0; everything else (cu_seqlens, + # pixel_values, image_grid_thw, etc.) is per-batch metadata that + # must be passed identically to every microbatch — mark it replicate. + pp_microbatches = n_microbatches + kwargs_chunk_spec = { + k: _Replicate() for k in ( + "input_ids", "attention_mask", "original_mask", + "image_grid_thw", "pixel_values", + "pixel_values_videos", "video_grid_thw", + ) + } + pp_schedule = schedule_cls( + pp_stage, + n_microbatches=n_microbatches, + loss_fn=pp_loss_fn, + kwargs_chunk_spec=kwargs_chunk_spec, + ) + logger.info(f"PP schedule: {schedule_name} (n_microbatches={n_microbatches})") + + return pp_microbatches, pp_schedule diff --git a/train/train_qwen.py b/train/train_qwen.py index 1003439..66fdf6f 100644 --- a/train/train_qwen.py +++ b/train/train_qwen.py @@ -5,11 +5,13 @@ import wandb import transformers from itertools import cycle +from pathlib import Path import time from transformers import AutoProcessor +import torch.distributed as dist from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed._composable.replicate import replicate @@ -19,20 +21,27 @@ from megatron.energon import get_train_dataset, get_loader, WorkerConfig from data.task_encoder_factory import build_task_encoder +from models.qwen3_5.utils import causal_lm_loss, load_stage_weights + # training imports from train.config_manager import ConfigManager from train.config import Config, ModelType from train.logger import init_logger, logger, Color from train.infra import ( get_mesh, - get_tp_group, - get_dp_group, + get_mesh_group, + apply_fsdp, apply_tp, + apply_ep, apply_ac, + apply_pp, + ACConfig, compile_model, ) +from models.qwen3_5.model import initialize_missing_weights + from train.utils import ( set_determinism, generate_accumulation_pattern, @@ -58,7 +67,6 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - class Trainer(torch.distributed.checkpoint.stateful.Stateful): @record @@ -66,6 +74,7 @@ def __init__(self, cfg: Config): self.model_args = cfg.model self.training_args = cfg.training self.data_args = cfg.data + self.p_args = cfg.parallel self.wandb_args = cfg.wandb self.debug_mode = bool(os.environ.get("DEBUG", False)) @@ -74,9 +83,18 @@ def __init__(self, cfg: Config): self.world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(self.local_rank) - self.mesh = get_mesh(self.training_args, self.world_size) - self.tp_group = get_tp_group(self.mesh) - self.dp_group = get_dp_group(self.mesh) + self.mesh = get_mesh(self.p_args, self.world_size) + + self.tp_group = get_mesh_group(self.mesh, 'tp') + self.pp_group = get_mesh_group(self.mesh, 'pp') + self.ep_group = get_mesh_group(self.mesh, 'ep') + self.shard_group = get_mesh_group(self.mesh, 'dp_shard') + self.replicate_group = get_mesh_group(self.mesh, 'dp_replicate') + # this mesh group unifies `shard` and `replicate` + self.dp_group = get_mesh_group(self.mesh, "dp") + + self.pp_size = getattr(self.p_args, "pp_size", 1) + self.ep_size = getattr(self.p_args, "ep_size", 1) self.device = torch.device(f"cuda:{self.local_rank}") if self.if_log_rank(): @@ -88,10 +106,10 @@ def __init__(self, cfg: Config): **vars(self.model_args), **vars(self.training_args), **vars(self.data_args), + **vars(self.p_args), "mesh": self.mesh, "world_size": self.world_size, - "dp_group": self.dp_group, - "tp_group": self.tp_group, + # TODO: add all parallel args here }, ) @@ -107,6 +125,8 @@ def __init__(self, cfg: Config): set_determinism(seed=42 + self.local_rank, deterministic=True, world_mesh=self.mesh, debug_mode=self.debug_mode) + self.setup_accumulation(self.training_args.tpi_multiplier) + if self.rank() == 0: if not os.path.exists(self.training_args.output_dir): os.makedirs(self.training_args.output_dir) @@ -120,9 +140,13 @@ def __init__(self, cfg: Config): else: raise NotImplementedError(f"model not supported: {self.model_args.model_name}") - self.model = select_model_class(self.model_type, self.model_args, self.training_args) + # Load the model on CPU with weights; for PP the split stage is moved to GPU below. + self.model = select_model_class( + self.model_type, self.model_args, self.training_args + ) # we calculate the flops per token used to get the MFU number + # (works on meta tensors: shapes are valid even without data) num_params, self.flops_per_token = get_dense_model_nparams_and_flops( self.model_args.model_name, self.model, @@ -131,37 +155,51 @@ def __init__(self, cfg: Config): logger.info(f"Number params: {num_params}") - if self.training_args.load_text_model: - self.text_model = select_text_model(self.training_args) - self.model = load_text_model(self.model, self.text_model) + self.optimizer = None # defined later on - # MOVE TO cuda:{self.local_rank} - self.model.to(self.device) - + # -- PIPELINE PARALLEL + if self.p_args.pp_size > 1: + + def pp_loss_fn(outputs, labels): + logits, aux_loss = outputs + ce_loss = causal_lm_loss(logits, labels) + aux_loss = aux_loss.squeeze() + + self._recent_ce_loss = ce_loss.detach() + self._recent_aux_loss = aux_loss.detach() + + return (ce_loss + 0.01 * aux_loss) / self.current_accum_target + + self.pp_microbatches, self.pp_schedule = apply_pp( + self.model, self.mesh, self.p_args, self.training_args, self.device, pp_loss_fn + ) + self.pp_has_first_stage = self.model.model.visual is not None + self.pp_has_last_stage = self.model.lm_head is not None + + logger.info("model loaded") + + # -- WEIGHT INIT if self.training_args.random_init: if self.model_type == ModelType.Qwen3_5: logger.info('initilizing decoder and projecter of Qwen3.5') - init_qwen35(self.model) + #init_qwen35(self.model) elif self.model_type == ModelType.Qwen3_vl: logger.info('initilizing projector of Qwen3-VL') init_qwen3vl(self.model) else: logger.info('model not initlized, incompatible') + initialize_missing_weights(self.model) - # replace flash_attn + # -- MIXED PRECISION self.model.train() - if self.model_args.model_impl == "hf": - self.model.enable_input_require_grads() - self.optimizer = None # its defined later on - if self.training_args.bfloat16: self.model = self.model.to(torch.bfloat16) - logger.info("model loaded") - - if self.training_args.tp_size > 1: - apply_tp(self.model, self.model_type, self.tp_group, self.training_args.async_tp) + # -- TENSOR PARALLEL + if self.p_args.tp_size > 1: + apply_tp(self.model, self.model_type, self.tp_group, self.p_args.async_tp) + # -- ACTIVATION CHECKPOINTING ac_mode = getattr(self.training_args, "ac_mode", "off") if ac_mode != "off": ac_cfg = ACConfig(enabled=True, full=(ac_mode == "full")) @@ -172,20 +210,35 @@ def __init__(self, cfg: Config): ) logger.info(f"activation checkpointing applied ({ac_mode})") - if self.training_args.data_parallel == 'fsdp': - apply_fsdp(self.model_type, self.model, mesh=self.dp_group) - elif self.training_args.data_parallel == 'ddp': - self.model = replicate(self.model, device_mesh=self.dp_group) + # -- EXPERT PARALLEL + if self.ep_size > 1: + if self.model_type != ModelType.Qwen3_5: + raise NotImplementedError("EP is only supported for Qwen3.5 MoE models") + tp_mesh = self.tp_group if self.p_args.tp_size > 1 else None + apply_ep(self.model, self.ep_group, tp_mesh=tp_mesh) + logger.info(f"expert parallelism applied (ep_size={self.ep_size})") + + # -- DATA PARALLEL + dp_shard_mesh = get_mesh_group(self.mesh, "dp_shard") + dp_replicate_mesh = get_mesh_group(self.mesh, "dp_replicate") + + if dp_shard_mesh is not None and dp_shard_mesh.size() > 1: + apply_fsdp(self.model_type, self.model, mesh=dp_shard_mesh) + logger.info(f"FSDP applied (dp_shard={dp_shard_mesh.size()})") + elif dp_replicate_mesh is not None and dp_replicate_mesh.size() > 1: + self.model = replicate(self.model, device_mesh=dp_replicate_mesh) + logger.info(f"DDP applied (dp_replicate={dp_replicate_mesh.size()})") else: - raise Exception('invalid sharding strategy for Data Parallel') + logger.info(f"no DP applied (dp=1)") - # get rank of local GPU that belongs to the DP group - data_rank = self.dp_group.get_local_rank() - data_world_size = self.dp_group.size() + # loading into GPU + self.model = self.model.to(device=self.device) + if self.training_args.bfloat16: + self.model = self.model.to(torch.bfloat16) logger.info('sharding/parallelism applied') - if self.training_args.compile: + if self.training_args.compile and self.pp_size == 1: compile_model(self.model) logger.info("model (will be) compiled") @@ -199,10 +252,16 @@ def __init__(self, cfg: Config): self.processor = AutoProcessor.from_pretrained( self.training_args.model_dir, - max_pixels=1048576, ) - self.model = set_model(self.model_type, self.model_args, self.model) + # set_model freezes/unfreezes param groups; skip for PP (stage module + # doesn't have the full VLM wrapper structure) + if self.pp_size == 1: + self.model = set_model(self.model_type, self.model_args, self.model) + + # get rank of local GPU that belongs to the DP group + data_rank = self.dp_group.get_local_rank() + data_world_size = self.dp_group.size() worker_config = WorkerConfig( rank=data_rank, @@ -228,8 +287,6 @@ def __init__(self, cfg: Config): self.data_loader = get_loader(ds) - self.setup_accumulation(self.training_args.tpi_multiplier) - self.global_step = 0 self.micro_step = 0 @@ -247,6 +304,7 @@ def rank(self): return torch.distributed.get_rank() def if_log_rank(self): + # Log only from global rank 0 (always pp_rank=0 and dp_rank=0) return self.rank() == 0 def create_optimizer(self): @@ -417,7 +475,7 @@ def batch_generator(self): yield batch - def log(self, avg_loss, max_loss, global_tokens, global_assistant_tokens, global_samples, lr): + def log(self, avg_loss, aux_loss, max_loss, global_tokens, global_assistant_tokens, global_samples, lr): time_delta = time.perf_counter() - self.time_last_log @@ -430,6 +488,9 @@ def log(self, avg_loss, max_loss, global_tokens, global_assistant_tokens, global # GB200 (JUP) and SXM H100 (MN5) peak_tflops_per_gpu = 989.4 + # BLACKWELL 6000 + peak_tflops_per_gpu = 504 + # L40S #peak_tflops_per_gpu = 362 @@ -442,6 +503,7 @@ def log(self, avg_loss, max_loss, global_tokens, global_assistant_tokens, global logger.info( f"{color.red}step {self.global_step} " f"{color.green}loss {avg_loss:.4f} " + f"{color.green}aux {aux_loss:.4f} " f"{color.blue}tps {tps:.2f} " f"{color.magenta}mfu {mfu:.1f}% " f"{color.reset}" @@ -478,24 +540,133 @@ def setup_accumulation(self, tpi_multiplier=1.5): self.current_accum_target = next(self.accum_schedule) self.current_accum_count = 0 - def train_step(self, data_iterator, optimizer): + def _train_step_pp(self, data_iterator, optimizer): + batch = next(data_iterator) + + labels = batch.pop('labels', None) + input_ids = batch.pop('input_ids') + batch['input_ids'] = input_ids + + # Schedule chunks the positional input_ids and target along dim 0; + # tile the (1, total) packed sample to (N, total) so n_microbatches > 1 + # produces N actual chunks. Each microbatch is identical content — fine + # for benchmarking the schedule, not for real training. + n = self.pp_microbatches + tiled_input_ids = input_ids.repeat(n, 1) if n > 1 else input_ids + tiled_labels = labels.repeat(n, 1) if (n > 1 and labels is not None) else labels + + losses = [] if self.pp_has_last_stage else None + target = tiled_labels if self.pp_has_last_stage else None + + s_model = time.perf_counter() + with record_function("pp_forward_backward"): + with torch.autocast('cuda', torch.bfloat16): + if self.pp_has_first_stage: + self.pp_schedule.step(tiled_input_ids, **batch, target=target, losses=losses) + else: + self.pp_schedule.step(**batch, target=target, losses=losses) + + if self.ep_size > 1 and self.dp_group.size() > 1: + is_last_accum = (self.current_accum_count + 1 >= self.current_accum_target) + if is_last_accum: + # we use a custom bucking system instead of the replicate hooks + self._sync_gradients() + + self.fwd_bwd_time = time.perf_counter() - s_model + + scaled_loss = torch.stack(losses).sum() if losses else torch.tensor(0.0, device=self.device) + loss_for_logging = scaled_loss * self.current_accum_target + torch.distributed.all_reduce(loss_for_logging, group=self.pp_group.get_group()) + + # TODO: FIX THIS + ce_loss = getattr(self, '_recent_ce_loss', torch.tensor(0.0, device=self.device)) + aux_loss = getattr(self, '_recent_aux_loss', torch.tensor(0.0, device=self.device)) + + torch.distributed.all_reduce(ce_loss, group=self.pp_group.get_group()) + torch.distributed.all_reduce(aux_loss, group=self.pp_group.get_group()) + + return self._maybe_optimizer_step(loss_for_logging, ce_loss, aux_loss, optimizer) + + def _sync_gradients(self): + """Bucketed grad all_reduce across dp_group. One collective per ~25 MB + bucket (per dtype) instead of one per parameter, so DP scales by NCCL + bandwidth rather than per-launch latency. Used instead of DDP hooks when + EP is active. + """ + dp_size = self.dp_group.size() + if dp_size <= 1: + return + grp = self.dp_group.get_group() + + from torch.distributed.tensor import DTensor + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + # NCCL all_reduce requires uniform dtype within a call. + by_dtype: dict[torch.dtype, list[torch.Tensor]] = {} + for p in self.model.parameters(): + if p.grad is None: + continue + g = p.grad + # TP-sharded params (e.g. shared_expert.*) have DTensor grads; reduce the local shard. + if isinstance(g, DTensor): + g = g.to_local() + by_dtype.setdefault(g.dtype, []).append(g) + + bucket_max_elems = 25 * 1024 * 1024 # ~50 MB at bf16, ~100 MB at fp32 + inv_dp = 1.0 / dp_size + + def _flush(bucket: list[torch.Tensor]) -> None: + flat = _flatten_dense_tensors(bucket) + dist.all_reduce(flat, group=grp) + flat.mul_(inv_dp) + for g, synced in zip(bucket, _unflatten_dense_tensors(flat, bucket)): + g.copy_(synced) + + for grads in by_dtype.values(): + bucket: list[torch.Tensor] = [] + bucket_elems = 0 + for g in grads: + n = g.numel() + if bucket and bucket_elems + n > bucket_max_elems: + _flush(bucket) + bucket, bucket_elems = [], 0 + bucket.append(g) + bucket_elems += n + if bucket: + _flush(bucket) + + def _train_step(self, data_iterator, optimizer): batch = next(data_iterator) + input_ids = batch.pop('input_ids') + labels = batch.pop('labels', None) s_model = time.perf_counter() with record_function("forward_pass"): with torch.autocast('cuda', torch.bfloat16): - outputs = self.model( - **batch - ) - loss = outputs.loss + logits, aux_loss = self.model(input_ids, **batch) + ce_loss = causal_lm_loss(logits, labels) + loss = ce_loss + (.01 * aux_loss) with record_function("backward_pass"): scaled_loss = loss / self.current_accum_target - with torch.autocast('cuda', torch.bfloat16): - scaled_loss.backward() + scaled_loss.backward() + + if self.ep_size > 1 and self.dp_group.size() > 1: + is_last_accum = (self.current_accum_count + 1 >= self.current_accum_target) + if is_last_accum: + self._sync_gradients() self.fwd_bwd_time = time.perf_counter() - s_model + return self._maybe_optimizer_step(loss, ce_loss, aux_loss, optimizer) + def train_step(self, data_iterator, optimizer): + if self.pp_size == 1: + return self._train_step(data_iterator, optimizer) + else: + return self._train_step_pp(data_iterator, optimizer) + + def _maybe_optimizer_step(self, loss, ce_loss, aux_loss, optimizer): + """Shared optimizer-step logic after fwd+bwd (regular and PP paths).""" self.current_accum_count += 1 if self.current_accum_count >= self.current_accum_target: @@ -504,40 +675,37 @@ def train_step(self, data_iterator, optimizer): optimizer.zero_grad() lr = optimizer.param_groups[0]['lr'] - self.global_step += 1 - avg_loss, max_loss, global_tokens, global_assistant, global_samples = ( - dist_mean(loss, self.dp_group), - dist_max(loss, self.dp_group), + avg_loss, aux_loss, max_loss, global_tokens, global_assistant, global_samples = ( + dist_mean(ce_loss, self.dp_group), + dist_mean(aux_loss, self.dp_group), + dist_max(ce_loss, self.dp_group), dist_sum( - torch.tensor( - self.tokens_seen, dtype=torch.int64, device=self.device - ), + torch.tensor(self.tokens_seen, dtype=torch.int64, device=self.device), self.dp_group, ), dist_sum( - torch.tensor( - self.tokens_seen_assistant, dtype=torch.int64, device=self.device - ), + torch.tensor(self.tokens_seen_assistant, dtype=torch.int64, device=self.device), self.dp_group, ), dist_sum( torch.tensor(self.samples_since_last_log, dtype=torch.int32, device=self.device), self.dp_group, - ) + ), ) - self.train_step_delta = (time.perf_counter() - self.time_last_log) / self.current_accum_target + self.train_step_delta = ( + (time.perf_counter() - self.time_last_log) / self.current_accum_target + ) if self.if_log_rank(): - self.log(avg_loss, max_loss, global_tokens, global_assistant, global_samples, lr) + self.log(avg_loss, aux_loss, max_loss, global_tokens, global_assistant, global_samples, lr) self.total_ntokens_since_last_log = 0 self.ntokens_since_last_log = 0 self.samples_since_last_log = 0 self.time_last_log = time.perf_counter() - self.current_accum_count = 0 self.current_accum_target = next(self.accum_schedule) @@ -582,7 +750,7 @@ def trace_handler(prof): try: while self.global_step < self.training_args.total_steps: self.micro_step += 1 - + # training step executed here optimizer_updated = self.train_step(data_iterator, optimizer) diff --git a/train/utils.py b/train/utils.py index c920568..9e50527 100644 --- a/train/utils.py +++ b/train/utils.py @@ -233,7 +233,7 @@ def collect(reason: str, generation: int = 1): logger.info("[GC] %s took %.2f seconds", reason, time.monotonic() - begin) -def select_model_class(model_type: ModelType, model_args: ModelArgs, training_args: TrainArgs): +def select_model_class(model_type: ModelType, model_args: ModelArgs, training_args: TrainArgs, meta_only: bool = False): """ TODO: use ModelType instead of model name """ @@ -245,7 +245,7 @@ def select_model_class(model_type: ModelType, model_args: ModelArgs, training_ar model_name = model_args.model_name.lower() if model_args.model_impl == "native": - return _select_native_model_class(training_args, model_name) + return _select_native_model_class(training_args, model_name, meta_only=meta_only) elif model_args.model_impl != "hf": raise ValueError( f"Unknown model_impl '{model_args.model_impl}'. Expected 'hf' or 'native'." @@ -307,8 +307,13 @@ def select_model_class(model_type: ModelType, model_args: ModelArgs, training_ar return model -def _select_native_model_class(training_args: TrainArgs, model_name: str): - """Dispatch to our torch-native model implementations under `models/`.""" +def _select_native_model_class(training_args: TrainArgs, model_name: str, meta_only: bool = False): + """Dispatch to our torch-native model implementations under `models/`. + + When ``meta_only=True`` the model is returned on ``torch.device("meta")`` + with no weights loaded. The caller is responsible for materialising + parameters and loading weights (e.g. via ``load_stage_weights`` for PP). + """ dtype = torch.bfloat16 if training_args.bfloat16 else torch.float32 if "qwen3-vl" in model_name: @@ -326,8 +331,10 @@ def _select_native_model_class(training_args: TrainArgs, model_name: str): training_args.model_dir, dtype=dtype, device="cpu", + weights=not meta_only, ) - logger.info(f"Loaded native {model_name} from {training_args.model_dir}") + if not meta_only: + logger.info(f"Loaded native {model_name} from {training_args.model_dir}") return model def select_text_model(training_args): @@ -508,20 +515,15 @@ def get_dense_model_nparams_and_flops( if isinstance(m, torch.nn.Embedding) ) - if "8B" in model_name: - tied = False - elif "9B" in model_name: - tied = False - elif "2B" in model_name: + if "2B" in model_name: tied = True elif "4B" in model_name: tied = True elif "1.7B" in model_name: tied = True else: - # ValueError - return 0, 0 - + tied = False + # we take into account the embedding params num_flops_per_token = 6 * nparams diff --git a/utils/down.py b/utils/down.py index 09b895e..e61fbb9 100644 --- a/utils/down.py +++ b/utils/down.py @@ -1,6 +1,6 @@ from huggingface_hub import snapshot_download snapshot_download( - repo_id="Qwen/Qwen3.5-9B", - local_dir="/data/151-1/users/tockier/qwen_finetune/cache/qwen35_9b", + repo_id="Qwen/Qwen3.5-35B-A3B", + local_dir="/data/151-1/users/tockier/qwen_finetune/cache/qwen35_35b_a3", )