Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/scope/core/pipelines/wan2_1/components/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Modified from https://github.com/guandeh17/Self-Forcing
import inspect
import json
import logging
import os
import types

Expand All @@ -10,6 +11,8 @@

from .scheduler import FlowMatchScheduler, SchedulerInterface

logger = logging.getLogger(__name__)


def filter_causal_model_cls_config(causal_model_cls, config):
# Filter config to only include parameters accepted by the model's __init__
Expand Down Expand Up @@ -73,7 +76,33 @@ def __init__(
self.model = self.model.to_empty(device="cpu")
# Then load the state dict weights
# Use strict=False to allow partial loading (e.g., VACE model with non-VACE checkpoint)
self.model.load_state_dict(state_dict, assign=True, strict=False)
try:
self.model.load_state_dict(state_dict, assign=True, strict=False)
except RuntimeError as e:
err_str = str(e)
# Detect tensor shape/size mismatch — symptom of a stale or corrupt
# model checkpoint whose weights no longer match the current model
# architecture (e.g. old 1.3B weights cached on a worker that now
# expects 14B shapes).
if "size of tensor" in err_str or "size mismatch" in err_str:
logger.error(
f"Checkpoint shape mismatch loading '{generator_path}': {e}. "
"This indicates a stale or incompatible cached model file. "
"Deleting the checkpoint so it will be re-downloaded on next startup."
)
try:
os.remove(generator_path)
logger.info(f"Deleted stale checkpoint: {generator_path}")
except OSError as rm_err:
logger.warning(
f"Could not delete stale checkpoint '{generator_path}': {rm_err}"
)
raise RuntimeError(
f"Checkpoint '{generator_path}' has incompatible tensor shapes "
f"and has been deleted ({e}). "
"Please retry — the model will be re-downloaded automatically."
) from e
raise

# HACK!
# Reinitialize self.freqs properly on CPU (it's not in state_dict)
Expand Down
Loading