Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/scope/core/pipelines/krea_realtime_video/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
config,
quantization: Quantization | None = None,
compile: bool = False,
compile_vae: bool = False,
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
stage_callback=None,
Expand Down Expand Up @@ -218,6 +219,18 @@ def __init__(

print(f"Warmed up ({warmup_runs} runs) in {time.time() - start:.2f}s")

if compile_vae:
if stage_callback:
stage_callback(
"Compiling VAE decoder (this may take several minutes)..."
)
self.components.vae.compile_decoder(config.height, config.width)
if stage_callback:
stage_callback(
"Compiling VAE encoder (this may take several minutes)..."
)
self.components.vae.compile_encoder(config.height, config.width)

self.first_call = True
self.last_mode = None # Track mode for transition detection

Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/krea_realtime_video/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class KreaRealtimeVideoConfig(BasePipelineConfig):
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
json_schema_extra=ui_field_config(order=3, is_load_param=True, label="VAE"),
)
compile_vae: bool = Field(
default=False,
description="Use torch.compile on the VAE decoder for ~1.4x faster decoding. First-time compilation takes several minutes.",
json_schema_extra=ui_field_config(
order=3, is_load_param=True, label="Compile VAE"
),
)
height: int = Field(
default=320,
ge=1,
Expand Down
13 changes: 13 additions & 0 deletions src/scope/core/pipelines/longlive/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self,
config,
quantization: Quantization | None = None,
compile_vae: bool = False,
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
stage_callback=None,
Expand Down Expand Up @@ -198,6 +199,18 @@ def __init__(
self.first_call = True
self.last_mode = None # Track mode for transition detection

if compile_vae:
if stage_callback:
stage_callback(
"Compiling VAE decoder (this may take several minutes)..."
)
self.components.vae.compile_decoder(config.height, config.width)
if stage_callback:
stage_callback(
"Compiling VAE encoder (this may take several minutes)..."
)
self.components.vae.compile_encoder(config.height, config.width)

def prepare(self, **kwargs) -> Requirements | None:
"""Return input requirements based on current mode."""
return prepare_for_mode(self.__class__, self.components.config, kwargs)
Expand Down
17 changes: 17 additions & 0 deletions src/scope/core/pipelines/longlive/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import ClassVar

from pydantic import Field

from ..artifacts import HuggingfaceRepoArtifact
Expand Down Expand Up @@ -66,6 +68,13 @@ class LongLiveConfig(BasePipelineConfig):
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
json_schema_extra=ui_field_config(order=3, is_load_param=True, label="VAE"),
)
compile_vae: bool = Field(
default=False,
description="Use torch.compile on the VAE decoder for ~1.4x faster decoding. First-time compilation takes several minutes.",
json_schema_extra=ui_field_config(
order=3, is_load_param=True, label="Compile VAE"
),
)
height: int = Field(
default=320,
ge=1,
Expand Down Expand Up @@ -125,6 +134,14 @@ class LongLiveConfig(BasePipelineConfig):
order=8, component="quantization", is_load_param=True
),
)
compile_vae: bool = Field(
default=False,
description="Use torch.compile for VAE encoder/decoder. First-time compilation takes several minutes.",
json_schema_extra=ui_field_config(
order=9, component="compile", is_load_param=True, label="Compile VAE"
),
)
vace_fp8_compatible: ClassVar[bool] = True

modes = {
"text": ModeDefaults(default=True),
Expand Down
21 changes: 15 additions & 6 deletions src/scope/core/pipelines/longlive/test_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@
# When extension + depth: anchor frames at first/last, follow depth structure in between
# When extension + r2v: anchor frames + style conditioning from reference images
"use_r2v": False, # Reference-to-Video: condition on reference images
"use_depth": True, # Depth guidance: structural control via depth maps
"use_depth": False, # Depth guidance: structural control via depth maps
"use_inpainting": True, # Inpainting: masked video-to-video generation
"use_extension": True, # Extension mode: temporal generation (firstframe/lastframe/firstlastframe)
"use_extension": False, # Extension mode: temporal generation (firstframe/lastframe/firstlastframe)
# ===== INPUT PATHS =====
# R2V: List of reference image paths (condition entire video, don't appear in output)
"ref_images": [
Expand All @@ -76,7 +76,7 @@
"prompt_depth": "a cat walking towards the camera", # Default prompt for depth mode
"prompt_inpainting": "a fireball", # Default prompt for inpainting mode
"prompt_extension": "", # Default prompt for extension mode
"num_chunks": 7, # Number of generation chunks
"num_chunks": 3, # Number of generation chunks
"frames_per_chunk": 12, # Frames per chunk (12 = 3 latent * 4 temporal upsample)
"height": 512,
"width": 512,
Expand All @@ -86,7 +86,9 @@
"mask_value": 127, # Gray value for masked regions (0-255)
# ===== OUTPUT =====
"output_dir": "vace_tests/unified", # path/to/output_dir
"vae_type": "tae",
"vae_type": "wan",
# ===== COMPILATION =====
"compile_vae": True, # Compile VAE encoder+decoder with torch.compile
}

# ========================= END CONFIGURATION =========================
Expand Down Expand Up @@ -471,6 +473,7 @@ def main():
print(f" Depth Guidance: {use_depth}")
print(f" Inpainting: {use_inpainting}")
print(f" Extension: {use_extension}")
print(f" Compile VAE: {config.get('compile_vae', False)}")
if use_extension:
print(f" Mode: {config['extension_mode']}")
print(f" Prompt: '{prompt}'")
Expand Down Expand Up @@ -529,8 +532,14 @@ def main():
)
pipeline_config.model_config.base_model_kwargs["vace_in_dim"] = 96

pipeline = LongLivePipeline(pipeline_config, device=device, dtype=torch.bfloat16)
print("Pipeline ready\n")
compile_vae = config.get("compile_vae", False)
pipeline = LongLivePipeline(
pipeline_config,
compile_vae=compile_vae,
device=device,
dtype=torch.bfloat16,
)
print(f"Pipeline ready (compile_vae={compile_vae})\n")

# Prepare inputs
total_frames = config["num_chunks"] * config["frames_per_chunk"]
Expand Down
8 changes: 8 additions & 0 deletions src/scope/core/pipelines/memflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self,
config,
quantization: Quantization | None = None,
compile_vae: bool = False,
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
stage_callback=None,
Expand Down Expand Up @@ -195,6 +196,13 @@ def __init__(
self.state.set("width", config.width)
self.state.set("base_seed", getattr(config, "base_seed", 42))

if compile_vae:
if stage_callback:
stage_callback(
"Compiling VAE decoder (this may take several minutes)..."
)
self.components.vae.compile_decoder(config.height, config.width)

self.first_call = True
self.last_mode = None # Track mode for transition detection

Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/memflow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ class MemFlowConfig(BasePipelineConfig):
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
json_schema_extra=ui_field_config(order=3, is_load_param=True, label="VAE"),
)
compile_vae: bool = Field(
default=False,
description="Use torch.compile on the VAE decoder for ~1.4x faster decoding. First-time compilation takes several minutes.",
json_schema_extra=ui_field_config(
order=3, is_load_param=True, label="Compile VAE"
),
)
height: int = Field(
default=320,
ge=1,
Expand Down
8 changes: 8 additions & 0 deletions src/scope/core/pipelines/reward_forcing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
self,
config,
quantization: Quantization | None = None,
compile_vae: bool = False,
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
stage_callback=None,
Expand Down Expand Up @@ -169,6 +170,13 @@ def __init__(
self.state.set("width", config.width)
self.state.set("base_seed", getattr(config, "base_seed", 42))

if compile_vae:
if stage_callback:
stage_callback(
"Compiling VAE decoder (this may take several minutes)..."
)
self.components.vae.compile_decoder(config.height, config.width)

self.first_call = True
self.last_mode = None # Track mode for transition detection

Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/reward_forcing/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ class RewardForcingConfig(BasePipelineConfig):
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
json_schema_extra=ui_field_config(order=3, is_load_param=True, label="VAE"),
)
compile_vae: bool = Field(
default=False,
description="Use torch.compile on the VAE decoder for ~1.4x faster decoding. First-time compilation takes several minutes.",
json_schema_extra=ui_field_config(
order=3, is_load_param=True, label="Compile VAE"
),
)
height: int = Field(
default=320,
ge=1,
Expand Down
8 changes: 8 additions & 0 deletions src/scope/core/pipelines/streamdiffusionv2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
self,
config,
quantization: Quantization | None = None,
compile_vae: bool = False,
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
stage_callback=None,
Expand Down Expand Up @@ -173,6 +174,13 @@ def __init__(
self.state.set("width", config.width)
self.state.set("base_seed", getattr(config, "base_seed", 42))

if compile_vae:
if stage_callback:
stage_callback(
"Compiling VAE decoder (this may take several minutes)..."
)
self.components.vae.compile_decoder(config.height, config.width)

self.first_call = True
self.last_mode = None # Track mode for transition detection

Expand Down
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/streamdiffusionv2/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ class StreamDiffusionV2Config(BasePipelineConfig):
description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
json_schema_extra=ui_field_config(order=3, is_load_param=True, label="VAE"),
)
compile_vae: bool = Field(
default=False,
description="Use torch.compile on the VAE decoder for ~1.4x faster decoding. First-time compilation takes several minutes.",
json_schema_extra=ui_field_config(
order=3, is_load_param=True, label="Compile VAE"
),
)
height: int = Field(
default=512,
ge=1,
Expand Down
Loading
Loading