Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions src/scope/core/pipelines/krea_realtime_video/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,27 @@ def __init__(
)

print(f"Quantized diffusion model to fp8 in {time.time() - start:.3f}s")

if compile:
# Float8DynamicActivationFloat8WeightConfig is incompatible with
# torch.compile(fullgraph=False): AOT autograd's gen_alias_from_base
# calls aten.as_strided on Float8Tensor outputs, which is not
# implemented. Skip block compilation when FP8 is active.
# See: https://github.com/daydreamlive/scope/issues/669
logger.warning(
"Skipping torch.compile for attention blocks: "
"Float8DynamicActivationFloat8WeightConfig is not compatible "
"with fullgraph=False compilation (aten.as_strided unsupported "
"on Float8Tensor). FP8 quantization is still active."
)
Comment on lines +141 to +152
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

FP8 + compile=True still triggers flex_attention warmup.

When FP8 quantization is active and compile=True, the code:

  1. Logs the warning and skips block.compile() (correct)
  2. But compile remains True, so initial_kv_bias is set to 0.3 (line 212)
  3. Warmup runs (line 232) with bias < 1.0, which per lines 223-226 "would otherwise enter the flex_attention code path... and trigger torch._dynamo tracing"

This defeats the purpose of skipping compilation for FP8. Consider tracking whether compilation actually occurred:

Proposed fix
+        # Track whether block compilation actually happens (FP8 is incompatible)
+        did_compile = False
+
         if quantization == Quantization.FP8_E4M3FN:
             # Cast before optional quantization
             generator = generator.to(dtype=dtype)
@@ -140,6 +143,7 @@
         else:
             generator = generator.to(device=device, dtype=dtype)

             if compile:
                 # Only compile the attention blocks
                 for block in generator.model.blocks:
                     # Disable fullgraph right now due to issues with RoPE
                     block.compile(fullgraph=False)
+                did_compile = True

         # ... later ...

         initial_kv_bias = (
-            DEFAULT_KV_CACHE_ATTENTION_BIAS if compile else KV_CACHE_ATTENTION_BIAS_DISABLED
+            DEFAULT_KV_CACHE_ATTENTION_BIAS if did_compile else KV_CACHE_ATTENTION_BIAS_DISABLED
         )

         # ... and ...

-        if compile:
+        if did_compile:
             local_attn_size = getattr(model_config, "local_attn_size", 6)

Also applies to: 211-214, 232-250

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/scope/core/pipelines/krea_realtime_video/pipeline.py` around lines 141 -
152, The current logic logs skipping block.compile() when FP8 is active but
leaves the `compile` flag true, causing `initial_kv_bias` and the warmup path to
still run and trigger flex_attention tracing; update the flow so you track
whether compilation actually happened (e.g., set a new boolean like `compiled`
or flip `compile` to False after skipping) immediately after the Float8 check
where `block.compile()` is skipped, and then use that actual compilation
indicator in the subsequent logic that sets `initial_kv_bias` and controls the
warmup/flex_attention branch (the code around `block.compile()`,
`initial_kv_bias`, and the warmup lines) so warmup does not run when compilation
was skipped due to FP8.

else:
generator = generator.to(device=device, dtype=dtype)

if compile:
# Only compile the attention blocks
for block in generator.model.blocks:
# Disable fullgraph right now due to issues with RoPE
block.compile(fullgraph=False)
if compile:
# Only compile the attention blocks
for block in generator.model.blocks:
# Disable fullgraph right now due to issues with RoPE
block.compile(fullgraph=False)

# Load VAE using create_vae factory (supports multiple VAE types)
# Note: VAE is shared across all Wan model sizes, stored in Wan2.1-T2V-1.3B
Expand Down Expand Up @@ -189,7 +202,16 @@ def __init__(
# does not work properly
self.state.set("current_start_frame", 0)
self.state.set("manage_cache", True)
self.state.set("kv_cache_attention_bias", DEFAULT_KV_CACHE_ATTENTION_BIAS)
# When compile=False the flex_attention path (and its torch.compile call)
# must be bypassed entirely. KV_CACHE_ATTENTION_BIAS_DISABLED (1.0) is
# the sentinel that makes causal_model.py skip flex_attention and take the
# standard attention path, so use it whenever compilation is disabled.
from .modules.causal_model import KV_CACHE_ATTENTION_BIAS_DISABLED

initial_kv_bias = (
DEFAULT_KV_CACHE_ATTENTION_BIAS if compile else KV_CACHE_ATTENTION_BIAS_DISABLED
)
self.state.set("kv_cache_attention_bias", initial_kv_bias)

self.state.set("height", config.height)
self.state.set("width", config.width)
Expand All @@ -198,25 +220,34 @@ def __init__(
# Warm-up: Run enough iterations to fill the KV cache completely.
# This ensures torch.compile compiles the flex_attention kernel at the
# steady-state cache size, avoiding recompilation during actual streaming.
# Skipped when compile=False because there is no compiled kernel to prime
# and the warmup loop would otherwise enter the flex_attention code path
# (via DEFAULT_KV_CACHE_ATTENTION_BIAS) and trigger torch._dynamo tracing
# even though block.compile() was never called.
#
# Cache fills at: num_frame_per_block frames per iteration
# Cache capacity: local_attn_size frames
# Iterations needed: ceil(local_attn_size / num_frame_per_block) + 1
# (+1 to exercise the "cache full with eviction" path)
local_attn_size = getattr(model_config, "local_attn_size", 6)
num_frame_per_block = getattr(model_config, "num_frame_per_block", 3)
warmup_runs = (local_attn_size // num_frame_per_block) + 1

if stage_callback:
stage_callback("Warming up model...")
start = time.time()
for i in range(warmup_runs):
self._generate(
prompts=WARMUP_PROMPT,
init_cache=(i == 0), # Only init on first run, then accumulate
)
if compile:
local_attn_size = getattr(model_config, "local_attn_size", 6)
num_frame_per_block = getattr(model_config, "num_frame_per_block", 3)
warmup_runs = (
(local_attn_size + num_frame_per_block - 1) // num_frame_per_block
) + 1

if stage_callback:
stage_callback("Warming up model...")
start = time.time()
for i in range(warmup_runs):
self._generate(
prompts=WARMUP_PROMPT,
init_cache=(i == 0), # Only init on first run, then accumulate
)

print(f"Warmed up ({warmup_runs} runs) in {time.time() - start:.2f}s")
print(f"Warmed up ({warmup_runs} runs) in {time.time() - start:.2f}s")
else:
logger.info("torch.compile disabled — skipping warmup (no compiled kernel to prime)")

self.first_call = True
self.last_mode = None # Track mode for transition detection
Expand Down
39 changes: 34 additions & 5 deletions src/scope/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def __init__(self):
# Loading stage for frontend display (e.g., "Loading diffusion model...")
self._loading_stage: str | None = None

# Set to True if torch._dynamo.reset() failed during an unload; stale
# Dynamo/FP8 compile caches may still be present, so force compile=False
# on all subsequent pipeline loads until the worker process restarts.
self._dynamo_reset_failed: bool = False

def set_loading_stage(self, stage: str | None) -> None:
"""Set the current loading stage (thread-safe)."""
with self._lock:
Expand Down Expand Up @@ -691,6 +696,21 @@ def _unload_pipeline_by_id_unsafe(
except Exception as e:
logger.warning(f"CUDA cleanup failed: {e}")

# Reset torch.compile compilation cache to prevent stale compiled graphs
# (especially those specialized for Float8Tensor weights) from leaking into
# subsequently loaded pipelines. Without this, longlive's FP8-compiled graph
# cache can corrupt Krea's compile attempt, causing as_strided dispatch errors.
try:
torch._dynamo.reset()
logger.info("torch._dynamo cache reset")
except Exception as e:
logger.warning(
f"torch._dynamo reset failed: {e}. "
"Stale compile caches may remain in this worker; "
"forcing compile=False for all subsequent pipeline loads."
)
self._dynamo_reset_failed = True

# Publish pipeline_unloaded event
publish_event(
event_type="pipeline_unloaded",
Expand Down Expand Up @@ -959,14 +979,23 @@ def _load_pipeline_implementation(
if load_params:
quantization = load_params.get("quantization", None)

# Only compile diffusion model for hopper; skip if a prior
# torch._dynamo.reset() failed (stale caches would cause a crash).
_hopper_gpu = torch.cuda.is_available() and any(
x in torch.cuda.get_device_name(0).lower()
for x in ("h100", "hopper")
)
_should_compile = _hopper_gpu and not self._dynamo_reset_failed
if _hopper_gpu and self._dynamo_reset_failed:
logger.warning(
"torch._dynamo reset previously failed; disabling torch.compile "
"for krea-realtime-video to avoid stale-cache crash. "
"Restart the worker process to re-enable compilation."
)
pipeline = KreaRealtimeVideoPipeline(
config,
quantization=quantization,
# Only compile diffusion model for hopper right now
compile=any(
x in torch.cuda.get_device_name(0).lower()
for x in ("h100", "hopper")
),
compile=_should_compile,
device=torch.device("cuda"),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
dtype=torch.bfloat16,
stage_callback=stage_callback,
Expand Down
Loading