Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0):

if self.scheduler.infer_condition:
plucker_emb = memory_plucker
# use memory_plucker first
if plucker_emb is None:
plucker_emb = camera_plucker
mouse_cond = dit_cond_dict.get("mouse_cond", None)
Expand Down
166 changes: 1 addition & 165 deletions lightx2v/models/networks/wan/matrix_game3_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import json
import os
import sys
from functools import lru_cache
from pathlib import Path

import torch
from safetensors import safe_open

from lightx2v.models.networks.wan.infer.matrix_game3.post_infer import WanMtxg3PostInfer
Expand All @@ -17,167 +13,6 @@
from lightx2v.utils.utils import *


@lru_cache(maxsize=1)
def _import_official_matrix_game3_wan_model():
"""Load the official Matrix-Game-3 WanModel implementation on demand."""
official_root = Path(__file__).resolve().parents[4] / "Matrix-Game-3" / "Matrix-Game-3"
if not official_root.is_dir():
raise FileNotFoundError(f"Official Matrix-Game-3 source directory not found: {official_root}")
official_root_str = str(official_root)
if official_root_str not in sys.path:
sys.path.insert(0, official_root_str)
from wan.modules.model import WanModel as OfficialWanModel

return OfficialWanModel


def _matrix_game3_compute_max_seq_len(config, lat_h: int, lat_w: int) -> int:
first_clip_frame = int(config.get("first_clip_frame", config.get("target_video_length", 57)))
vae_stride_t = int(tuple(config.get("vae_stride", (4, 16, 16)))[0])
patch_h, patch_w = tuple(config.get("patch_size", (1, 2, 2)))[1:]
max_lat_f = (first_clip_frame - 1) // vae_stride_t + 1
max_mem_f = 5
max_total_f = max_lat_f + max_mem_f
max_seq_len = max_total_f * lat_h * lat_w // (patch_h * patch_w)
sp_size = int(config.get("sp_size", 1))
if sp_size > 1:
max_seq_len = int(((max_seq_len + sp_size - 1) // sp_size) * sp_size)
return max_seq_len


class WanMtxg3OfficialBaseModel:
"""Base-model wrapper that delegates denoising to the official MG3 forward.

The distilled MG3 path is numerically tolerant enough to run through the
custom LightX2V weight/infer stack, but the base checkpoint is much more
sensitive under 50-step CFG. Reusing the official DiT forward here removes
the remaining block/head precision mismatches from the adaptation.
"""

def __init__(self, model_path, config, device, model_type="wan2.2", lora_path=None, lora_strength=1.0):
del model_type, lora_path, lora_strength
self.model_path = model_path
self.config = config
self.device = device
self.scheduler = None
self.transformer_infer = None
self._official_model = self._load_official_model()

def _load_official_model(self):
sub_model_folder = self.config.get("sub_model_folder", "base_model")
model_dir = os.path.join(self.config["model_path"], sub_model_folder)
if not os.path.isdir(model_dir):
raise FileNotFoundError(f"Matrix-Game-3 base checkpoint directory not found: {model_dir}")
OfficialWanModel = _import_official_matrix_game3_wan_model()
model = OfficialWanModel.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
model = model.eval().requires_grad_(False)
model.to(device=self.device, dtype=torch.bfloat16)
return model

def set_scheduler(self, scheduler):
self.scheduler = scheduler

def _build_official_timestep(self, latents):
t = self.scheduler.timestep_input
if t is None:
raise RuntimeError("Matrix-Game-3 base forward requested before scheduler.timestep_input was prepared")
if t.numel() != 1:
return t.reshape(1, -1).to(device=latents.device, dtype=latents.dtype)

timestep_scalar = t.reshape(1).to(device=latents.device, dtype=latents.dtype)
patch_size = tuple(self.config.get("patch_size", (1, 2, 2)))
patch_h = int(patch_size[1])
patch_w = int(patch_size[2])
latent_frames = int(latents.shape[1])
latent_h = int(latents.shape[2])
latent_w = int(latents.shape[3])
tokens_per_frame = latent_h * latent_w // (patch_h * patch_w)
timestep = latents.new_full(
(latent_frames, tokens_per_frame),
timestep_scalar.squeeze(0),
)
mask = getattr(self.scheduler, "mask", None)
if mask is not None:
fixed_latent_frames = int((mask.to(dtype=torch.float32).amax(dim=(0, 2, 3)) == 0).sum().item())
if fixed_latent_frames > 0:
timestep[:fixed_latent_frames].zero_()
return timestep.flatten().unsqueeze(0)

def _build_forward_kwargs(self, inputs, infer_condition):
if self.scheduler is None:
raise RuntimeError("Matrix-Game-3 base model used before scheduler was attached")

latents = self.scheduler.latents.unsqueeze(0)
timestep = self._build_official_timestep(self.scheduler.latents)
image_encoder_output = inputs.get("image_encoder_output", {})
dit_cond_dict = image_encoder_output.get("dit_cond_dict") or {}
memory_plucker = dit_cond_dict.get("plucker_emb_with_memory")
camera_plucker = dit_cond_dict.get("c2ws_plucker_emb")

if infer_condition:
context = inputs["text_encoder_output"]["context"]
plucker_emb = memory_plucker
if plucker_emb is None:
plucker_emb = camera_plucker
mouse_cond = dit_cond_dict.get("mouse_cond")
keyboard_cond = dit_cond_dict.get("keyboard_cond")
x_memory = dit_cond_dict.get("x_memory")
timestep_memory = dit_cond_dict.get("timestep_memory")
mouse_cond_memory = dit_cond_dict.get("mouse_cond_memory")
keyboard_cond_memory = dit_cond_dict.get("keyboard_cond_memory")
memory_latent_idx = dit_cond_dict.get("memory_latent_idx")
else:
context = inputs["text_encoder_output"]["context_null"]
mouse_source = dit_cond_dict.get("mouse_cond")
keyboard_source = dit_cond_dict.get("keyboard_cond")
plucker_emb = dit_cond_dict.get("c2ws_plucker_emb")
mouse_cond = torch.ones_like(mouse_source) if mouse_source is not None else None
keyboard_cond = -torch.ones_like(keyboard_source) if keyboard_source is not None else None
x_memory = None
timestep_memory = None
mouse_cond_memory = None
keyboard_cond_memory = None
memory_latent_idx = None

seq_len = _matrix_game3_compute_max_seq_len(self.config, latents.shape[3], latents.shape[4])

forward_kwargs = {
"x": latents,
"t": timestep,
"context": context,
"seq_len": seq_len,
"mouse_cond": mouse_cond,
"keyboard_cond": keyboard_cond,
"x_memory": x_memory,
"timestep_memory": timestep_memory,
"mouse_cond_memory": mouse_cond_memory,
"keyboard_cond_memory": keyboard_cond_memory,
"plucker_emb": plucker_emb,
"memory_latent_idx": memory_latent_idx,
"predict_latent_idx": dit_cond_dict.get("predict_latent_idx"),
}
return forward_kwargs

@torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition
noise_pred = self._official_model(**self._build_forward_kwargs(inputs, infer_condition))
if isinstance(noise_pred, list):
noise_pred = torch.stack(noise_pred)
if noise_pred.dim() == 5 and noise_pred.shape[0] == 1:
noise_pred = noise_pred.squeeze(0)
return noise_pred.float()

@torch.no_grad()
def infer(self, inputs):
if self.config.get("enable_cfg", False):
cond_pred = self._infer_cond_uncond(inputs, infer_condition=True)
uncond_pred = self._infer_cond_uncond(inputs, infer_condition=False)
self.scheduler.noise_pred = uncond_pred + self.scheduler.sample_guide_scale * (cond_pred - uncond_pred)
else:
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)


class WanMtxg3Model(WanModel):
"""Network model for Matrix-Game-3.0.

Expand All @@ -193,6 +28,7 @@ class WanMtxg3Model(WanModel):

pre_weight_class = WanMtxg3PreWeights
transformer_weight_class = WanMtxg3TransformerWeights
# replace the module
Comment thread
Michael20070814 marked this conversation as resolved.

def __init__(self, model_path, config, device, model_type="wan2.2", lora_path=None, lora_strength=1.0):
super().__init__(model_path, config, device, model_type, lora_path, lora_strength)
Expand Down
15 changes: 2 additions & 13 deletions lightx2v/models/runners/wan/wan_matrix_game3_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,10 +1112,7 @@ def run_text_encoder(self, input_info):
return super().run_text_encoder(input_info)

def load_transformer(self):
from lightx2v.models.networks.wan.matrix_game3_model import (
WanMtxg3Model,
WanMtxg3OfficialBaseModel,
)
from lightx2v.models.networks.wan.matrix_game3_model import WanMtxg3Model

# The backbone is still a Wan2.2 DiT, but Matrix-Game-3 swaps in a dedicated
# network wrapper that understands keyboard / mouse / camera conditions.
Expand All @@ -1126,15 +1123,7 @@ def load_transformer(self):
}
lora_configs = self.config.get("lora_configs")
if not lora_configs:
if self.config.get("use_base_model", False):
try:
logger.info("[matrix-game-3] base-model path will use the official WanModel forward for denoising.")
return WanMtxg3OfficialBaseModel(**model_kwargs)
except Exception as exc:
logger.warning(
"[matrix-game-3] failed to initialize official base-model forward ({}); falling back to the custom LightX2V MG3 model.",
exc,
)
logger.info("[matrix-game-3] loading MG3 {} checkpoint with the LightX2V inference stack.", self._get_sub_model_folder())
return WanMtxg3Model(**model_kwargs)
return build_wan_model_with_lora(WanMtxg3Model, self.config, model_kwargs, lora_configs, model_type="wan2.2")

Expand Down
1 change: 1 addition & 0 deletions lightx2v/models/runners/wan/wan_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def _build_wan22_vae_config(self, vae_offload):
"vae_type": resolved_paths["vae_type"],
"lightvae_pruning_rate": resolved_paths["lightvae_pruning_rate"],
"lightvae_encoder_vae_pth": resolved_paths["lightvae_encoder_vae_pth"],
"dummy_model": self.config.get("dummy_model", False),
}

def load_vae_encoder(self):
Expand Down
29 changes: 17 additions & 12 deletions lightx2v/models/video_encoders/hf/wan/vae_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,7 @@ def _video_vae(
load_from_rank0=False,
normalize_state_dict=False,
strict=True,
dummy_model=False,
**kwargs,
):
# params
Expand All @@ -1024,18 +1025,18 @@ def _video_vae(
with torch.device("meta"):
model = WanVAE_(**cfg)

# load checkpoint
logging.info(f"loading {pretrained_path}")
raw_state = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0)
weights_dict = _normalize_vae_state_dict(raw_state) if normalize_state_dict else raw_state
for key in list(weights_dict.keys()):
if hasattr(weights_dict[key], "dtype") and weights_dict[key].dtype != dtype:
weights_dict[key] = weights_dict[key].to(dtype)
if strict:
model.load_state_dict(weights_dict, assign=True)
else:
missing, unexpected = model.load_state_dict(weights_dict, strict=False, assign=True)
logging.info(f"VAE checkpoint loaded with strict=False (missing={len(missing)}, unexpected={len(unexpected)})")
# load checkpoint
logging.info(f"loading {pretrained_path}")
raw_state = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0)
weights_dict = _normalize_vae_state_dict(raw_state) if normalize_state_dict else raw_state
for key in list(weights_dict.keys()):
if hasattr(weights_dict[key], "dtype") and weights_dict[key].dtype != dtype:
weights_dict[key] = weights_dict[key].to(dtype)
if strict:
model.load_state_dict(weights_dict, assign=True)
else:
missing, unexpected = model.load_state_dict(weights_dict, strict=False, assign=True)
logging.info(f"VAE checkpoint loaded with strict=False (missing={len(missing)}, unexpected={len(unexpected)})")

# Convert Conv3d weights to channels_last_3d for cuDNN optimization
if GET_USE_CHANNELS_LAST_3D():
Expand All @@ -1061,6 +1062,7 @@ def __init__(
vae_type="wan2.2",
lightvae_pruning_rate=None,
lightvae_encoder_vae_pth=None,
dummy_model=False,
**kwargs,
):
self.dtype = dtype
Expand Down Expand Up @@ -1194,6 +1196,7 @@ def __init__(
normalize_state_dict=False,
strict=True,
pruning_rate=0.0,
dummy_model=dummy_model,
)
.eval()
.requires_grad_(False)
Expand Down Expand Up @@ -1223,6 +1226,7 @@ def __init__(
normalize_state_dict=True,
strict=False,
pruning_rate=0.0,
dummy_model=dummy_model,
)
.eval()
.requires_grad_(False)
Expand All @@ -1242,6 +1246,7 @@ def __init__(
normalize_state_dict=True,
strict=False,
pruning_rate=resolved_pruning_rate,
dummy_model=dummy_model,
)
.eval()
.requires_grad_(False)
Expand Down
Loading