diff --git a/Start_StreamDiffusion.bat b/Start_StreamDiffusion.bat new file mode 100644 index 00000000..631645b4 --- /dev/null +++ b/Start_StreamDiffusion.bat @@ -0,0 +1,23 @@ +@echo off +cd /d %~dp0 + +:: ─── CUDA / PyTorch Performance Tuning ─── +:: Prevents memory fragmentation from per-frame torch.cat allocations +set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,expandable_segments:True +:: Defers CUDA module loading until first use (~1-5s faster startup) +set CUDA_MODULE_LOADING=LAZY +:: Enables cuDNN v8 graph API for better conv kernel selection (VAE, preprocessors) +set TORCH_CUDNN_V8_API_ENABLED=1 +:: Ensures async kernel launches (default=0, but explicit protects against debug leftovers) +set CUDA_LAUNCH_BLOCKING=0 +:: Caches compiled Triton kernels to disk (eliminates 30-60s JIT warmup on restart) +set TORCHINDUCTOR_FX_GRAPH_CACHE=1 + +if exist venv ( + call venv\Scripts\activate.bat + venv\Scripts\python.exe streamdiffusionTD\td_main.py +) else ( + call .venv\Scripts\activate.bat + .venv\Scripts\python.exe streamdiffusionTD\td_main.py +) +pause diff --git a/StreamDiffusionTD/td_main.py b/StreamDiffusionTD/td_main.py new file mode 100644 index 00000000..e425c03c --- /dev/null +++ b/StreamDiffusionTD/td_main.py @@ -0,0 +1,544 @@ +""" +TouchDesigner StreamDiffusion Main Entry Point + +Minimal main script leveraging the full DotSimulate StreamDiffusion fork capabilities. +Replaces the complex main_sdtd.py with a cleaner, config-driven approach. + +Reads configuration from td_config.yaml (single source of truth) +""" + +import logging +import os +import sys +import threading +import time + +import yaml +from pythonosc import udp_client + + +class OSCReporter: + """ + Lightweight OSC reporter for state, telemetry, and errors. + Lives at module level - starts before any heavy imports. + No dependencies on torch, diffusers, or StreamDiffusion. + """ + + def __init__(self, transmit_ip: str, transmit_port: int): + self.client = udp_client.SimpleUDPClient(transmit_ip, transmit_port) + self._heartbeat_running = False + self._heartbeat_thread = None + self._current_state = "local_offline" + self._vram_monitor_thread = None + self._vram_monitor_running = False + + # === Heartbeat === + def start_heartbeat(self) -> None: + """Start server heartbeat immediately (Serveractive=True in TD)""" + if self._heartbeat_running: + return + self._heartbeat_running = True + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_loop, daemon=True + ) + self._heartbeat_thread.start() + + def stop_heartbeat(self) -> None: + self._heartbeat_running = False + + def _heartbeat_loop(self) -> None: + while self._heartbeat_running: + self.client.send_message("/server_active", 1) + # Also send current state with heartbeat (every 1-2 seconds during startup) + if self._current_state: + self.client.send_message("/stream-info/state", self._current_state) + time.sleep(1.0) + + # === State Management === + def set_state(self, state: str) -> None: + """Set and broadcast connection state""" + self._current_state = state + self.client.send_message("/stream-info/state", state) + + def send_error(self, error_msg: str, error_time: int = None) -> None: + """Send error message with timestamp""" + if error_time is None: + error_time = int(time.time() * 1000) + self.client.send_message("/stream-info/error", error_msg) + self.client.send_message("/stream-info/error-time", error_time) + + # === Telemetry (called by manager) === + def send_output_name(self, name: str) -> None: + self.client.send_message("/stream-info/output-name", name) + + def send_frame_count(self, count: int) -> None: + self.client.send_message("/framecount", count) + + def send_frame_ready(self, count: int) -> None: + self.client.send_message("/frame_ready", count) + + def send_fps(self, fps: float) -> None: + self.client.send_message("/stream-info/fps", fps) + + def send_controlnet_processed_name(self, name: str) -> None: + self.client.send_message("/stream-info/controlnet-processed-name", name) + + # === VRAM Monitoring (Non-Blocking) === + def start_vram_monitoring(self, interval: float = 2.0) -> None: + """Start periodic VRAM monitoring (every 2 seconds by default)""" + if self._vram_monitor_running: + return + self._vram_monitor_running = True + self._vram_monitor_thread = threading.Thread( + target=self._vram_monitor_loop, args=(interval,), daemon=True + ) + self._vram_monitor_thread.start() + + def stop_vram_monitoring(self) -> None: + self._vram_monitor_running = False + + def _vram_monitor_loop(self, interval: float) -> None: + """Periodic VRAM monitoring using torch.cuda.mem_get_info() for accurate tracking""" + import torch + + while self._vram_monitor_running: + try: + if torch.cuda.is_available(): + # Use mem_get_info() to get ACTUAL free/total VRAM (includes TensorRT!) + free_cuda, total_cuda = torch.cuda.mem_get_info( + 0 + ) # Returns (free, total) in bytes + total = total_cuda / (1024**3) + used = (total_cuda - free_cuda) / ( + 1024**3 + ) # ACTUAL usage including TensorRT + + # Only send useful metrics (total and used) + self.client.send_message("/vram/total", total) + self.client.send_message("/vram/used", used) + + except Exception: + pass # Silently skip if monitoring fails + + time.sleep(interval) + + # === Engine Building Progress === + def send_engine_progress(self, stage: str, model_name: str = "") -> None: + """ + Send engine building progress. + Stages: 'exporting_onnx', 'optimizing_onnx', 'building_engine', 'cached', 'complete' + """ + self.client.send_message("/engine/stage", stage) + if model_name: + self.client.send_message("/engine/model", model_name) + + +class OSCLoggingHandler(logging.Handler): + """ + Custom logging handler that sends ERROR logs to TouchDesigner via OSC. + Also captures engine building progress messages. + """ + + def __init__(self, osc_reporter: OSCReporter): + super().__init__() + self.osc_reporter = osc_reporter + # Only capture ERROR and CRITICAL + self.setLevel(logging.ERROR) + + def emit(self, record: logging.LogRecord): + try: + # Check if this is a reportable error (from report_error()) + if hasattr(record, "report_error") and record.report_error: + error_msg = self.format(record) + self.osc_reporter.send_error(error_msg) + + # Also send all ERROR/CRITICAL logs + elif record.levelno >= logging.ERROR: + error_msg = self.format(record) + self.osc_reporter.send_error(error_msg) + + # Capture engine building progress from log messages + msg = record.getMessage() + if "Exporting model:" in msg: + model_name = msg.split("Exporting model:")[-1].strip() + self.osc_reporter.send_engine_progress("exporting_onnx", model_name) + elif "Generating optimizing model:" in msg: + model_name = msg.split("Generating optimizing model:")[-1].strip() + self.osc_reporter.send_engine_progress("optimizing_onnx", model_name) + elif "Found cached engine:" in msg: + self.osc_reporter.send_engine_progress("cached") + elif "Building TensorRT engine" in msg or "Building engine" in msg: + self.osc_reporter.send_engine_progress("building_engine") + + except Exception: + self.handleError(record) + + +# Read config FIRST +script_dir = os.path.dirname(os.path.abspath(__file__)) +yaml_config_path = os.path.join(script_dir, "td_config.yaml") +with open(yaml_config_path, "r") as f: + yaml_config = yaml.safe_load(f) + +td_settings = yaml_config.get("td_settings", {}) +osc_port = td_settings.get("osc_receive_port", 8567) + +# Create reporter and start heartbeat IMMEDIATELY +osc_reporter = OSCReporter("127.0.0.1", osc_port) +osc_reporter.start_heartbeat() # Serveractive=True NOW in TouchDesigner +osc_reporter.start_vram_monitoring(interval=2.0) # Monitor VRAM every 2 seconds +osc_reporter.set_state("local_starting_server") + +# Add OSC logging handler to root logger (captures all logs) +osc_log_handler = OSCLoggingHandler(osc_reporter) +logging.root.addHandler(osc_log_handler) + + +import signal +import warnings + + +# Configure warnings to display in dark cyan (low saturation, dark) BEFORE any imports +def warning_format(message, category, filename, lineno, file=None, line=None): + return f"\033[38;5;66m{filename}:{lineno}: {category.__name__}: {message}\033[0m\n" + + +warnings.formatwarning = warning_format + +# ─── CUDA / PyTorch env var defaults (set before any torch import) ─── +# Only set if not already provided by the launch script or environment +if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" +if "CUDA_MODULE_LOADING" not in os.environ: + os.environ["CUDA_MODULE_LOADING"] = "LAZY" + +# Loading animation +print("\033[38;5;80mLoading StreamDiffusionTD", end="", flush=True) +for _ in range(3): + time.sleep(0.33) + print(".", end="", flush=True) +print("\033[0m") +time.sleep(0.01) + +# Clear the loading line and add spacing +print("\r" + " " * 50 + "\r", end="") # Clear line +print("\n") # One blank line + +# ASCII Art Logo +print("\033[38;5;208m ┌──────────────────────────────────────┐") +print(" │ ▶ StreamDiffusionTD │") +print(" │ ▓▓▒▒░░ real-time diffusion │") +print(" │ ░░▒▒▓▓ TOX by dotsimulate │") +print(" │ ────────────────────────── v0.3.0 │") +print(" └──────────────────────────────────────┘") +print(" StreamDiffusion: cumulo-autumn • Daydream\033[0m\n\n") + +# Add StreamDiffusion to path +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +# Suppress known third-party deprecation warnings (timm, cuda-python, controlnet_aux) +import warnings + +warnings.filterwarnings( + "ignore", message=".*timm.models.layers.*", category=FutureWarning +) +warnings.filterwarnings( + "ignore", message=".*timm.models.registry.*", category=FutureWarning +) +warnings.filterwarnings("ignore", message=".*cuda.cudart.*", category=FutureWarning) +warnings.filterwarnings("ignore", message=".*not expected and will be ignored.*") +warnings.filterwarnings( + "ignore", message=".*Overwriting tiny_vit.*", category=UserWarning +) + +# Heavy imports with error handling +try: + print("\033[38;5;66mImporting StreamDiffusion modules...\033[0m") + from td_manager import TouchDesignerManager + from td_osc_handler import OSCParameterHandler + + print("\033[38;5;34m✓ Imports loaded\033[0m\n") +except ImportError as e: + error_msg = f"Import failed: {e}" + logging.error(error_msg) # Also log so it appears in console + osc_reporter.send_error(error_msg) + osc_reporter.set_state("local_error") + raise +except Exception as e: + error_msg = f"Initialization failed: {e}" + logging.error(error_msg) # Also log so it appears in console + osc_reporter.send_error(error_msg) + osc_reporter.set_state("local_error") + raise + +# Print YAML config in sections with animated display +print() # One blank line before YAML +script_dir = os.path.dirname(os.path.abspath(__file__)) +yaml_config_path = os.path.join(script_dir, "td_config.yaml") + +with open(yaml_config_path, "r") as f: + yaml_content = f.read() + +# Split into sections based on top-level keys or comment headers +current_section = [] +lines = yaml_content.split("\n") + +for line in lines: + stripped = line.lstrip() + + # Check if this is a section break (top-level key or major comment) + is_section_break = ( + (not stripped or stripped.startswith("#")) + and current_section + and any(l.strip() and not l.strip().startswith("#") for l in current_section) + ) + + if is_section_break: + # Print accumulated section + for section_line in current_section: + section_stripped = section_line.lstrip() + if not section_stripped or section_stripped.startswith("#"): + print(f"\033[38;5;66m {section_line}\033[0m") + else: + indent = len(section_line) - len(section_stripped) + if indent == 0: + print(f"\033[38;5;80m {section_line}\033[0m") + elif indent == 2: + print(f"\033[38;5;75m {section_line}\033[0m") + elif indent == 4: + print(f"\033[38;5;105m {section_line}\033[0m") + else: + print(f"\033[38;5;111m {section_line}\033[0m") + time.sleep(0.05) # 50ms delay between sections + current_section = [line] + else: + current_section.append(line) + +# Print final section +for section_line in current_section: + section_stripped = section_line.lstrip() + if not section_stripped or section_stripped.startswith("#"): + print(f"\033[38;5;66m {section_line}\033[0m") + else: + indent = len(section_line) - len(section_stripped) + if indent == 0: + print(f"\033[38;5;80m {section_line}\033[0m") + elif indent == 2: + print(f"\033[38;5;75m {section_line}\033[0m") + elif indent == 4: + print(f"\033[38;5;105m {section_line}\033[0m") + else: + print(f"\033[38;5;111m {section_line}\033[0m") + + +class ColoredFormatter(logging.Formatter): + """Custom formatter with ANSI color codes for different log levels""" + + # ANSI color codes + GREY = "\033[90m" + CYAN = "\033[36m" + YELLOW = "\033[33m" + RED = "\033[31m" + BOLD_RED = "\033[91m" + RESET = "\033[0m" + + FORMATS = { + logging.DEBUG: GREY + + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + RESET, + logging.INFO: GREY + + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + RESET, + logging.WARNING: YELLOW + + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + RESET, + logging.ERROR: RED + + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + RESET, + logging.CRITICAL: BOLD_RED + + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + RESET, + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S") + return formatter.format(record) + + +class StreamDiffusionTD: + """Main application class""" + + def __init__(self, osc_reporter): + self.osc_reporter = osc_reporter # Store reporter reference + + # Load YAML config first to get debug_mode setting + script_dir = os.path.dirname(os.path.abspath(__file__)) + yaml_config_path = os.path.join(script_dir, "td_config.yaml") + + from streamdiffusion.config import load_config + + yaml_config = load_config(yaml_config_path) + + # Get TouchDesigner-specific settings from YAML + td_settings = yaml_config.get("td_settings", {}) + + # Configure logging based on debug_mode from YAML + debug_mode = td_settings.get("debug_mode", False) + log_level = logging.DEBUG if debug_mode else logging.WARNING + + # Configure root logger with colored formatter + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + + # Remove existing handlers (except our OSC handler) + for handler in root_logger.handlers[:]: + if not isinstance(handler, OSCLoggingHandler): + root_logger.removeHandler(handler) + + # Add colored console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(ColoredFormatter()) + root_logger.addHandler(console_handler) + + # Show debug mode status if enabled + if debug_mode: + print("\033[38;5;208m [DEBUG MODE ENABLED]\033[0m\n") + + input_mem = td_settings.get("input_mem_name", "input_stream") + + # Make output memory name unique (prevents conflicts & different resolutions) + base_output_name = td_settings.get("output_mem_name", "sd_to_td") + output_mem = f"{base_output_name}_{int(time.time())}" + + # Get OSC ports from YAML + listen_port = td_settings.get("osc_transmit_port", 8247) # Python listens + transmit_port = td_settings.get("osc_receive_port", 8248) # Python transmits + + # Initialize core manager with clean YAML config, debug flag, AND osc_reporter + osc_reporter.set_state("local_loading_models") + print("\033[38;5;66mLoading AI models (this may take time)...\033[0m") + + try: + self.manager = TouchDesignerManager( + yaml_config, + input_mem, + output_mem, + debug_mode=debug_mode, + osc_reporter=osc_reporter, # Pass reporter to manager + ) + print("\033[38;5;34m✓ Models loaded\033[0m") + osc_reporter.set_state("local_ready") + except Exception as e: + error_msg = f"Model loading failed: {e}" + logging.error(error_msg, exc_info=True) + osc_reporter.send_error(error_msg) + osc_reporter.set_state("local_error") + raise + + # Initialize OSC handler after manager (parameters only, no heartbeat) + self.osc_handler = OSCParameterHandler( + manager=self.manager, + main_app=self, # Pass main app for shutdown handling + listen_port=listen_port, + transmit_port=transmit_port, + debug_mode=debug_mode, + ) + + # Now set OSC handler in manager + self.manager.osc_handler = self.osc_handler + + # Application shutdown state + self.shutdown_requested = False + + # Setup signal handlers + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) + + print() # Extra blank line before OSC + print( + f"\033[38;5;80mOSC: \033[37mListen {listen_port} -> Transmit {transmit_port}\033[0m" + ) + print(f"\033[38;5;80mMemory: \033[37m{input_mem} -> {output_mem}\033[0m") + + # Start OSC handler immediately so /stop commands work during model loading + self.osc_handler.start() + + def start(self): + """Start the application""" + try: + # OSC handler already started in __init__ so /stop works during loading + + # Send output name via reporter + self.osc_reporter.send_output_name(self.manager.output_mem_name) + + # Auto-start streaming (matches your current main_sdtd.py behavior) + self.manager.start_streaming() + + # Keep main thread alive + self._wait_for_shutdown() + + except KeyboardInterrupt: + print("\n👋 Interrupted by user") + except Exception as e: + print(f"❌ Error: {e}") + self.osc_reporter.send_error(f"Application error: {e}") + raise + finally: + self.shutdown() + + def shutdown(self): + """Graceful shutdown""" + print("\n\nShutting down...") + + try: + self.manager.stop_streaming() + self.osc_handler.stop() + self.osc_reporter.stop_heartbeat() + self.osc_reporter.stop_vram_monitoring() + print("Shutdown complete") + except Exception as e: + print(f"Shutdown error: {e}") + + def _signal_handler(self, sig, frame): + """Handle shutdown signals""" + print(f"\nReceived signal {sig}") + self.shutdown() + sys.exit(0) + + def request_shutdown(self): + """Request application shutdown (called by OSC /stop command)""" + print("\n\033[31mStop command received via OSC\033[0m") + self.shutdown_requested = True + + # Force immediate exit if stuck in model loading (can't gracefully shutdown) + # Use threading to allow OSC response to be sent before exit + def force_exit(): + time.sleep(0.1) # Give OSC handler time to send response + print("Forcing exit...") + os._exit(0) # Hard exit (bypasses cleanup but works when blocked) + + threading.Thread(target=force_exit, daemon=True).start() + + def _wait_for_shutdown(self): + """Wait for shutdown signal""" + try: + # Keep main thread alive + while not self.shutdown_requested: + time.sleep(0.1) # Check shutdown flag more frequently + + except KeyboardInterrupt: + raise + + +def main(): + """Main entry point - reads from td_config.yaml""" + + # Create and start application (no longer needs stream_config.json) + app = StreamDiffusionTD(osc_reporter) + app.start() + + +if __name__ == "__main__": + main() diff --git a/StreamDiffusionTD/td_manager.py b/StreamDiffusionTD/td_manager.py index adaf5d57..f2eafd93 100644 --- a/StreamDiffusionTD/td_manager.py +++ b/StreamDiffusionTD/td_manager.py @@ -536,10 +536,16 @@ def _streaming_loop(self) -> None: instantaneous_fps = ( 1.0 / frame_interval if frame_interval > 0 else 0.0 ) - self.current_fps = ( - self.current_fps * self.fps_smoothing - + instantaneous_fps * (1 - self.fps_smoothing) - ) + # Smooth the FPS calculation. + # Seed EMA with first measurement so display doesn't slowly + # climb from 0.0 during ramp-up. + if self.current_fps == 0.0: + self.current_fps = instantaneous_fps + else: + self.current_fps = ( + self.current_fps * self.fps_smoothing + + instantaneous_fps * (1 - self.fps_smoothing) + ) self.last_frame_output_time = frame_output_time # Inference FPS: only frames that actually ran GPU inference (similar filter skips excluded) diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index b1c60946..795ee5d9 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -1,3 +1,12 @@ +import os + +# ─── CUDA / PyTorch env var defaults (must be before import torch) ─── +# Only set if not already provided by the launch script or environment +if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" +if "CUDA_MODULE_LOADING" not in os.environ: + os.environ["CUDA_MODULE_LOADING"] = "LAZY" + from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, UploadFile, File, Response from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware diff --git a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py index 23ce1c05..a06fc0a5 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py @@ -5,7 +5,7 @@ from diffusers.models.attention_processor import Attention from diffusers.utils import USE_PEFT_BACKEND - + class CachedSTAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -14,7 +14,41 @@ class CachedSTAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + # Per-layer pre-allocated buffers (lazy-init on first call — shape is model-dependent) + self._curr_key_buf: Optional[torch.Tensor] = None + self._curr_value_buf: Optional[torch.Tensor] = None + self._kv_out_buf: Optional[torch.Tensor] = None # shape: (2, 1, B, seq, inner_dim) + + # Pre-allocated buffers for zero-alloc hot path (lazy init on first call). + # _use_prealloc is False by default so ONNX export tracing uses the original + # clone/contiguous path. Set to True by wrapper.py after engine build. + self._curr_key_buf: Optional[torch.Tensor] = None + self._curr_value_buf: Optional[torch.Tensor] = None + self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key + self._cached_value_tr_buf: Optional[torch.Tensor] = None # transposed cache value + self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H) + self._use_prealloc: bool = False + + def _ensure_buffers( + self, + key: torch.Tensor, + cached_key: Optional[torch.Tensor], + ) -> None: + """Lazy-allocate or re-allocate buffers if tensor shapes changed.""" + if self._curr_key_buf is None or self._curr_key_buf.shape != key.shape: + B, S, H = key.shape + self._curr_key_buf = torch.empty_like(key) + self._curr_value_buf = torch.empty_like(key) + self._kvo_out_buf = torch.empty(2, 1, B, S, H, dtype=key.dtype, device=key.device) + + if cached_key is not None: + # cached_key shape: (maxframes, batch, seq_len, hidden_dim) + # transposed target shape: (batch, maxframes, seq_len, hidden_dim) + tr_shape = (cached_key.shape[1], cached_key.shape[0], cached_key.shape[2], cached_key.shape[3]) + if self._cached_key_tr_buf is None or self._cached_key_tr_buf.shape != tr_shape: + self._cached_key_tr_buf = torch.empty(tr_shape, dtype=cached_key.dtype, device=cached_key.device) + self._cached_value_tr_buf = torch.empty(tr_shape, dtype=cached_key.dtype, device=cached_key.device) + def __call__( self, attn: Attention, @@ -68,14 +102,34 @@ def __call__( cached_key, cached_value = None, None if is_selfattn: - curr_key = key.clone() - curr_value = value.clone() - - if cached_key is not None: - cached_key_reshaped = cached_key.transpose(0, 1).contiguous().flatten(1, 2) - cached_value_reshaped = cached_value.transpose(0, 1).contiguous().flatten(1, 2) - key = torch.cat([curr_key, cached_key_reshaped], dim=1) - value = torch.cat([curr_value, cached_value_reshaped], dim=1) + if self._use_prealloc: + # Zero-alloc hot path: copy into pre-allocated buffers + self._ensure_buffers(key, cached_key) + + self._curr_key_buf.copy_(key) + self._curr_value_buf.copy_(value) + curr_key = self._curr_key_buf + curr_value = self._curr_value_buf + + if cached_key is not None: + # transpose(0,1) makes non-contiguous; copy into contiguous buffer + self._cached_key_tr_buf.copy_(cached_key.transpose(0, 1)) + self._cached_value_tr_buf.copy_(cached_value.transpose(0, 1)) + # flatten is a free view on already-contiguous buffer + cached_key_reshaped = self._cached_key_tr_buf.flatten(1, 2) + cached_value_reshaped = self._cached_value_tr_buf.flatten(1, 2) + key = torch.cat([curr_key, cached_key_reshaped], dim=1) + value = torch.cat([curr_value, cached_value_reshaped], dim=1) + else: + # Original path — used during ONNX export tracing + curr_key = key.clone() + curr_value = value.clone() + + if cached_key is not None: + cached_key_reshaped = cached_key.transpose(0, 1).contiguous().flatten(1, 2) + cached_value_reshaped = cached_value.transpose(0, 1).contiguous().flatten(1, 2) + key = torch.cat([curr_key, cached_key_reshaped], dim=1) + value = torch.cat([curr_value, cached_value_reshaped], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -106,8 +160,14 @@ def __call__( hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor - + if is_selfattn: - kvo_cache = torch.stack([curr_key.unsqueeze(0), curr_value.unsqueeze(0)], dim=0) - - return hidden_states, kvo_cache \ No newline at end of file + if self._use_prealloc: + # Write curr K/V into pre-allocated output buffer — zero alloc + self._kvo_out_buf[0, 0].copy_(curr_key) + self._kvo_out_buf[1, 0].copy_(curr_value) + kvo_cache = self._kvo_out_buf + else: + kvo_cache = torch.stack([curr_key.unsqueeze(0), curr_value.unsqueeze(0)], dim=0) + + return hidden_states, kvo_cache diff --git a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py index caa14bb8..aaaad587 100644 --- a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py +++ b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py @@ -24,6 +24,7 @@ def __init__(self, filepath: str, stream: "cuda.Stream", use_cuda_graph: bool = self.use_cuda_graph = use_cuda_graph self.use_control = False # Will be set to True by wrapper if engine has ControlNet support self._cached_dummy_controlnet_inputs = None + self._cached_dummy_controlnet_tensors: Optional[Dict[str, torch.Tensor]] = None # pre-alloc zero tensors # Enable VRAM monitoring only if explicitly requested (defaults to False for performance) self.debug_vram = os.getenv("STREAMDIFFUSION_DEBUG_VRAM", "").lower() in ("1", "true") @@ -137,6 +138,7 @@ def __call__( latent_model_input ) self._cached_latent_dims = (current_latent_height, current_latent_width) + self._cached_dummy_controlnet_tensors = None # invalidate tensor cache except RuntimeError: self._cached_dummy_controlnet_inputs = None @@ -244,29 +246,38 @@ def _add_cached_dummy_inputs( self, dummy_inputs: Dict, latent_model_input: torch.Tensor, shape_dict: Dict, input_dict: Dict ): """ - Add cached dummy inputs to the shape dictionary and input dictionary + Add dummy ControlNet zero tensors to shape/input dicts. + Tensors are pre-allocated once and reused — no per-call torch.zeros() malloc. Args: dummy_inputs: Dictionary containing dummy input specifications - latent_model_input: The main latent input tensor (used for device/dtype reference) + latent_model_input: The main latent input tensor (used for device/dtype/batch reference) shape_dict: Shape dictionary to update input_dict: Input dictionary to update """ - for input_name, shape_spec in dummy_inputs.items(): - channels = shape_spec["channels"] - height = shape_spec["height"] - width = shape_spec["width"] - - # Create zero tensor with appropriate shape - zero_tensor = torch.zeros( - latent_model_input.shape[0], - channels, - height, - width, - dtype=latent_model_input.dtype, - device=latent_model_input.device, - ) - + batch_size = latent_model_input.shape[0] + + # Invalidate cache if batch size or dtype changed (rare; only on pipeline reconfigure) + if self._cached_dummy_controlnet_tensors is not None: + first = next(iter(self._cached_dummy_controlnet_tensors.values())) + if first.shape[0] != batch_size or first.dtype != latent_model_input.dtype: + self._cached_dummy_controlnet_tensors = None + + # Build cache on first call (or after invalidation) + if self._cached_dummy_controlnet_tensors is None: + self._cached_dummy_controlnet_tensors = { + input_name: torch.zeros( + batch_size, + shape_spec["channels"], + shape_spec["height"], + shape_spec["width"], + dtype=latent_model_input.dtype, + device=latent_model_input.device, + ) + for input_name, shape_spec in dummy_inputs.items() + } + + for input_name, zero_tensor in self._cached_dummy_controlnet_tensors.items(): shape_dict[input_name] = zero_tensor.shape input_dict[input_name] = zero_tensor diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 44521838..b8adb71f 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -110,6 +110,13 @@ def __init__( self._noise_buf = None # pre-allocated buffer for per-step noise in TCD non-batched path self._image_decode_buf = None # pre-allocated buffer for VAE decode output (avoids .clone()) self._prev_image_buf = None # pre-allocated buffer for skip-frame image cache + self._combined_latent_buf = None # pre-allocated: avoids torch.cat in predict_x0_batch + self._alpha_next = None # pre-computed: cat([alpha_prod_t_sqrt[1:], ones[0:1]]) + self._beta_next = None # pre-computed: cat([beta_prod_t_sqrt[1:], ones[0:1]]) + self._init_noise_rotated = None # pre-computed: cat([init_noise[1:], init_noise[0:1]]) + self._unet_kwargs: dict = {"return_dict": False} # pre-allocated: avoids per-frame dict creation + self._cfg_latent_buf = None # pre-allocated: avoids torch.concat for CFG latent doubling + self._cfg_t_buf = None # pre-allocated: avoids torch.concat for CFG timestep doubling self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) @@ -531,6 +538,46 @@ def prepare( self.c_skip = self.c_skip.to(self.device) self.c_out = self.c_out.to(self.device) + # Pre-compute shifted alpha/beta/init_noise (eliminates 5 mallocs + 8 kernel launches per frame) + if self.use_denoising_batch and (self.cfg_type == "self" or self.cfg_type == "initialize"): + self._alpha_next = torch.cat( + [self.alpha_prod_t_sqrt[1:], torch.ones_like(self.alpha_prod_t_sqrt[0:1])], dim=0 + ) + self._beta_next = torch.cat( + [self.beta_prod_t_sqrt[1:], torch.ones_like(self.beta_prod_t_sqrt[0:1])], dim=0 + ) + self._init_noise_rotated = torch.cat([self.init_noise[1:], self.init_noise[0:1]], dim=0) + else: + self._alpha_next = None + self._beta_next = None + self._init_noise_rotated = None + + # Pre-allocate combined latent buffer (eliminates 1 malloc + copy kernel per frame) + if self.denoising_steps_num > 1: + self._combined_latent_buf = torch.empty( + (self.batch_size, 4, self.latent_height, self.latent_width), + dtype=self.dtype, + device=self.device, + ) + else: + self._combined_latent_buf = None + + # Pre-allocate CFG expansion buffers (eliminates torch.concat mallocs for latent/timestep doubling) + if self.guidance_scale > 1.0 and (self.cfg_type == "initialize" or self.cfg_type == "full"): + cfg_batch = (1 + self.batch_size) if self.cfg_type == "initialize" else (2 * self.batch_size) + self._cfg_latent_buf = torch.empty( + (cfg_batch, 4, self.latent_height, self.latent_width), + dtype=self.dtype, + device=self.device, + ) + self._cfg_t_buf = torch.empty(cfg_batch, dtype=self.sub_timesteps_tensor.dtype, device=self.device) + else: + self._cfg_latent_buf = None + self._cfg_t_buf = None + + # Seed _unet_kwargs with the constant key so per-frame code only updates values + self._unet_kwargs = {"return_dict": False} + def _get_scheduler_scalings(self, timestep): """Get LCM/TCD-specific scaling factors for boundary conditions.""" if isinstance(self.scheduler, LCMScheduler): @@ -620,21 +667,28 @@ def unet_step( idx: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): - x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) - t_list = torch.concat([t_list[0:1], t_list], dim=0) + # Pre-allocated buf avoids torch.concat malloc for CFG latent doubling + self._cfg_latent_buf[:1].copy_(x_t_latent[0:1]) + self._cfg_latent_buf[1:].copy_(x_t_latent) + x_t_latent_plus_uc = self._cfg_latent_buf + self._cfg_t_buf[:1].copy_(t_list[0:1]) + self._cfg_t_buf[1:].copy_(t_list) + t_list = self._cfg_t_buf elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): - x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) - t_list = torch.concat([t_list, t_list], dim=0) + self._cfg_latent_buf[:len(x_t_latent)].copy_(x_t_latent) + self._cfg_latent_buf[len(x_t_latent):].copy_(x_t_latent) + x_t_latent_plus_uc = self._cfg_latent_buf + self._cfg_t_buf[:len(t_list)].copy_(t_list) + self._cfg_t_buf[len(t_list):].copy_(t_list) + t_list = self._cfg_t_buf else: x_t_latent_plus_uc = x_t_latent - # Prepare UNet call arguments - unet_kwargs = { - "sample": x_t_latent_plus_uc, - "timestep": t_list, - "encoder_hidden_states": self.prompt_embeds, - "return_dict": False, - } + # Prepare UNet call arguments (update pre-allocated dict in-place: avoids per-frame dict malloc) + self._unet_kwargs["sample"] = x_t_latent_plus_uc + self._unet_kwargs["timestep"] = t_list + self._unet_kwargs["encoder_hidden_states"] = self.prompt_embeds + unet_kwargs = self._unet_kwargs # Add SDXL-specific conditioning if this is an SDXL model if self.is_sdxl and hasattr(self, "add_text_embeds") and hasattr(self, "add_time_ids"): @@ -784,9 +838,7 @@ def unet_step( if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): noise_pred_text = model_pred[1:] - self.stock_noise = torch.concat( - [model_pred[0:1], self.stock_noise[1:]], dim=0 - ) # ここコメントアウトでself out cfg + self.stock_noise[0:1].copy_(model_pred[0:1]) # in-place: eliminates 1 malloc + copy kernel elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): noise_pred_uncond, noise_pred_text = model_pred.chunk(2) else: @@ -807,24 +859,9 @@ def unet_step( if self.cfg_type == "self" or self.cfg_type == "initialize": scaled_noise = self.beta_prod_t_sqrt * self.stock_noise delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx) - alpha_next = torch.concat( - [ - self.alpha_prod_t_sqrt[1:], - torch.ones_like(self.alpha_prod_t_sqrt[0:1]), - ], - dim=0, - ) - delta_x = alpha_next * delta_x - beta_next = torch.concat( - [ - self.beta_prod_t_sqrt[1:], - torch.ones_like(self.beta_prod_t_sqrt[0:1]), - ], - dim=0, - ) - delta_x = delta_x / beta_next - init_noise = torch.concat([self.init_noise[1:], self.init_noise[0:1]], dim=0) - self.stock_noise = init_noise + delta_x + delta_x = self._alpha_next * delta_x + delta_x = delta_x / self._beta_next + self.stock_noise = self._init_noise_rotated + delta_x else: # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised @@ -842,8 +879,10 @@ def update_kvo_cache(self, kvo_cache_out: List[torch.Tensor]) -> None: # Circular buffer: overwrite the oldest slot without shifting or cloning. # The attention processor reads all slots as an unordered K/V bag, so slot order is irrelevant. + # Use self.cache_maxframes (not tensor shape) so that when the buffer is allocated at + # max_cache_maxframes but the logical window is smaller, writes stay within the active range. for i, new_kv in enumerate(kvo_cache_out): - cache_size = self.kvo_cache[i].shape[1] + cache_size = self.cache_maxframes write_slot = (self.frame_idx // self.cache_interval - 1) % cache_size self.kvo_cache[i][:, write_slot].copy_(new_kv.squeeze(1)) @@ -875,7 +914,10 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): t_list = self.sub_timesteps_tensor if self.denoising_steps_num > 1: - x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) + # Copy into pre-allocated buffer: eliminates 1 malloc + copy kernel vs torch.cat + self._combined_latent_buf[:self.frame_bff_size].copy_(x_t_latent) + self._combined_latent_buf[self.frame_bff_size:].copy_(prev_latent_batch) + x_t_latent = self._combined_latent_buf self.stock_noise = torch.cat((self.init_noise[0:1], self.stock_noise[:-1]), dim=0) x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 4efdde09..b8f68b67 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -370,24 +370,26 @@ def update_stream_params( if cache_maxframes is not None: old_cache_maxframes = self.stream.cache_maxframes - self.stream.cache_maxframes = cache_maxframes if old_cache_maxframes != cache_maxframes: - for i, cache_tensor in enumerate(self.stream.kvo_cache): - current_shape = cache_tensor.shape - new_shape = (current_shape[0], cache_maxframes, current_shape[2], current_shape[3], current_shape[4]) - new_cache_tensor = torch.zeros( - new_shape, - dtype=cache_tensor.dtype, - device=cache_tensor.device + # KVO cache tensors are allocated at max_cache_maxframes and never resized at + # runtime — resizing one-at-a-time races with TRT inference (causes "Dimensions + # with name C must be equal" errors). cache_maxframes is a logical write window. + actual_cache_size = ( + self.stream.kvo_cache[0].shape[1] + if self.stream.kvo_cache + else cache_maxframes + ) + if cache_maxframes > actual_cache_size: + logger.warning( + f"update_stream_params: Requested cache_maxframes={cache_maxframes} " + f"exceeds allocated buffer size={actual_cache_size}. Clamping." ) - - if cache_maxframes > old_cache_maxframes: - new_cache_tensor[:, :old_cache_maxframes] = cache_tensor - else: - new_cache_tensor[:, :] = cache_tensor[:, -cache_maxframes:] - - self.stream.kvo_cache[i] = new_cache_tensor - logger.info(f"update_stream_params: Cache maxframes updated from {old_cache_maxframes} to {cache_maxframes}, kvo_cache tensors resized") + cache_maxframes = actual_cache_size + self.stream.cache_maxframes = cache_maxframes + logger.info( + f"update_stream_params: Cache maxframes {old_cache_maxframes} -> " + f"{cache_maxframes} (buffer size: {actual_cache_size}, no tensor resize)" + ) else: logger.info(f"update_stream_params: Cache maxframes set to {cache_maxframes}") @@ -737,6 +739,12 @@ def _update_seed(self, seed: int) -> None: # Reset stock_noise to match the new init_noise self.stream.stock_noise = torch.zeros_like(self.stream.init_noise) + # Keep pre-computed rotation in sync with new init_noise + if self.stream._init_noise_rotated is not None: + self.stream._init_noise_rotated = torch.cat( + [self.stream.init_noise[1:], self.stream.init_noise[0:1]], dim=0 + ) + def _get_scheduler_scalings(self, timestep): """Get LCM/TCD-specific scaling factors for boundary conditions.""" from diffusers import LCMScheduler diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 35d5bb79..0cadedd4 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -17,6 +17,7 @@ torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True class StreamDiffusionWrapper: @@ -353,10 +354,6 @@ def __init__( seed=seed, ) - # Offload text encoders to CPU after initial encoding to free ~1.6 GB VRAM (SDXL). - # They are reloaded on-demand before each prompt re-encoding call. - if acceleration == "tensorrt": - self._offload_text_encoders() # Set wrapper reference on parameter updater so it can access pipeline structure self.stream._param_updater.wrapper = self @@ -365,6 +362,10 @@ def __init__( self._acceleration = acceleration self._engine_dir = engine_dir + # Prompt change tracking: skip text encoder reload when text is identical + self._last_prompt_texts: Optional[List[str]] = None + self._last_negative_prompt: Optional[str] = None + if device_ids is not None: self.stream.unet = torch.nn.DataParallel( self.stream.unet, device_ids=device_ids @@ -473,10 +474,23 @@ def prepare( raise TypeError(f"prepare: prompt must be str or List[Tuple[str, float]], got {type(prompt)}") def _offload_text_encoders(self) -> None: - """Move text encoders to CPU to free VRAM (~1.6 GB for SDXL). + """No-op during inference: text encoders kept on GPU. + + Prompts change during inference while UNet/VAE/ControlNet are constant. + The ~1.6GB text encoder VRAM fits comfortably alongside other components + on a 24GB GPU, so the offload-reload cycle is pure overhead. + For maximum-VRAM scenarios (engine building, FP8 quantization), + use _force_offload_text_encoders() explicitly instead. + """ + + def _reload_text_encoders(self) -> None: + """No-op: text encoders remain on GPU (never offloaded during inference).""" - Called automatically after initial prepare() when using TRT acceleration. - Text encoders are reloaded to GPU before each prompt re-encoding call. + def _force_offload_text_encoders(self) -> None: + """Force-offload text encoders to CPU. Use during engine building or FP8 quantization only. + + Frees ~1.6GB VRAM for maximum headroom during one-time build processes. + Call _force_reload_text_encoders() afterwards to restore state. """ pipe = self.stream.pipe if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: @@ -486,16 +500,16 @@ def _offload_text_encoders(self) -> None: if next(pipe.text_encoder_2.parameters(), None) is not None: pipe.text_encoder_2 = pipe.text_encoder_2.to("cpu") torch.cuda.empty_cache() - logger.debug("[VRAM] Text encoders offloaded to CPU") + logger.debug("[VRAM] Text encoders force-offloaded to CPU (engine build/quantization)") - def _reload_text_encoders(self) -> None: - """Move text encoders back to GPU before prompt re-encoding.""" + def _force_reload_text_encoders(self) -> None: + """Force-reload text encoders to GPU after engine building or FP8 quantization.""" pipe = self.stream.pipe if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: pipe.text_encoder = pipe.text_encoder.to(self.device) if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: pipe.text_encoder_2 = pipe.text_encoder_2.to(self.device) - logger.debug("[VRAM] Text encoders reloaded to GPU") + logger.debug("[VRAM] Text encoders force-reloaded to GPU") def update_prompt( self, @@ -644,9 +658,26 @@ def update_stream_params( safety_checker_threshold : Optional[float] The threshold for the safety checker. """ - # Reload text encoders to GPU if a new prompt needs encoding. - needs_encoding = prompt_list is not None or negative_prompt is not None + # Reload text encoders to GPU only when prompt text actually changed. + # OSC sends prompt updates at ~60 Hz even with identical text, so comparing + # against the last encoded texts avoids repeated ~1.6 GB CPU↔GPU transfers. + _new_prompt_texts = ( + [p for p, _w in prompt_list] if prompt_list is not None else None + ) + _texts_changed = ( + _new_prompt_texts is not None + and _new_prompt_texts != self._last_prompt_texts + ) + _neg_changed = ( + negative_prompt is not None + and negative_prompt != self._last_negative_prompt + ) + needs_encoding = _texts_changed or _neg_changed if needs_encoding: + if _new_prompt_texts is not None: + self._last_prompt_texts = _new_prompt_texts + if negative_prompt is not None: + self._last_negative_prompt = negative_prompt self._reload_text_encoders() try: # Handle all parameters via parameter updater (including ControlNet) @@ -1757,6 +1788,12 @@ def _load_model( if name.endswith("attn1.processor") and not isinstance(processor, CachedSTAttnProcessor2_0): processors[name] = CachedSTAttnProcessor2_0() stream.unet.set_attn_processor(processors) + # Enable pre-allocated buffers for runtime — ONNX export already completed above, + # so the original clone/contiguous path was used for tracing. From here on, the + # processors run only at Python runtime (non-TRT paths) and buffer reuse is safe. + for proc in stream.unet.attn_processors.values(): + if isinstance(proc, CachedSTAttnProcessor2_0): + proc._use_prealloc = True # Compile VAE decoder engine using EngineManager vae_decoder_model = VAE(