diff --git a/src/streamdiffusion/acceleration/tensorrt/__init__.py b/src/streamdiffusion/acceleration/tensorrt/__init__.py index 3918bc20..52bbcb3c 100644 --- a/src/streamdiffusion/acceleration/tensorrt/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/__init__.py @@ -113,6 +113,12 @@ def compile_unet( opt_batch_size: int = 1, engine_build_options: dict = {}, ): + # Extract FP8-specific options before passing the rest to EngineBuilder.build(). + # These are not valid kwargs for build_engine() and must be handled here. + build_options = dict(engine_build_options) + fp8 = build_options.pop("fp8", False) + calibration_data_fn = build_options.pop("calibration_data_fn", None) + unet = unet.to(torch.device("cuda"), dtype=torch.float16) builder = EngineBuilder(model_data, unet, device=torch.device("cuda")) builder.build( @@ -120,7 +126,9 @@ def compile_unet( onnx_opt_path, engine_path, opt_batch_size=opt_batch_size, - **engine_build_options, + fp8=fp8, + calibration_data_fn=calibration_data_fn, + **build_options, ) diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index dbc5f34f..e12451aa 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -1,5 +1,6 @@ import gc import json +import logging import os import time from datetime import datetime, timezone @@ -15,7 +16,7 @@ optimize_onnx, ) -import logging + _build_logger = logging.getLogger(__name__) @@ -70,6 +71,8 @@ def build( force_engine_build: bool = False, force_onnx_export: bool = False, force_onnx_optimize: bool = False, + fp8: bool = False, + calibration_data_fn=None, ): build_total_start = time.perf_counter() engine_name = Path(engine_path).parent.name @@ -145,6 +148,49 @@ def build( ) _build_logger.info(f"Verified ONNX opt file: {onnx_opt_path} ({opt_file_size / (1024**2):.1f} MB)") + # --- FP8 Quantization (if enabled) --- + # Inserts Q/DQ nodes into the optimized ONNX and replaces onnx_opt_path with + # the FP8-annotated ONNX for the TRT build step below. + onnx_trt_input = onnx_opt_path # default: use FP16 opt ONNX + fp8_trt = fp8 # may be set to False below if FP8 quantization fails + if fp8: + onnx_fp8_path = onnx_opt_path.replace(".opt.onnx", ".fp8.onnx") + if not os.path.exists(onnx_fp8_path): + _build_logger.warning(f"[BUILD] FP8 quantization starting...") + t0 = time.perf_counter() + from .fp8_quantize import quantize_onnx_fp8 + try: + quantize_onnx_fp8( + onnx_opt_path, + onnx_fp8_path, + model_data=self.model, + opt_batch_size=opt_batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + ) + elapsed = time.perf_counter() - t0 + stats["stages"]["fp8_quantize"] = {"status": "built", "elapsed_s": round(elapsed, 2)} + _build_logger.warning(f"[BUILD] FP8 quantization ({engine_filename}): {elapsed:.1f}s") + onnx_trt_input = onnx_fp8_path + except Exception as fp8_err: + elapsed = time.perf_counter() - t0 + _build_logger.warning( + f"[BUILD] FP8 quantization failed after {elapsed:.1f}s: {fp8_err}. " + f"Falling back to FP16 TensorRT engine (onnx_trt_input unchanged)." + ) + stats["stages"]["fp8_quantize"] = { + "status": "failed_fallback_fp16", + "elapsed_s": round(elapsed, 2), + "error": str(fp8_err), + } + # onnx_trt_input remains onnx_opt_path (FP16 ONNX) + # Disable FP8 engine build path (avoids STRONGLY_TYPED flag) + fp8_trt = False + else: + _build_logger.info(f"[BUILD] Found cached FP8 ONNX: {onnx_fp8_path}") + stats["stages"]["fp8_quantize"] = {"status": "cached"} + onnx_trt_input = onnx_fp8_path + # --- TRT Engine Build --- if not force_engine_build and os.path.exists(engine_path): print(f"Found cached engine: {engine_path}") @@ -153,7 +199,7 @@ def build( t0 = time.perf_counter() build_engine( engine_path=engine_path, - onnx_opt_path=onnx_opt_path, + onnx_opt_path=onnx_trt_input, model_data=self.model, opt_image_height=opt_image_height, opt_image_width=opt_image_width, @@ -162,21 +208,13 @@ def build( build_dynamic_shape=build_dynamic_shape, build_all_tactics=build_all_tactics, build_enable_refit=build_enable_refit, + fp8=fp8_trt, ) elapsed = time.perf_counter() - t0 stats["stages"]["trt_build"] = {"status": "built", "elapsed_s": round(elapsed, 2)} _build_logger.warning(f"[BUILD] TRT engine build ({engine_filename}): {elapsed:.1f}s") - # Cleanup ONNX artifacts — tolerate Windows file-lock failures (Issue #4) - for file in os.listdir(os.path.dirname(engine_path)): - if file.endswith('.engine'): - continue - try: - os.remove(os.path.join(os.path.dirname(engine_path), file)) - except OSError as cleanup_err: - _build_logger.warning(f"[BUILD] Could not delete temp file {file}: {cleanup_err}") - - # Record totals + # Record totals (before cleanup so build_stats.json is preserved) total_elapsed = time.perf_counter() - build_total_start stats["total_elapsed_s"] = round(total_elapsed, 2) stats["build_end"] = datetime.now(timezone.utc).isoformat() @@ -188,5 +226,46 @@ def build( _build_logger.warning(f"[BUILD] {engine_filename} complete: {total_elapsed:.1f}s total") _write_build_stats(engine_path, stats) - gc.collect() - torch.cuda.empty_cache() + # Cleanup ONNX artifacts — preserve .engine, .fp8.onnx, and build_stats.json + # Two-pass deletion to handle Windows file locks (gc.collect releases Python handles) + _keep_suffixes = (".engine", ".fp8.onnx") + _keep_exact = {"build_stats.json"} + engine_dir = os.path.dirname(engine_path) + _to_delete = [] + for file in os.listdir(engine_dir): + if file in _keep_exact or any(file.endswith(s) for s in _keep_suffixes): + continue + _to_delete.append(os.path.join(engine_dir, file)) + + if _to_delete: + _failed = [] + for fpath in _to_delete: + try: + os.remove(fpath) + except OSError: + _failed.append(fpath) + + # Release Python-held file handles (ONNX model refs), retry failures + if _failed: + gc.collect() + torch.cuda.empty_cache() + time.sleep(0.5) + _still_failed = [] + for fpath in _failed: + try: + os.remove(fpath) + except OSError as cleanup_err: + _still_failed.append(os.path.basename(fpath)) + _build_logger.warning(f"[BUILD] Could not delete temp file {os.path.basename(fpath)}: {cleanup_err}") + if _still_failed: + _build_logger.warning( + f"[BUILD] {len(_still_failed)} intermediate files could not be cleaned. " + f"Manual cleanup: delete all files except *.engine and *.fp8.onnx from {engine_dir}" + ) + cleaned = len(_to_delete) - len(_still_failed) + else: + cleaned = len(_to_delete) + _build_logger.info(f"[BUILD] Cleaned {cleaned}/{len(_to_delete)} intermediate files") + else: + gc.collect() + torch.cuda.empty_cache() diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index 6b54d022..9b2e1df0 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -106,7 +106,8 @@ def get_engine_path(self, controlnet_model_id: Optional[str] = None, is_faceid: Optional[bool] = None, use_cached_attn: bool = False, - use_controlnet: bool = False + use_controlnet: bool = False, + fp8: bool = False ) -> Path: """ Generate engine path using wrapper.py's current logic. @@ -151,6 +152,8 @@ def get_engine_path(self, prefix += f"--use_cached_attn-{use_cached_attn}" if use_controlnet: prefix += "--controlnet" + if fp8: + prefix += "--fp8" prefix += f"--mode-{mode}" diff --git a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py new file mode 100644 index 00000000..2a0c2d2e --- /dev/null +++ b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py @@ -0,0 +1,464 @@ +""" +FP8 Quantization for StreamDiffusion TensorRT UNet engine. + +Uses nvidia-modelopt for ONNX-level FP8 quantization via Q/DQ node insertion. +The quantized ONNX is then compiled to TRT with STRONGLY_TYPED + FP8 builder flags. + +Requirements: + nvidia-modelopt[onnx] >= 0.35.0 + TensorRT >= 10.0 (FP8 support) + RTX 4090+ (Ada Lovelace, compute 8.9, FP8 E4M3 hardware support) + +This module is called from builder.py when fp8=True is passed to EngineBuilder.build(). +""" + +import logging +import os +from typing import Dict, List, Optional + +import numpy as np + +logger = logging.getLogger(__name__) + + +def _restore_dynamic_axes(onnx_fp8_path: str, model_data) -> None: + """Restore dynamic dim_param symbols in FP8 ONNX after ModelOpt quantization. + + ModelOpt's override_shapes replaces dim_param with static dim_value for + calibration. TRT requires dynamic dims (dim_param) on inputs/outputs to + accept optimization profiles (min/opt/max ranges). This reads the original + dynamic_axes from model_data and restores them in the FP8 ONNX. + + Uses load_external_data=False so only the small protobuf is loaded/modified, + leaving the ~23GB external weight file untouched. + """ + import onnx + + try: + dynamic_axes = model_data.get_dynamic_axes() + except Exception as e: + logger.warning(f"[FP8] Could not get dynamic_axes from model_data: {e}. Skipping restore.") + return + + if not dynamic_axes: + logger.warning("[FP8] dynamic_axes is empty — skipping dynamic dim restore.") + return + + model = onnx.load(onnx_fp8_path, load_external_data=False) + + restored_count = 0 + for graph_input in model.graph.input: + name = graph_input.name + if name not in dynamic_axes: + continue + axes = dynamic_axes[name] + dims = graph_input.type.tensor_type.shape.dim + for dim_idx, symbolic_name in axes.items(): + if dim_idx < len(dims): + dim = dims[dim_idx] + dim.ClearField("dim_value") + dim.dim_param = symbolic_name + restored_count += 1 + + for graph_output in model.graph.output: + name = graph_output.name + if name not in dynamic_axes: + continue + axes = dynamic_axes[name] + dims = graph_output.type.tensor_type.shape.dim + for dim_idx, symbolic_name in axes.items(): + if dim_idx < len(dims): + dim = dims[dim_idx] + dim.ClearField("dim_value") + dim.dim_param = symbolic_name + restored_count += 1 + + if restored_count == 0: + logger.warning("[FP8] No dynamic dimensions restored — graph inputs may already be dynamic.") + return + + # Save only the protobuf (weight data stays in existing external file). + # load_external_data=False keeps tensor data_location=EXTERNAL references intact, + # so onnx.save() writes a small protobuf that still points to the existing _data file. + onnx.save(model, onnx_fp8_path) + logger.info( + f"[FP8] Restored {restored_count} dynamic dimensions in {os.path.basename(onnx_fp8_path)}" + ) + + +def generate_unet_calibration_data( + model_data, + opt_batch_size: int, + opt_image_height: int, + opt_image_width: int, + num_batches: int = 8, +) -> List[Dict[str, np.ndarray]]: + """ + Generate calibration data for SDXL-Turbo UNet FP8 quantization. + + Returns a list of input dicts matching the ONNX model's input names, + with values as numpy arrays shaped to the TRT optimization profile's opt shapes. + + Args: + model_data: UNet BaseModel instance (provides input names, kvo_cache_shapes, + text_maxlen, embedding_dim, cache_maxframes). + opt_batch_size: Optimal batch size from TRT profile (typically 1 for + frame_buffer_size=1). The UNet input dim is 2*opt_batch_size + because cond + uncond are batched together. + opt_image_height: Optimal image height in pixels (e.g. 512). + opt_image_width: Optimal image width in pixels (e.g. 512). + num_batches: Number of calibration batches. Capped at 8 for SDXL-scale + models: each batch contains 70 KVO cache tensors (~2.2 GB), + so 128 batches would require ~281 GB RAM. FP8 is less + sensitive to calibration size than INT8 (wider dynamic range). + + Returns: + List of dicts: [{input_name: np.ndarray}, ...] — one dict per batch. + """ + latent_h = opt_image_height // 8 + latent_w = opt_image_width // 8 + # UNet always receives 2× the batch (cond + uncond paired) + effective_batch = 2 * opt_batch_size + + input_names = model_data.get_input_names() + + # Fixed seed for reproducible calibration + rng = np.random.default_rng(seed=42) + + # Pre-read model_data properties once to avoid repeated attribute access + text_maxlen = getattr(model_data, "text_maxlen", 77) + embedding_dim = getattr(model_data, "embedding_dim", 2048) + cache_maxframes = getattr(model_data, "cache_maxframes", 4) + kvo_cache_shapes = getattr(model_data, "kvo_cache_shapes", []) + num_ip_layers = getattr(model_data, "num_ip_layers", 1) + control_inputs = getattr(model_data, "control_inputs", {}) + + calibration_dataset = [] + + for i in range(num_batches): + batch_data = {} + + for name in input_names: + if name == "sample": + # Noisy latents in float32 (UNet ingests fp32 sample before internal autocast) + # VAE latent scale: 0.18215 for SDXL + data = (rng.standard_normal((effective_batch, 4, latent_h, latent_w)) * 0.18215) + batch_data[name] = data.astype(np.float32) + + elif name == "timestep": + # Timesteps: float32, shape (effective_batch,) + # Sample broadly across [0, 999] to cover full activation range. + t = rng.integers(0, 1000, size=(effective_batch,)) + batch_data[name] = t.astype(np.float32) + + elif name == "encoder_hidden_states": + # CLIP/OpenCLIP text embeddings: float16 for fp16 SDXL models + # Scale 0.01 approximates typical normalized text embedding magnitude. + data = (rng.standard_normal((effective_batch, text_maxlen, embedding_dim)) * 0.01) + batch_data[name] = data.astype(np.float16) + + elif name == "ipadapter_scale": + # IP-Adapter per-layer scale: float32, shape (num_ip_layers,) + batch_data[name] = np.ones((num_ip_layers,), dtype=np.float32) + + elif name.startswith("input_control_"): + # ControlNet residual tensors: float16 + if name in control_inputs: + spec = control_inputs[name] + data = rng.standard_normal( + (effective_batch, spec["channels"], spec["height"], spec["width"]) + ) + batch_data[name] = data.astype(np.float16) + + elif name.startswith("kvo_cache_in_"): + # KVO cached attention inputs: float16 + # shape = (2, cache_maxframes, kvo_calib_batch, seq_len, hidden_dim) + # dim[0]=2: K/V pair (must match ONNX trace, which always uses 2). + # dim[2]: Must equal sample's batch dimension (effective_batch = 2 * opt_batch_size) + # because both share the ONNX dynamic axis "2B". Using a different value + # causes Concat dimension mismatches in attention layers during calibration. + # Zeros = cold cache. Conservative but avoids over-fitting calibration + # ranges to cached-attention activation patterns. + idx = int(name.rsplit("_", 1)[-1]) + if idx < len(kvo_cache_shapes): + seq_len, hidden_dim = kvo_cache_shapes[idx] + kvo_calib_batch = effective_batch # Must match sample batch (ONNX axis "2B") + batch_data[name] = np.zeros( + (2, cache_maxframes, kvo_calib_batch, seq_len, hidden_dim), + dtype=np.float16, + ) + + calibration_dataset.append(batch_data) + + logger.info( + f"[FP8] Generated {num_batches} calibration batches " + f"(effective_batch={effective_batch}, latent={latent_h}x{latent_w}, " + f"inputs={len(input_names)}, kvo_count={len(kvo_cache_shapes)})" + ) + return calibration_dataset + + +def quantize_onnx_fp8( + onnx_opt_path: str, + onnx_fp8_path: str, + calibration_data: Optional[List[Dict[str, np.ndarray]]] = None, + quantize_mha: bool = False, + percentile: float = 1.0, + alpha: float = 0.8, + model_data=None, + opt_batch_size: int = 1, + opt_image_height: int = 512, + opt_image_width: int = 512, +) -> None: + """ + Insert FP8 Q/DQ nodes into an optimized ONNX model via nvidia-modelopt. + + Takes the FP16-optimized ONNX (*.opt.onnx), runs calibration to collect + activation ranges, and writes a new ONNX with QuantizeLinear/DequantizeLinear + nodes annotated for FP8 E4M3 precision. TRT compiles this with + STRONGLY_TYPED + FP8 builder flags. + + Args: + onnx_opt_path: Input FP16 optimized ONNX path (*.opt.onnx). + onnx_fp8_path: Output FP8 quantized ONNX path (*.fp8.onnx). + calibration_data: Unused. Kept for backward compatibility. + quantize_mha: Enable FP8 quantization of multi-head attention ops. + Recommended: True. Requires TRT 10+ and compute 8.9+. + percentile: Percentile for activation range calibration. + 1.0 = no clipping (safest for first run). + alpha: SmoothQuant alpha — balances quantization difficulty between + activations (alpha→0) and weights (alpha→1). 0.8 is optimal + for transformer attention layers. + model_data: UNet BaseModel instance for building calibration_shapes. + If None, RandomDataProvider defaults all dynamic dims to 1. + opt_batch_size: Optimal batch size from TRT profile. + opt_image_height: Optimal image height in pixels. + opt_image_width: Optimal image width in pixels. + """ + try: + from modelopt.onnx.quantization import quantize as modelopt_quantize + except ImportError as e: + raise ImportError( + "nvidia-modelopt is required for FP8 quantization. " + "Install with: pip install 'nvidia-modelopt[onnx]'" + ) from e + + # Enable verbose ORT logging so Memcpy node details are visible before the + # summary warning. Severity 1 = INFO (shows per-node placement decisions). + try: + import onnxruntime as _ort + _ort.set_default_logger_severity(1) + logger.info("[FP8] ORT log_severity_level set to 1 (INFO) for Memcpy diagnostics") + except Exception: + pass + + input_size_mb = os.path.getsize(onnx_opt_path) / (1024 * 1024) + logger.info(f"[FP8] Starting ONNX FP8 quantization") + logger.info(f"[FP8] Input: {onnx_opt_path} ({input_size_mb:.0f} MB)") + logger.info(f"[FP8] Output: {onnx_fp8_path}") + logger.info(f"[FP8] Config: quantize_mha={quantize_mha}, percentile={percentile}, alpha={alpha}") + logger.info(f"[FP8] Calibration: RandomDataProvider with calibration_shapes (model_data={'provided' if model_data is not None else 'none'})") + + # Patch ByteSize() for >2GB ONNX models: modelopt calls onnx_model.ByteSize() + # to auto-detect external data format, but protobuf cannot serialize >2GB protos. + # Return a large value on failure so modelopt correctly uses external data format. + import onnx as _onnx + from google.protobuf.message import EncodeError as _EncodeError + + _orig_byte_size = _onnx.ModelProto.ByteSize + + def _safe_byte_size(self): + try: + return _orig_byte_size(self) + except _EncodeError: + return 3 * (1024**3) # >2GB → triggers external data format + + _onnx.ModelProto.ByteSize = _safe_byte_size + + # Ensure NVIDIA DLLs (cuDNN, cuBLAS, CUDA runtime) are on PATH so modelopt's + # ORT sessions can use CUDA/TensorRT EPs instead of CPU EP (which is stricter + # about mixed-precision Cast nodes and fails on FP16 models). + _nvidia_pkg_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))), os.pardir, "venv", "Lib", + "site-packages", "nvidia") + _nvidia_pkg_dir = os.path.normpath(_nvidia_pkg_dir) + if not os.path.isdir(_nvidia_pkg_dir): + # Fallback: find via importlib + try: + import nvidia.cudnn + _nvidia_pkg_dir = os.path.dirname(os.path.dirname(nvidia.cudnn.__file__)) + except ImportError: + _nvidia_pkg_dir = None + + if _nvidia_pkg_dir and os.path.isdir(_nvidia_pkg_dir): + _bin_dirs = [] + for _subpkg in ("cudnn", "cublas", "cuda_runtime", "cufft", "curand"): + _bdir = os.path.join(_nvidia_pkg_dir, _subpkg, "bin") + if os.path.isdir(_bdir) and _bdir not in os.environ.get("PATH", ""): + _bin_dirs.append(_bdir) + if _bin_dirs: + os.environ["PATH"] = os.pathsep.join(_bin_dirs) + os.pathsep + os.environ.get("PATH", "") + logger.info(f"[FP8] Added {len(_bin_dirs)} NVIDIA DLL dirs to PATH") + + # Build calibration_shapes string for modelopt's RandomDataProvider. + # RandomDataProvider calls _get_tensor_shape() which sets ALL dynamic dims to 1. + # For a 512x512 UNet, sample becomes (1,4,1,1) instead of (2,4,64,64), causing + # spatial dimension mismatches at UNet skip-connection Concat nodes (up_blocks). + # calibration_shapes overrides _get_tensor_shape() per input — only specified + # inputs bypass the default-to-1 fallback. + # + # Format: "input0:d0xd1x...,input1:d0xd1x..." (modelopt parse_shapes_spec format) + calibration_shapes_str: Optional[str] = None + if model_data is not None: + latent_h = opt_image_height // 8 + latent_w = opt_image_width // 8 + effective_batch = 2 * opt_batch_size + text_maxlen = getattr(model_data, "text_maxlen", 77) + embedding_dim = getattr(model_data, "embedding_dim", 2048) + # Use cache_maxframes=1 for calibration. The attention processor does: + # kvo_cache[0] → (cache_maxframes, batch, S, H) + # .transpose(0,1).flatten(1,2) → (batch, cache_maxframes*S, H) + # With cache_maxframes=4, ONNX shape-computation nodes create Concat ops + # that mix dim=4 (cache_maxframes) with dim=2 (batch), causing Concat axis + # mismatch errors in ORT. cache_maxframes=1 is valid (within TRT profile + # min range) and avoids the conflict. FP8 only needs valid activation ranges. + calib_cache_maxframes = 1 + kvo_cache_shapes = getattr(model_data, "kvo_cache_shapes", []) + num_ip_layers = getattr(model_data, "num_ip_layers", 1) + control_inputs = getattr(model_data, "control_inputs", {}) + kvo_calib_batch = effective_batch # Must match sample batch (ONNX axis "2B") + + shape_parts = [] + try: + input_names = model_data.get_input_names() + except Exception: + input_names = [] + + for name in input_names: + if name == "sample": + shape_parts.append(f"{name}:{effective_batch}x4x{latent_h}x{latent_w}") + elif name == "timestep": + shape_parts.append(f"{name}:{effective_batch}") + elif name == "encoder_hidden_states": + shape_parts.append(f"{name}:{effective_batch}x{text_maxlen}x{embedding_dim}") + elif name == "ipadapter_scale": + shape_parts.append(f"{name}:{num_ip_layers}") + elif name.startswith("input_control_") and name in control_inputs: + spec = control_inputs[name] + shape_parts.append( + f"{name}:{effective_batch}x{spec['channels']}x{spec['height']}x{spec['width']}" + ) + elif name.startswith("kvo_cache_in_"): + idx = int(name.rsplit("_", 1)[-1]) + if idx < len(kvo_cache_shapes): + seq_len, hidden_dim = kvo_cache_shapes[idx] + shape_parts.append( + f"{name}:2x{calib_cache_maxframes}x{kvo_calib_batch}x{seq_len}x{hidden_dim}" + ) + + if shape_parts: + calibration_shapes_str = ",".join(shape_parts) + logger.info( + f"[FP8] calibration_shapes: {len(shape_parts)} inputs " + f"(sample={effective_batch}x4x{latent_h}x{latent_w}, " + f"kvo={len([p for p in shape_parts if 'kvo_cache_in' in p])} caches " + f"calib_frames={calib_cache_maxframes})" + ) + else: + logger.warning( + "[FP8] model_data not provided — RandomDataProvider will default all " + "dynamic dims to 1. UNet Concat nodes may fail for non-trivial models." + ) + + quantize_kwargs = { + "quantize_mode": "fp8", + "output_path": onnx_fp8_path, + "calibration_method": "percentile", + "percentile": percentile, + "alpha": alpha, + "use_external_data_format": True, + # override_shapes replaces dynamic dims in the ONNX model itself with static + # values BEFORE any ORT sessions (MHA analysis or calibration) are created. + # Without this, ORT's internal shape inference with dynamic dims causes + # Concat failures (e.g. KVO cache dims vs sample batch dims). + # calibration_shapes additionally tells RandomDataProvider what shapes to + # generate for the calibration data. + "override_shapes": calibration_shapes_str, + "calibration_shapes": calibration_shapes_str, + # Use default EPs ["cpu","cuda:0","trt"] — CPU-only would fail on this FP16 SDXL + # model because ORT's mandatory CastFloat16Transformer inserts Cast nodes that + # conflict with existing Cast nodes in the upsampler conv. + # disable_mha_qdq controls modelopt's MHA analysis. When True, MHA MatMul + # nodes are excluded from FP8 quantization WITHOUT running ORT inference. + # Non-MHA ops (Conv, Linear, LayerNorm) still get FP8 Q/DQ nodes. + "disable_mha_qdq": not quantize_mha, + } + + try: + modelopt_quantize(onnx_opt_path, **quantize_kwargs) + except TypeError as e: + # Older nvidia-modelopt versions may not support alpha/disable_mha_qdq. + # Retry with base parameters only. + logger.warning(f"[FP8] Retrying without alpha/disable_mha_qdq (TypeError: {e})") + quantize_kwargs.pop("alpha", None) + quantize_kwargs.pop("disable_mha_qdq", None) + modelopt_quantize(onnx_opt_path, **quantize_kwargs) + except Exception as e: + # MHA analysis (disable_mha_qdq=False) requires an ORT inference run that + # fails with KVO cached attention models. Retry with disable_mha_qdq=True + # to skip the ORT session entirely — MHA layers use FP16, rest uses FP8. + if not quantize_kwargs.get("disable_mha_qdq", True): + # Delete intermediate files written during the failed attempt to free + # disk space before the retry (each set is ~23GB for SDXL-scale models). + _base = os.path.splitext(onnx_opt_path)[0] # strip .onnx + for _suffix in ( + "_static.onnx", "_static.onnx_data", # from override_shapes + "_named.onnx", "_named.onnx_data", + "_named_extended.onnx", "_named_extended.onnx_data", + "_ir10.onnx", "_ir10.onnx_data", + "_static_named.onnx", "_static_named.onnx_data", + "_static_ir10.onnx", "_static_ir10.onnx_data", + ): + _f = _base + _suffix + if os.path.exists(_f): + os.remove(_f) + logger.info(f"[FP8] Cleaned up intermediate: {os.path.basename(_f)}") + logger.warning( + f"[FP8] MHA analysis failed ({type(e).__name__}: {e}). " + "Retrying with disable_mha_qdq=True (MHA layers will use FP16 precision)." + ) + quantize_kwargs["disable_mha_qdq"] = True + modelopt_quantize(onnx_opt_path, **quantize_kwargs) + else: + raise + finally: + _onnx.ModelProto.ByteSize = _orig_byte_size # Restore original method + try: + import onnxruntime as _ort + _ort.set_default_logger_severity(2) # Restore to WARNING + except Exception: + pass + + if not os.path.exists(onnx_fp8_path): + raise RuntimeError( + f"[FP8] Quantization completed but output file not found: {onnx_fp8_path}" + ) + + # --- Restore dynamic axes --- + # ModelOpt's override_shapes baked static dim_value into graph inputs for calibration. + # TRT needs dynamic dim_param on inputs/outputs to accept optimization profiles. + if model_data is not None: + try: + _restore_dynamic_axes(onnx_fp8_path, model_data) + except Exception as restore_err: + logger.warning( + f"[FP8] Failed to restore dynamic axes: {restore_err}. " + "TRT engine build may fail with static shape profile mismatch." + ) + + output_size_mb = os.path.getsize(onnx_fp8_path) / (1024 * 1024) + ratio = output_size_mb / input_size_mb if input_size_mb > 0 else 0 + logger.info( + f"[FP8] Quantization complete: {input_size_mb:.0f} MB → {output_size_mb:.0f} MB " + f"(ratio: {ratio:.2f}x)" + ) diff --git a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py index 23ce1c05..02ee293f 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py @@ -5,7 +5,7 @@ from diffusers.models.attention_processor import Attention from diffusers.utils import USE_PEFT_BACKEND - + class CachedSTAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -14,7 +14,37 @@ class CachedSTAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + + # Pre-allocated buffers for zero-alloc hot path (lazy init on first call). + # _use_prealloc is False by default so ONNX export tracing uses the original + # clone/contiguous path. Set to True by wrapper.py after engine build. + self._curr_key_buf: Optional[torch.Tensor] = None + self._curr_value_buf: Optional[torch.Tensor] = None + self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key + self._cached_value_tr_buf: Optional[torch.Tensor] = None # transposed cache value + self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H) + self._use_prealloc: bool = False + + def _ensure_buffers( + self, + key: torch.Tensor, + cached_key: Optional[torch.Tensor], + ) -> None: + """Lazy-allocate or re-allocate buffers if tensor shapes changed.""" + if self._curr_key_buf is None or self._curr_key_buf.shape != key.shape: + B, S, H = key.shape + self._curr_key_buf = torch.empty_like(key) + self._curr_value_buf = torch.empty_like(key) + self._kvo_out_buf = torch.empty(2, 1, B, S, H, dtype=key.dtype, device=key.device) + + if cached_key is not None: + # cached_key shape: (maxframes, batch, seq_len, hidden_dim) + # transposed target shape: (batch, maxframes, seq_len, hidden_dim) + tr_shape = (cached_key.shape[1], cached_key.shape[0], cached_key.shape[2], cached_key.shape[3]) + if self._cached_key_tr_buf is None or self._cached_key_tr_buf.shape != tr_shape: + self._cached_key_tr_buf = torch.empty(tr_shape, dtype=cached_key.dtype, device=cached_key.device) + self._cached_value_tr_buf = torch.empty(tr_shape, dtype=cached_key.dtype, device=cached_key.device) + def __call__( self, attn: Attention, @@ -68,14 +98,34 @@ def __call__( cached_key, cached_value = None, None if is_selfattn: - curr_key = key.clone() - curr_value = value.clone() - - if cached_key is not None: - cached_key_reshaped = cached_key.transpose(0, 1).contiguous().flatten(1, 2) - cached_value_reshaped = cached_value.transpose(0, 1).contiguous().flatten(1, 2) - key = torch.cat([curr_key, cached_key_reshaped], dim=1) - value = torch.cat([curr_value, cached_value_reshaped], dim=1) + if self._use_prealloc: + # Zero-alloc hot path: copy into pre-allocated buffers + self._ensure_buffers(key, cached_key) + + self._curr_key_buf.copy_(key) + self._curr_value_buf.copy_(value) + curr_key = self._curr_key_buf + curr_value = self._curr_value_buf + + if cached_key is not None: + # transpose(0,1) makes non-contiguous; copy into contiguous buffer + self._cached_key_tr_buf.copy_(cached_key.transpose(0, 1)) + self._cached_value_tr_buf.copy_(cached_value.transpose(0, 1)) + # flatten is a free view on already-contiguous buffer + cached_key_reshaped = self._cached_key_tr_buf.flatten(1, 2) + cached_value_reshaped = self._cached_value_tr_buf.flatten(1, 2) + key = torch.cat([curr_key, cached_key_reshaped], dim=1) + value = torch.cat([curr_value, cached_value_reshaped], dim=1) + else: + # Original path — used during ONNX export tracing + curr_key = key.clone() + curr_value = value.clone() + + if cached_key is not None: + cached_key_reshaped = cached_key.transpose(0, 1).contiguous().flatten(1, 2) + cached_value_reshaped = cached_value.transpose(0, 1).contiguous().flatten(1, 2) + key = torch.cat([curr_key, cached_key_reshaped], dim=1) + value = torch.cat([curr_value, cached_value_reshaped], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -106,8 +156,14 @@ def __call__( hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor - + if is_selfattn: - kvo_cache = torch.stack([curr_key.unsqueeze(0), curr_value.unsqueeze(0)], dim=0) - - return hidden_states, kvo_cache \ No newline at end of file + if self._use_prealloc: + # Write curr K/V into pre-allocated output buffer — zero alloc + self._kvo_out_buf[0, 0].copy_(curr_key) + self._kvo_out_buf[1, 0].copy_(curr_value) + kvo_cache = self._kvo_out_buf + else: + kvo_cache = torch.stack([curr_key.unsqueeze(0), curr_value.unsqueeze(0)], dim=0) + + return hidden_states, kvo_cache diff --git a/src/streamdiffusion/acceleration/tensorrt/models/models.py b/src/streamdiffusion/acceleration/tensorrt/models/models.py index f92fc7fc..6de37236 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/models.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/models.py @@ -19,9 +19,9 @@ import onnx_graphsurgeon as gs import torch +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from onnx import shape_inference from polygraphy.backend.onnx.loader import fold_constants -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel class Optimizer: @@ -55,7 +55,9 @@ def fold_constants(self, return_onnx=False): def infer_shapes(self, return_onnx=False): onnx_graph = gs.export_onnx(self.graph) if onnx_graph.ByteSize() > 2147483648: - print(f"⚠️ Model size ({onnx_graph.ByteSize() / (1024**3):.2f} GB) exceeds 2GB - this is normal for SDXL models") + print( + f"⚠️ Model size ({onnx_graph.ByteSize() / (1024**3):.2f} GB) exceeds 2GB - this is normal for SDXL models" + ) print("🔧 ONNX shape inference will be skipped for large models to avoid memory issues") # For large models like SDXL, skip shape inference to avoid memory/size issues # The model will still work with TensorRT's own shape inference during engine building @@ -129,16 +131,17 @@ def optimize(self, onnx_graph): def check_dims(self, batch_size, image_height, image_width): # Make batch size check more flexible for ONNX export - if hasattr(self, '_allow_export_batch_override') and self._allow_export_batch_override: + if hasattr(self, "_allow_export_batch_override") and self._allow_export_batch_override: # During ONNX export, allow different batch sizes effective_min_batch = min(self.min_batch, batch_size) effective_max_batch = max(self.max_batch, batch_size) else: effective_min_batch = self.min_batch effective_max_batch = self.max_batch - - assert batch_size >= effective_min_batch and batch_size <= effective_max_batch, \ + + assert batch_size >= effective_min_batch and batch_size <= effective_max_batch, ( f"Batch size {batch_size} not in range [{effective_min_batch}, {effective_max_batch}]" + ) assert image_height % 8 == 0 or image_width % 8 == 0 latent_height = image_height // 8 latent_width = image_width // 8 @@ -149,7 +152,7 @@ def check_dims(self, batch_size, image_height, image_width): def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): # Following ComfyUI TensorRT approach: ensure proper min ≤ opt ≤ max constraints # Even with static_batch=True, we need different min/max to avoid TensorRT constraint violations - + if static_batch: # For static batch, still provide range to avoid min=opt=max constraint violation min_batch = max(1, batch_size - 1) # At least 1, but allow some range @@ -157,10 +160,10 @@ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, s else: min_batch = self.min_batch max_batch = self.max_batch - + latent_height = image_height // 8 latent_width = image_width // 8 - + # Force dynamic shapes for height/width to enable runtime resolution changes # Always use 384-1024 range regardless of static_shape flag min_image_height = self.min_image_shape @@ -171,7 +174,7 @@ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, s max_latent_height = self.max_latent_shape min_latent_width = self.min_latent_shape max_latent_width = self.max_latent_shape - + return ( min_batch, max_batch, @@ -247,7 +250,7 @@ def optimize(self, onnx_graph): class SafetyChecker(BaseModel): - def __init__(self, device, max_batch_size = 1, min_batch_size = 1): + def __init__(self, device, max_batch_size=1, min_batch_size=1): super(SafetyChecker, self).__init__( device=device, max_batch_size=max_batch_size, @@ -280,28 +283,27 @@ def get_shape_dict(self, batch_size, *args, **kwargs): } def get_sample_input(self, batch_size, *args, **kwargs): - return ( - torch.randn(batch_size, 3, 224, 224, dtype=torch.float16, device=self.device), - ) + return (torch.randn(batch_size, 3, 224, 224, dtype=torch.float16, device=self.device),) + class NSFWDetector(BaseModel): - def __init__(self, device, max_batch_size = 1, min_batch_size = 1): + def __init__(self, device, max_batch_size=1, min_batch_size=1): super(NSFWDetector, self).__init__( device=device, max_batch_size=max_batch_size, min_batch_size=min_batch_size, ) self.name = "nsfw_detector" - + def get_input_names(self): return ["pixel_values"] - + def get_output_names(self): return ["logits"] - + def get_dynamic_axes(self): return {"pixel_values": {0: "B"}} - + def get_input_profile(self, batch_size, *args, **kwargs): return { "pixel_values": [ @@ -310,17 +312,16 @@ def get_input_profile(self, batch_size, *args, **kwargs): (self.max_batch, 3, 448, 448), ], } - + def get_shape_dict(self, batch_size, *args, **kwargs): return { "pixel_values": (batch_size, 3, 448, 448), "logits": (batch_size, 2), } - + def get_sample_input(self, batch_size, *args, **kwargs): - return ( - torch.randn(batch_size, 3, 448, 448, dtype=torch.float16, device=self.device), - ) + return (torch.randn(batch_size, 3, 448, 448, dtype=torch.float16, device=self.device),) + class UNet(BaseModel): def __init__( @@ -358,13 +359,13 @@ def __init__( self.name = "UNet" self.image_height = image_height self.image_width = image_width - + self.use_control = use_control self.unet_arch = unet_arch or {} self.use_ipadapter = use_ipadapter self.num_image_tokens = num_image_tokens self.num_ip_layers = num_ip_layers - + # Baked-in IPAdapter configuration if self.use_ipadapter: # With baked-in processors, we extend text_maxlen to include image tokens @@ -375,7 +376,6 @@ def __init__( if self.num_ip_layers is None: raise ValueError("UNet model requires num_ip_layers when use_ipadapter=True") - if self.use_control and self.unet_arch: self.control_inputs = self.get_control(image_height, image_width) self._add_control_inputs() @@ -388,21 +388,24 @@ def __init__( self.max_cache_maxframes = max_cache_maxframes if self.use_cached_attn and self.unet is not None: from .utils import get_kvo_cache_info - self.kvo_cache_shapes, self.kvo_cache_structure, self.kvo_cache_count = get_kvo_cache_info(self.unet, image_height, image_width) - + + self.kvo_cache_shapes, self.kvo_cache_structure, self.kvo_cache_count = get_kvo_cache_info( + self.unet, image_height, image_width + ) + self.min_kvo_cache_shapes, _, _ = get_kvo_cache_info(self.unet, image_height, image_width) self.max_kvo_cache_shapes, _, _ = get_kvo_cache_info(self.unet, image_height, image_width) def get_control(self, image_height: int = 512, image_width: int = 512) -> dict: """Generate ControlNet input configurations with dynamic spatial dimensions based on input resolution.""" - block_out_channels = self.unet_arch.get('block_out_channels', (320, 640, 1280, 1280)) - + block_out_channels = self.unet_arch.get("block_out_channels", (320, 640, 1280, 1280)) + # Calculate latent space dimensions latent_height = image_height // 8 latent_width = image_width // 8 - + control_inputs = {} - + if len(block_out_channels) == 3: # SDXL architecture: Match UNet's exact down_block_res_samples structure # UNet down_block_res_samples = [initial_sample] + [block0_residuals] + [block1_residuals] + [block2_residuals] @@ -411,17 +414,14 @@ def get_control(self, image_height: int = 512, image_width: int = 512) -> dict: control_tensors = [ # Initial sample (after conv_in: 4->320 channels, no downsampling) (block_out_channels[0], 1), # 320 channels, 88x88 - # Block 0 residuals (320 channels) - (block_out_channels[0], 1), # 320 channels, 88x88 + (block_out_channels[0], 1), # 320 channels, 88x88 (block_out_channels[0], 1), # 320 channels, 88x88 (block_out_channels[0], 2), # 320 channels, 44x44 (downsampled) - - # Block 1 residuals (640 channels) + # Block 1 residuals (640 channels) (block_out_channels[1], 2), # 640 channels, 44x44 (block_out_channels[1], 2), # 640 channels, 44x44 (block_out_channels[1], 4), # 640 channels, 22x22 (downsampled) - # Block 2 residuals (1280 channels) (block_out_channels[2], 4), # 1280 channels, 22x22 (block_out_channels[2], 4), # 1280 channels, 22x22 @@ -430,31 +430,39 @@ def get_control(self, image_height: int = 512, image_width: int = 512) -> dict: # SD1.5/SD2.1 architecture: 4 down blocks with 12 control tensors control_tensors = [ # Block 0: No downsampling from latent space (factor = 1) - (320, 1), (320, 1), (320, 1), - # Block 1: 2x downsampling from latent space (factor = 2) - (320, 2), (640, 2), (640, 2), + (320, 1), + (320, 1), + (320, 1), + # Block 1: 2x downsampling from latent space (factor = 2) + (320, 2), + (640, 2), + (640, 2), # Block 2: 4x downsampling from latent space (factor = 4) - (640, 4), (1280, 4), (1280, 4), + (640, 4), + (1280, 4), + (1280, 4), # Block 3: 8x downsampling from latent space (factor = 8) - (1280, 8), (1280, 8), (1280, 8) + (1280, 8), + (1280, 8), + (1280, 8), ] - + # Generate control inputs with proper spatial dimensions for i, (channels, downsample_factor) in enumerate(control_tensors): input_name = f"input_control_{i:02d}" - + # Calculate spatial dimensions for this level control_height = max(1, latent_height // downsample_factor) control_width = max(1, latent_width // downsample_factor) - + control_inputs[input_name] = { - 'batch': self.min_batch, - 'channels': channels, - 'height': control_height, - 'width': control_width, - 'downsampling_factor': downsample_factor + "batch": self.min_batch, + "channels": channels, + "height": control_height, + "width": control_width, + "downsampling_factor": downsample_factor, } - + # Middle block uses the most downsampled resolution based on architecture if len(block_out_channels) == 3: # SDXL: middle block at 4x downsampling (after 3 down blocks) @@ -462,15 +470,15 @@ def get_control(self, image_height: int = 512, image_width: int = 512) -> dict: else: # SD1.5: middle block at 8x downsampling (after 4 down blocks) middle_downsample_factor = 8 - + control_inputs["input_control_middle"] = { - 'batch': self.min_batch, - 'channels': 1280, - 'height': max(1, latent_height // middle_downsample_factor), - 'width': max(1, latent_width // middle_downsample_factor), - 'downsampling_factor': middle_downsample_factor + "batch": self.min_batch, + "channels": 1280, + "height": max(1, latent_height // middle_downsample_factor), + "width": max(1, latent_width // middle_downsample_factor), + "downsampling_factor": middle_downsample_factor, } - + return control_inputs def get_kvo_cache_names(self, in_out: str): @@ -480,7 +488,7 @@ def _add_control_inputs(self): """Add ControlNet inputs to the model's input/output specifications""" if not self.control_inputs: return - + self._original_get_input_names = self.get_input_names self._original_get_dynamic_axes = self.get_dynamic_axes self._original_get_input_profile = self.get_input_profile @@ -494,6 +502,7 @@ def get_input_names(self): base_names.append("ipadapter_scale") try: import logging + logging.getLogger(__name__).debug(f"TRT Models: get_input_names with ipadapter -> {base_names}") except Exception: pass @@ -512,8 +521,14 @@ def get_output_names(self): def get_kvo_cache_input_profile(self, min_batch, batch_size, max_batch): profiles = [] - for min_shape, shape, max_shape in zip(self.min_kvo_cache_shapes, self.kvo_cache_shapes, self.max_kvo_cache_shapes): - profile = [(2, self.min_cache_maxframes, min_batch, min_shape[0], min_shape[1]), (2, self.cache_maxframes, batch_size, shape[0], shape[1]), (2, self.max_cache_maxframes, max_batch, max_shape[0], max_shape[1])] + for min_shape, shape, max_shape in zip( + self.min_kvo_cache_shapes, self.kvo_cache_shapes, self.max_kvo_cache_shapes + ): + profile = [ + (2, self.min_cache_maxframes, min_batch, min_shape[0], min_shape[1]), + (2, self.cache_maxframes, batch_size, shape[0], shape[1]), + (2, self.max_cache_maxframes, max_batch, max_shape[0], max_shape[1]), + ] profiles.append(profile) return profiles @@ -528,26 +543,27 @@ def get_dynamic_axes(self): base_axes["ipadapter_scale"] = {0: "L_ip"} try: import logging - logging.getLogger(__name__).debug(f"TRT Models: dynamic axes include ipadapter_scale with L_ip={getattr(self, 'num_ip_layers', None)}") + + logging.getLogger(__name__).debug( + f"TRT Models: dynamic axes include ipadapter_scale with L_ip={getattr(self, 'num_ip_layers', None)}" + ) except Exception: pass - + if self.use_control and self.control_inputs: for name, shape_spec in self.control_inputs.items(): height = shape_spec["height"] width = shape_spec["width"] spatial_suffix = f"{height}x{width}" - base_axes[name] = { - 0: "2B", - 2: f"H_{spatial_suffix}", - 3: f"W_{spatial_suffix}" - } + base_axes[name] = {0: "2B", 2: f"H_{spatial_suffix}", 3: f"W_{spatial_suffix}"} if self.use_cached_attn: # hardcoded resolution for now due to VRAM limitations + # NOTE: dim[0]=2 (K/V pair) must stay static — attention Gather nodes + # index into it at idx=0 and idx=1, so dim[0]<2 causes OOB errors. for i in range(self.kvo_cache_count): base_axes[f"kvo_cache_in_{i}"] = {1: "C", 2: "2B"} base_axes[f"kvo_cache_out_{i}"] = {2: "2B"} - + return base_axes def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): @@ -564,30 +580,30 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, min_latent_width, max_latent_width, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) - + # Following TensorRT documentation: ensure proper min ≤ opt ≤ max constraints for ALL dimensions # Calculate optimal latent dimensions that fall within min/max range opt_latent_height = min(max(latent_height, min_latent_height), max_latent_height) opt_latent_width = min(max(latent_width, min_latent_width), max_latent_width) - + # Ensure no dimension equality that causes constraint violations if opt_latent_height == min_latent_height and min_latent_height < max_latent_height: opt_latent_height = min(min_latent_height + 8, max_latent_height) # Add 8 pixels for separation if opt_latent_width == min_latent_width and min_latent_width < max_latent_width: opt_latent_width = min(min_latent_width + 8, max_latent_width) - + # Image dimensions for ControlNet inputs min_image_h, max_image_h = self.min_image_shape, self.max_image_shape min_image_w, max_image_w = self.min_image_shape, self.max_image_shape opt_image_height = min(max(image_height, min_image_h), max_image_h) opt_image_width = min(max(image_width, min_image_w), max_image_w) - + # Ensure image dimension separation as well if opt_image_height == min_image_h and min_image_h < max_image_h: opt_image_height = min(min_image_h + 64, max_image_h) # Add 64 pixels for separation if opt_image_width == min_image_w and min_image_w < max_image_w: opt_image_width = min(min_image_w + 64, max_image_w) - + profile = { "sample": [ (min_batch, self.unet_dim, min_latent_height, min_latent_width), @@ -610,10 +626,13 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, ] try: import logging - logging.getLogger(__name__).debug(f"TRT Models: profile ipadapter_scale min/opt/max={(1,),(self.num_ip_layers,),(self.num_ip_layers,)}") + + logging.getLogger(__name__).debug( + f"TRT Models: profile ipadapter_scale min/opt/max={(1,), (self.num_ip_layers,), (self.num_ip_layers,)}" + ) except Exception: pass - + if self.use_control and self.control_inputs: # Use the actual calculated spatial dimensions for each ControlNet input # Each control input has its own specific spatial resolution based on UNet architecture @@ -621,29 +640,31 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, channels = shape_spec["channels"] control_height = shape_spec["height"] control_width = shape_spec["width"] - + # Create optimization profile with proper spatial dimension scaling # Scale the spatial dimensions proportionally with the main latent dimensions scale_h = opt_latent_height / latent_height if latent_height > 0 else 1.0 scale_w = opt_latent_width / latent_width if latent_width > 0 else 1.0 - + min_control_h = max(1, int(control_height * min_latent_height / latent_height)) max_control_h = max(min_control_h + 1, int(control_height * max_latent_height / latent_height)) opt_control_h = max(min_control_h, min(int(control_height * scale_h), max_control_h)) - + min_control_w = max(1, int(control_width * min_latent_width / latent_width)) max_control_w = max(min_control_w + 1, int(control_width * max_latent_width / latent_width)) opt_control_w = max(min_control_w, min(int(control_width * scale_w), max_control_w)) - + profile[name] = [ - (min_batch, channels, min_control_h, min_control_w), # min - (batch_size, channels, opt_control_h, opt_control_w), # opt - (max_batch, channels, max_control_h, max_control_w), # max + (min_batch, channels, min_control_h, min_control_w), # min + (batch_size, channels, opt_control_h, opt_control_w), # opt + (max_batch, channels, max_control_h, max_control_w), # max ] if self.use_cached_attn: - for name, _profile in zip(self.get_kvo_cache_names("in"), self.get_kvo_cache_input_profile(min_batch, batch_size, max_batch)): + for name, _profile in zip( + self.get_kvo_cache_names("in"), self.get_kvo_cache_input_profile(min_batch, batch_size, max_batch) + ): profile[name] = _profile - + return profile def get_shape_dict(self, batch_size, image_height, image_width): @@ -658,10 +679,11 @@ def get_shape_dict(self, batch_size, image_height, image_width): shape_dict["ipadapter_scale"] = (self.num_ip_layers,) try: import logging + logging.getLogger(__name__).debug(f"TRT Models: shape_dict ipadapter_scale={(self.num_ip_layers,)}") except Exception: pass - + if self.use_control and self.control_inputs: # Use the actual calculated spatial dimensions for each ControlNet input for name, shape_spec in self.control_inputs.items(): @@ -671,68 +693,78 @@ def get_shape_dict(self, batch_size, image_height, image_width): shape_dict[name] = (2 * batch_size, channels, control_height, control_width) if self.use_cached_attn: - for in_name, out_name, shape in zip(self.get_kvo_cache_names("in"), self.get_kvo_cache_names("out"), self.get_kvo_cache_shapes): + for in_name, out_name, shape in zip( + self.get_kvo_cache_names("in"), self.get_kvo_cache_names("out"), self.kvo_cache_shapes + ): shape_dict[in_name] = (2, self.cache_maxframes, batch_size, shape[0], shape[1]) shape_dict[out_name] = (2, 1, batch_size, shape[0], shape[1]) - + return shape_dict def get_sample_input(self, batch_size, image_height, image_width): # Enable flexible batch size checking for ONNX export self._allow_export_batch_override = True - + try: latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) finally: # Clean up the override flag - if hasattr(self, '_allow_export_batch_override'): - delattr(self, '_allow_export_batch_override') - + if hasattr(self, "_allow_export_batch_override"): + delattr(self, "_allow_export_batch_override") + dtype = torch.float16 if self.fp16 else torch.float32 - + # Use smaller batch size for memory efficiency during ONNX export export_batch_size = min(batch_size, 1) # Use batch size 1 for ONNX export to save memory - + base_inputs = [ torch.randn( - 2 * export_batch_size, self.unet_dim, latent_height, latent_width, - dtype=torch.float32, device=self.device + 2 * export_batch_size, + self.unet_dim, + latent_height, + latent_width, + dtype=torch.float32, + device=self.device, ), torch.ones((2 * export_batch_size,), dtype=torch.float32, device=self.device), torch.randn(2 * export_batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ] - + if self.use_ipadapter: base_inputs.append(torch.ones(self.num_ip_layers, dtype=torch.float32, device=self.device)) - + if self.use_control and self.control_inputs: control_inputs = [] - + # Use the ACTUAL calculated spatial dimensions for each control input # This ensures each control input matches its expected UNet feature map resolution - + for name in sorted(self.control_inputs.keys()): shape_spec = self.control_inputs[name] channels = shape_spec["channels"] - + # KEY FIX: Use the specific spatial dimensions calculated for this control input control_height = shape_spec["height"] control_width = shape_spec["width"] - + control_input = torch.randn( - 2 * export_batch_size, channels, control_height, control_width, - dtype=dtype, device=self.device + 2 * export_batch_size, channels, control_height, control_width, dtype=dtype, device=self.device ) control_inputs.append(control_input) - + # Clear cache periodically to prevent memory buildup if len(control_inputs) % 4 == 0: torch.cuda.empty_cache() - + base_inputs = base_inputs + control_inputs - + if self.use_cached_attn: - base_inputs = base_inputs + [torch.randn(2, self.cache_maxframes, 2 * export_batch_size, shape[0], shape[1], dtype=torch.float16).to(self.device) for shape in self.kvo_cache_shapes] + base_inputs = base_inputs + [ + torch.randn( + 2, self.cache_maxframes, 2 * export_batch_size, shape[0], shape[1], dtype=torch.float16 + ).to(self.device) + for shape in self.kvo_cache_shapes + ] return tuple(base_inputs) diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index 1f980aaa..48a6c319 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -31,6 +31,7 @@ import tensorrt as trt import torch + # cuda-python 13.x renamed 'cudart' to 'cuda.bindings.runtime' try: from cuda.bindings import runtime as cudart @@ -242,8 +243,14 @@ def build( enable_all_tactics=False, timing_cache=None, workspace_size=0, + fp8=False, ): logger.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + + if fp8: + self._build_fp8(onnx_path, input_profile, workspace_size, enable_all_tactics) + return + p = Profile() if input_profile: for name, dims in input_profile.items(): @@ -274,6 +281,77 @@ def build( ) save_engine(engine, path=self.engine_path) + def _build_fp8(self, onnx_path, input_profile, workspace_size, enable_all_tactics): + """ + Build a TRT engine from a Q/DQ-annotated FP8 ONNX using the raw TRT builder API. + + Polygraphy 0.49.26's CreateConfig does not support fp8=, so we use the raw + TensorRT Python API directly. The STRONGLY_TYPED network flag is required to + preserve the Q/DQ precision annotations inserted by nvidia-modelopt. + + Args: + onnx_path: Path to *.fp8.onnx (Q/DQ-annotated by fp8_quantize.py). + input_profile: Dict of {name: (min, opt, max)} shapes. + workspace_size: TRT workspace limit in bytes. + enable_all_tactics: If True, allow all TRT tactic sources. + """ + TRT_LOGGER = trt.Logger(trt.Logger.WARNING) + + builder = trt.Builder(TRT_LOGGER) + + # STRONGLY_TYPED: required for FP8. Tells TRT to use the data-type annotations + # from Q/DQ nodes rather than running its own precision heuristics. + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + network = builder.create_network(network_flags) + + parser = trt.OnnxParser(network, TRT_LOGGER) + success = parser.parse_from_file(onnx_path) + if not success: + errors = [parser.get_error(i) for i in range(parser.num_errors)] + raise RuntimeError( + f"TRT ONNX parser failed for FP8 engine: {onnx_path}\n" + + "\n".join(str(e) for e in errors) + ) + + config = builder.create_builder_config() + # TRT 10.12+ with STRONGLY_TYPED network: precision flags (FP8, FP16, TF32) + # must NOT be set — the Q/DQ node annotations dictate precision directly. + # Older TRT versions need both the BuilderFlag and the network flag. + if hasattr(trt.BuilderFlag, 'STRONGLY_TYPED'): + # TRT < 10.12: set all precision flags + STRONGLY_TYPED on config + config.set_flag(trt.BuilderFlag.FP8) + config.set_flag(trt.BuilderFlag.FP16) + config.set_flag(trt.BuilderFlag.TF32) + config.set_flag(trt.BuilderFlag.STRONGLY_TYPED) + else: + # TRT 10.12+: NetworkDefinitionCreationFlag.STRONGLY_TYPED (line 304) + # handles precision; setting FP8 flag causes API Usage Error. + pass + + if workspace_size > 0: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size) + + if input_profile: + profile = builder.create_optimization_profile() + for name, dims in input_profile.items(): + assert len(dims) == 3, f"Expected (min, opt, max) for {name}" + profile.set_shape(name, min=dims[0], opt=dims[1], max=dims[2]) + config.add_optimization_profile(profile) + + logger.info(f"[FP8] Building TRT FP8 engine (STRONGLY_TYPED): {self.engine_path}") + serialized = builder.build_serialized_network(network, config) + if serialized is None: + raise RuntimeError( + f"TRT FP8 engine build failed for {onnx_path}. " + "Check TRT logs above for details." + ) + + with open(self.engine_path, "wb") as f: + f.write(serialized) + + size_bytes = getattr(serialized, 'nbytes', None) or len(serialized) + logger.info(f"[FP8] Engine saved: {self.engine_path} ({size_bytes / 1024 / 1024:.0f} MB)") + def load(self): logger.info(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) @@ -516,6 +594,7 @@ def build_engine( build_dynamic_shape: bool = False, build_all_tactics: bool = False, build_enable_refit: bool = False, + fp8: bool = False, ): _, free_mem, _ = cudart.cudaMemGetInfo() GiB = 2**30 @@ -540,6 +619,7 @@ def build_engine( enable_refit=build_enable_refit, enable_all_tactics=build_all_tactics, workspace_size=max_workspace_size, + fp8=fp8, ) return engine @@ -652,6 +732,15 @@ def export_onnx( ) logger.info("Converted to external data format with weights in weights.pb") + # Delete individual tensor files left by torch.onnx.export (~4 GB for SDXL) + # They are now consolidated into weights.pb and no longer needed + for f in os.listdir(onnx_dir): + if f.startswith("onnx__"): + try: + os.remove(os.path.join(onnx_dir, f)) + except OSError: + pass # Caught by builder.py final cleanup if still present + del onnx_model del wrapped_model gc.collect() diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 47604f1a..b16a9585 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -122,6 +122,7 @@ def __init__( cache_interval: int = 1, min_cache_maxframes: int = 1, max_cache_maxframes: int = 4, + fp8: bool = False, ): """ Initializes the StreamDiffusionWrapper. @@ -296,6 +297,7 @@ def __init__( self.set_nsfw_fallback_img(height, width) self.safety_checker_fallback_type = safety_checker_fallback_type self.safety_checker_threshold = safety_checker_threshold + self.fp8 = fp8 self.stream: StreamDiffusion = self._load_model( model_id_or_path=model_id_or_path, @@ -328,6 +330,7 @@ def __init__( cache_interval=cache_interval, min_cache_maxframes=min_cache_maxframes, max_cache_maxframes=max_cache_maxframes, + fp8=fp8, ) # Store skip_diffusion on wrapper for execution flow control @@ -1054,6 +1057,7 @@ def _load_model( cache_interval: int = 1, min_cache_maxframes: int = 1, max_cache_maxframes: int = 4, + fp8: bool = False, ) -> StreamDiffusion: """ Loads the model. @@ -1517,6 +1521,7 @@ def _load_model( is_faceid=is_faceid if use_ipadapter_trt else None, use_cached_attn=use_cached_attn, use_controlnet=use_controlnet_trt, + fp8=fp8, ) vae_encoder_path = engine_manager.get_engine_path( EngineType.VAE_ENCODER, @@ -1752,6 +1757,12 @@ def _load_model( if name.endswith("attn1.processor") and not isinstance(processor, CachedSTAttnProcessor2_0): processors[name] = CachedSTAttnProcessor2_0() stream.unet.set_attn_processor(processors) + # Enable pre-allocated buffers for runtime — ONNX export already completed above, + # so the original clone/contiguous path was used for tracing. From here on, the + # processors run only at Python runtime (non-TRT paths) and buffer reuse is safe. + for proc in stream.unet.attn_processors.values(): + if isinstance(proc, CachedSTAttnProcessor2_0): + proc._use_prealloc = True # Compile VAE decoder engine using EngineManager vae_decoder_model = VAE( @@ -1829,6 +1840,7 @@ def _load_model( 'opt_image_height': self.height, 'opt_image_width': self.width, 'build_all_tactics': True, + 'fp8': fp8, } ) if load_engine: