Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions src/scope/core/pipelines/krea_realtime_video/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,27 @@ def __init__(
)

print(f"Quantized diffusion model to fp8 in {time.time() - start:.3f}s")

if compile:
# Float8DynamicActivationFloat8WeightConfig is incompatible with
# torch.compile(fullgraph=False): AOT autograd's gen_alias_from_base
# calls aten.as_strided on Float8Tensor outputs, which is not
# implemented. Skip block compilation when FP8 is active.
# See: https://github.com/daydreamlive/scope/issues/669
logger.warning(
"Skipping torch.compile for attention blocks: "
"Float8DynamicActivationFloat8WeightConfig is not compatible "
"with fullgraph=False compilation (aten.as_strided unsupported "
"on Float8Tensor). FP8 quantization is still active."
)
else:
generator = generator.to(device=device, dtype=dtype)

if compile:
# Only compile the attention blocks
for block in generator.model.blocks:
# Disable fullgraph right now due to issues with RoPE
block.compile(fullgraph=False)
if compile:
# Only compile the attention blocks
for block in generator.model.blocks:
# Disable fullgraph right now due to issues with RoPE
block.compile(fullgraph=False)

# Load VAE using create_vae factory (supports multiple VAE types)
# Note: VAE is shared across all Wan model sizes, stored in Wan2.1-T2V-1.3B
Expand Down Expand Up @@ -189,7 +202,18 @@ def __init__(
# does not work properly
self.state.set("current_start_frame", 0)
self.state.set("manage_cache", True)
self.state.set("kv_cache_attention_bias", DEFAULT_KV_CACHE_ATTENTION_BIAS)
# When compile=False the flex_attention path (and its torch.compile call)
# must be bypassed entirely. KV_CACHE_ATTENTION_BIAS_DISABLED (1.0) is
# the sentinel that makes causal_model.py skip flex_attention and take the
# standard attention path, so use it whenever compilation is disabled.
from .modules.causal_model import KV_CACHE_ATTENTION_BIAS_DISABLED

initial_kv_bias = (
DEFAULT_KV_CACHE_ATTENTION_BIAS
if compile
else KV_CACHE_ATTENTION_BIAS_DISABLED
)
self.state.set("kv_cache_attention_bias", initial_kv_bias)

self.state.set("height", config.height)
self.state.set("width", config.width)
Expand All @@ -198,26 +222,37 @@ def __init__(
# Warm-up: Run enough iterations to fill the KV cache completely.
# This ensures torch.compile compiles the flex_attention kernel at the
# steady-state cache size, avoiding recompilation during actual streaming.
# Skipped when compile=False because there is no compiled kernel to prime
# and the warmup loop would otherwise enter the flex_attention code path
# (via DEFAULT_KV_CACHE_ATTENTION_BIAS) and trigger torch._dynamo tracing
# even though block.compile() was never called.
#
# Cache fills at: num_frame_per_block frames per iteration
# Cache capacity: local_attn_size frames
# Iterations needed: ceil(local_attn_size / num_frame_per_block) + 1
# (+1 to exercise the "cache full with eviction" path)
local_attn_size = getattr(model_config, "local_attn_size", 6)
num_frame_per_block = getattr(model_config, "num_frame_per_block", 3)
warmup_runs = (local_attn_size // num_frame_per_block) + 1
if compile:
local_attn_size = getattr(model_config, "local_attn_size", 6)
num_frame_per_block = getattr(model_config, "num_frame_per_block", 3)
warmup_runs = (
(local_attn_size + num_frame_per_block - 1) // num_frame_per_block
) + 1

if stage_callback:
stage_callback("Warming up model...")
start = time.time()
for i in range(warmup_runs):
self._generate(
prompts=WARMUP_PROMPT,
init_cache=(i == 0), # Only init on first run, then accumulate
)

if stage_callback:
stage_callback("Warming up model...")
start = time.time()
for i in range(warmup_runs):
self._generate(
prompts=WARMUP_PROMPT,
init_cache=(i == 0), # Only init on first run, then accumulate
print(f"Warmed up ({warmup_runs} runs) in {time.time() - start:.2f}s")
else:
logger.info(
"torch.compile disabled — skipping warmup (no compiled kernel to prime)"
)

print(f"Warmed up ({warmup_runs} runs) in {time.time() - start:.2f}s")

self.first_call = True
self.last_mode = None # Track mode for transition detection

Expand Down
35 changes: 35 additions & 0 deletions src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,32 @@ def _encode_with_conditioning(self, components, block_state, current_start):

batch_size, channels, num_frames, height, width = input_frames_data.shape

# Guard against temporal underflow: the WAN VAE encoder has a 3x1x1 temporal
# convolution kernel. After downsampling by vae_temporal_downsample_factor (4),
# the latent temporal dimension must be ≥ 3. That means the pixel-space input
# needs at least num_frame_per_block * vae_temporal_downsample_factor frames.
# When the input chunk is shorter (e.g. the very first tiny chunk in
# StreamDiffusionV2 or a user-supplied clip with < 12 frames) we pad the
# input by repeating the last frame rather than hard-failing with a PyTorch
# convolution error.
min_frames = (
components.config.num_frame_per_block
* components.config.vae_temporal_downsample_factor
)
if num_frames < min_frames:
logger.warning(
f"VaceEncodingBlock._encode_with_conditioning: vace_input_frames has only "
f"{num_frames} frames but the WAN VAE temporal kernel requires at least "
f"{min_frames} (num_frame_per_block={components.config.num_frame_per_block} × "
f"vae_temporal_downsample_factor={components.config.vae_temporal_downsample_factor}). "
f"Padding to {min_frames} frames by repeating the last frame."
)
pad_amount = min_frames - num_frames
last_frame = input_frames_data[:, :, -1:, :, :] # [B, C, 1, H, W]
padding = last_frame.expand(-1, -1, pad_amount, -1, -1)
input_frames_data = torch.cat([input_frames_data, padding], dim=2)
num_frames = min_frames

# Validate resolution
if height != block_state.height or width != block_state.width:
raise ValueError(
Expand Down Expand Up @@ -767,6 +793,15 @@ def _encode_with_conditioning(self, components, block_state, current_start):
raise ValueError(
f"VaceEncodingBlock._encode_with_conditioning: vace_input_masks must have 1 channel, got {mask_channels}"
)

# Pad masks to match padded frames length if needed (mirrors frame padding above)
if mask_frames < num_frames:
pad_amount = num_frames - mask_frames
last_mask = input_masks_data[:, :, -1:, :, :]
mask_padding = last_mask.expand(-1, -1, pad_amount, -1, -1)
input_masks_data = torch.cat([input_masks_data, mask_padding], dim=2)
mask_frames = num_frames

if (
mask_frames != num_frames
or mask_height != height
Expand Down
38 changes: 33 additions & 5 deletions src/scope/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def __init__(self):
# Loading stage for frontend display (e.g., "Loading diffusion model...")
self._loading_stage: str | None = None

# Set to True if torch._dynamo.reset() failed during an unload; stale
# Dynamo/FP8 compile caches may still be present, so force compile=False
# on all subsequent pipeline loads until the worker process restarts.
self._dynamo_reset_failed: bool = False

def set_loading_stage(self, stage: str | None) -> None:
"""Set the current loading stage (thread-safe)."""
with self._lock:
Expand Down Expand Up @@ -691,6 +696,21 @@ def _unload_pipeline_by_id_unsafe(
except Exception as e:
logger.warning(f"CUDA cleanup failed: {e}")

# Reset torch.compile compilation cache to prevent stale compiled graphs
# (especially those specialized for Float8Tensor weights) from leaking into
# subsequently loaded pipelines. Without this, longlive's FP8-compiled graph
# cache can corrupt Krea's compile attempt, causing as_strided dispatch errors.
try:
torch._dynamo.reset()
logger.info("torch._dynamo cache reset")
except Exception as e:
logger.warning(
f"torch._dynamo reset failed: {e}. "
"Stale compile caches may remain in this worker; "
"forcing compile=False for all subsequent pipeline loads."
)
self._dynamo_reset_failed = True

# Publish pipeline_unloaded event
publish_event(
event_type="pipeline_unloaded",
Expand Down Expand Up @@ -959,14 +979,22 @@ def _load_pipeline_implementation(
if load_params:
quantization = load_params.get("quantization", None)

# Only compile diffusion model for hopper; skip if a prior
# torch._dynamo.reset() failed (stale caches would cause a crash).
_hopper_gpu = torch.cuda.is_available() and any(
x in torch.cuda.get_device_name(0).lower() for x in ("h100", "hopper")
)
_should_compile = _hopper_gpu and not self._dynamo_reset_failed
if _hopper_gpu and self._dynamo_reset_failed:
logger.warning(
"torch._dynamo reset previously failed; disabling torch.compile "
"for krea-realtime-video to avoid stale-cache crash. "
"Restart the worker process to re-enable compilation."
)
pipeline = KreaRealtimeVideoPipeline(
config,
quantization=quantization,
# Only compile diffusion model for hopper right now
compile=any(
x in torch.cuda.get_device_name(0).lower()
for x in ("h100", "hopper")
),
compile=_should_compile,
device=torch.device("cuda"),
dtype=torch.bfloat16,
stage_callback=stage_callback,
Expand Down
Loading