diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 0f36c43c3..35ae7bbbb 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -124,6 +124,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): if self.scheduler.infer_condition: plucker_emb = memory_plucker + # use memory_plucker first if plucker_emb is None: plucker_emb = camera_plucker mouse_cond = dit_cond_dict.get("mouse_cond", None) diff --git a/lightx2v/models/networks/wan/matrix_game3_model.py b/lightx2v/models/networks/wan/matrix_game3_model.py index 098c69b04..676a56631 100644 --- a/lightx2v/models/networks/wan/matrix_game3_model.py +++ b/lightx2v/models/networks/wan/matrix_game3_model.py @@ -1,10 +1,6 @@ import json import os -import sys -from functools import lru_cache -from pathlib import Path -import torch from safetensors import safe_open from lightx2v.models.networks.wan.infer.matrix_game3.post_infer import WanMtxg3PostInfer @@ -17,167 +13,6 @@ from lightx2v.utils.utils import * -@lru_cache(maxsize=1) -def _import_official_matrix_game3_wan_model(): - """Load the official Matrix-Game-3 WanModel implementation on demand.""" - official_root = Path(__file__).resolve().parents[4] / "Matrix-Game-3" / "Matrix-Game-3" - if not official_root.is_dir(): - raise FileNotFoundError(f"Official Matrix-Game-3 source directory not found: {official_root}") - official_root_str = str(official_root) - if official_root_str not in sys.path: - sys.path.insert(0, official_root_str) - from wan.modules.model import WanModel as OfficialWanModel - - return OfficialWanModel - - -def _matrix_game3_compute_max_seq_len(config, lat_h: int, lat_w: int) -> int: - first_clip_frame = int(config.get("first_clip_frame", config.get("target_video_length", 57))) - vae_stride_t = int(tuple(config.get("vae_stride", (4, 16, 16)))[0]) - patch_h, patch_w = tuple(config.get("patch_size", (1, 2, 2)))[1:] - max_lat_f = (first_clip_frame - 1) // vae_stride_t + 1 - max_mem_f = 5 - max_total_f = max_lat_f + max_mem_f - max_seq_len = max_total_f * lat_h * lat_w // (patch_h * patch_w) - sp_size = int(config.get("sp_size", 1)) - if sp_size > 1: - max_seq_len = int(((max_seq_len + sp_size - 1) // sp_size) * sp_size) - return max_seq_len - - -class WanMtxg3OfficialBaseModel: - """Base-model wrapper that delegates denoising to the official MG3 forward. - - The distilled MG3 path is numerically tolerant enough to run through the - custom LightX2V weight/infer stack, but the base checkpoint is much more - sensitive under 50-step CFG. Reusing the official DiT forward here removes - the remaining block/head precision mismatches from the adaptation. - """ - - def __init__(self, model_path, config, device, model_type="wan2.2", lora_path=None, lora_strength=1.0): - del model_type, lora_path, lora_strength - self.model_path = model_path - self.config = config - self.device = device - self.scheduler = None - self.transformer_infer = None - self._official_model = self._load_official_model() - - def _load_official_model(self): - sub_model_folder = self.config.get("sub_model_folder", "base_model") - model_dir = os.path.join(self.config["model_path"], sub_model_folder) - if not os.path.isdir(model_dir): - raise FileNotFoundError(f"Matrix-Game-3 base checkpoint directory not found: {model_dir}") - OfficialWanModel = _import_official_matrix_game3_wan_model() - model = OfficialWanModel.from_pretrained(model_dir, torch_dtype=torch.bfloat16) - model = model.eval().requires_grad_(False) - model.to(device=self.device, dtype=torch.bfloat16) - return model - - def set_scheduler(self, scheduler): - self.scheduler = scheduler - - def _build_official_timestep(self, latents): - t = self.scheduler.timestep_input - if t is None: - raise RuntimeError("Matrix-Game-3 base forward requested before scheduler.timestep_input was prepared") - if t.numel() != 1: - return t.reshape(1, -1).to(device=latents.device, dtype=latents.dtype) - - timestep_scalar = t.reshape(1).to(device=latents.device, dtype=latents.dtype) - patch_size = tuple(self.config.get("patch_size", (1, 2, 2))) - patch_h = int(patch_size[1]) - patch_w = int(patch_size[2]) - latent_frames = int(latents.shape[1]) - latent_h = int(latents.shape[2]) - latent_w = int(latents.shape[3]) - tokens_per_frame = latent_h * latent_w // (patch_h * patch_w) - timestep = latents.new_full( - (latent_frames, tokens_per_frame), - timestep_scalar.squeeze(0), - ) - mask = getattr(self.scheduler, "mask", None) - if mask is not None: - fixed_latent_frames = int((mask.to(dtype=torch.float32).amax(dim=(0, 2, 3)) == 0).sum().item()) - if fixed_latent_frames > 0: - timestep[:fixed_latent_frames].zero_() - return timestep.flatten().unsqueeze(0) - - def _build_forward_kwargs(self, inputs, infer_condition): - if self.scheduler is None: - raise RuntimeError("Matrix-Game-3 base model used before scheduler was attached") - - latents = self.scheduler.latents.unsqueeze(0) - timestep = self._build_official_timestep(self.scheduler.latents) - image_encoder_output = inputs.get("image_encoder_output", {}) - dit_cond_dict = image_encoder_output.get("dit_cond_dict") or {} - memory_plucker = dit_cond_dict.get("plucker_emb_with_memory") - camera_plucker = dit_cond_dict.get("c2ws_plucker_emb") - - if infer_condition: - context = inputs["text_encoder_output"]["context"] - plucker_emb = memory_plucker - if plucker_emb is None: - plucker_emb = camera_plucker - mouse_cond = dit_cond_dict.get("mouse_cond") - keyboard_cond = dit_cond_dict.get("keyboard_cond") - x_memory = dit_cond_dict.get("x_memory") - timestep_memory = dit_cond_dict.get("timestep_memory") - mouse_cond_memory = dit_cond_dict.get("mouse_cond_memory") - keyboard_cond_memory = dit_cond_dict.get("keyboard_cond_memory") - memory_latent_idx = dit_cond_dict.get("memory_latent_idx") - else: - context = inputs["text_encoder_output"]["context_null"] - mouse_source = dit_cond_dict.get("mouse_cond") - keyboard_source = dit_cond_dict.get("keyboard_cond") - plucker_emb = dit_cond_dict.get("c2ws_plucker_emb") - mouse_cond = torch.ones_like(mouse_source) if mouse_source is not None else None - keyboard_cond = -torch.ones_like(keyboard_source) if keyboard_source is not None else None - x_memory = None - timestep_memory = None - mouse_cond_memory = None - keyboard_cond_memory = None - memory_latent_idx = None - - seq_len = _matrix_game3_compute_max_seq_len(self.config, latents.shape[3], latents.shape[4]) - - forward_kwargs = { - "x": latents, - "t": timestep, - "context": context, - "seq_len": seq_len, - "mouse_cond": mouse_cond, - "keyboard_cond": keyboard_cond, - "x_memory": x_memory, - "timestep_memory": timestep_memory, - "mouse_cond_memory": mouse_cond_memory, - "keyboard_cond_memory": keyboard_cond_memory, - "plucker_emb": plucker_emb, - "memory_latent_idx": memory_latent_idx, - "predict_latent_idx": dit_cond_dict.get("predict_latent_idx"), - } - return forward_kwargs - - @torch.no_grad() - def _infer_cond_uncond(self, inputs, infer_condition=True): - self.scheduler.infer_condition = infer_condition - noise_pred = self._official_model(**self._build_forward_kwargs(inputs, infer_condition)) - if isinstance(noise_pred, list): - noise_pred = torch.stack(noise_pred) - if noise_pred.dim() == 5 and noise_pred.shape[0] == 1: - noise_pred = noise_pred.squeeze(0) - return noise_pred.float() - - @torch.no_grad() - def infer(self, inputs): - if self.config.get("enable_cfg", False): - cond_pred = self._infer_cond_uncond(inputs, infer_condition=True) - uncond_pred = self._infer_cond_uncond(inputs, infer_condition=False) - self.scheduler.noise_pred = uncond_pred + self.scheduler.sample_guide_scale * (cond_pred - uncond_pred) - else: - self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) - - class WanMtxg3Model(WanModel): """Network model for Matrix-Game-3.0. @@ -193,6 +28,7 @@ class WanMtxg3Model(WanModel): pre_weight_class = WanMtxg3PreWeights transformer_weight_class = WanMtxg3TransformerWeights + # replace the module def __init__(self, model_path, config, device, model_type="wan2.2", lora_path=None, lora_strength=1.0): super().__init__(model_path, config, device, model_type, lora_path, lora_strength) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index d053f1fb1..11d06d7d7 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -1112,10 +1112,7 @@ def run_text_encoder(self, input_info): return super().run_text_encoder(input_info) def load_transformer(self): - from lightx2v.models.networks.wan.matrix_game3_model import ( - WanMtxg3Model, - WanMtxg3OfficialBaseModel, - ) + from lightx2v.models.networks.wan.matrix_game3_model import WanMtxg3Model # The backbone is still a Wan2.2 DiT, but Matrix-Game-3 swaps in a dedicated # network wrapper that understands keyboard / mouse / camera conditions. @@ -1126,15 +1123,7 @@ def load_transformer(self): } lora_configs = self.config.get("lora_configs") if not lora_configs: - if self.config.get("use_base_model", False): - try: - logger.info("[matrix-game-3] base-model path will use the official WanModel forward for denoising.") - return WanMtxg3OfficialBaseModel(**model_kwargs) - except Exception as exc: - logger.warning( - "[matrix-game-3] failed to initialize official base-model forward ({}); falling back to the custom LightX2V MG3 model.", - exc, - ) + logger.info("[matrix-game-3] loading MG3 {} checkpoint with the LightX2V inference stack.", self._get_sub_model_folder()) return WanMtxg3Model(**model_kwargs) return build_wan_model_with_lora(WanMtxg3Model, self.config, model_kwargs, lora_configs, model_type="wan2.2") diff --git a/lightx2v/models/runners/wan/wan_runner.py b/lightx2v/models/runners/wan/wan_runner.py index 813f7d11c..359ee2adb 100755 --- a/lightx2v/models/runners/wan/wan_runner.py +++ b/lightx2v/models/runners/wan/wan_runner.py @@ -852,6 +852,7 @@ def _build_wan22_vae_config(self, vae_offload): "vae_type": resolved_paths["vae_type"], "lightvae_pruning_rate": resolved_paths["lightvae_pruning_rate"], "lightvae_encoder_vae_pth": resolved_paths["lightvae_encoder_vae_pth"], + "dummy_model": self.config.get("dummy_model", False), } def load_vae_encoder(self): diff --git a/lightx2v/models/video_encoders/hf/wan/vae_2_2.py b/lightx2v/models/video_encoders/hf/wan/vae_2_2.py index 713cf5c05..8efebbd63 100755 --- a/lightx2v/models/video_encoders/hf/wan/vae_2_2.py +++ b/lightx2v/models/video_encoders/hf/wan/vae_2_2.py @@ -1002,6 +1002,7 @@ def _video_vae( load_from_rank0=False, normalize_state_dict=False, strict=True, + dummy_model=False, **kwargs, ): # params @@ -1024,18 +1025,18 @@ def _video_vae( with torch.device("meta"): model = WanVAE_(**cfg) - # load checkpoint - logging.info(f"loading {pretrained_path}") - raw_state = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0) - weights_dict = _normalize_vae_state_dict(raw_state) if normalize_state_dict else raw_state - for key in list(weights_dict.keys()): - if hasattr(weights_dict[key], "dtype") and weights_dict[key].dtype != dtype: - weights_dict[key] = weights_dict[key].to(dtype) - if strict: - model.load_state_dict(weights_dict, assign=True) - else: - missing, unexpected = model.load_state_dict(weights_dict, strict=False, assign=True) - logging.info(f"VAE checkpoint loaded with strict=False (missing={len(missing)}, unexpected={len(unexpected)})") + # load checkpoint + logging.info(f"loading {pretrained_path}") + raw_state = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0) + weights_dict = _normalize_vae_state_dict(raw_state) if normalize_state_dict else raw_state + for key in list(weights_dict.keys()): + if hasattr(weights_dict[key], "dtype") and weights_dict[key].dtype != dtype: + weights_dict[key] = weights_dict[key].to(dtype) + if strict: + model.load_state_dict(weights_dict, assign=True) + else: + missing, unexpected = model.load_state_dict(weights_dict, strict=False, assign=True) + logging.info(f"VAE checkpoint loaded with strict=False (missing={len(missing)}, unexpected={len(unexpected)})") # Convert Conv3d weights to channels_last_3d for cuDNN optimization if GET_USE_CHANNELS_LAST_3D(): @@ -1061,6 +1062,7 @@ def __init__( vae_type="wan2.2", lightvae_pruning_rate=None, lightvae_encoder_vae_pth=None, + dummy_model=False, **kwargs, ): self.dtype = dtype @@ -1194,6 +1196,7 @@ def __init__( normalize_state_dict=False, strict=True, pruning_rate=0.0, + dummy_model=dummy_model, ) .eval() .requires_grad_(False) @@ -1223,6 +1226,7 @@ def __init__( normalize_state_dict=True, strict=False, pruning_rate=0.0, + dummy_model=dummy_model, ) .eval() .requires_grad_(False) @@ -1242,6 +1246,7 @@ def __init__( normalize_state_dict=True, strict=False, pruning_rate=resolved_pruning_rate, + dummy_model=dummy_model, ) .eval() .requires_grad_(False)