Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions StreamDiffusionTD/install_tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Standalone TensorRT installation script for StreamDiffusionTD
This is a self-contained version that doesn't rely on the streamdiffusion package imports
"""

import platform
import subprocess
import sys
from typing import Optional


def run_pip(command: str):
"""Run pip command with proper error handling"""
return subprocess.check_call([sys.executable, "-m", "pip"] + command.split())


def is_installed(package_name: str) -> bool:
"""Check if a package is installed"""
try:
__import__(package_name.replace("-", "_"))
return True
except ImportError:
return False


def version(package_name: str) -> Optional[str]:
"""Get version of installed package"""
try:
import importlib.metadata

return importlib.metadata.version(package_name)
except:
return None


def get_cuda_version_from_torch() -> Optional[str]:
try:
import torch
except ImportError:
return None

cuda_version = torch.version.cuda
if cuda_version:
# Return full version like "12.8" for better detection
major_minor = ".".join(cuda_version.split(".")[:2])
return major_minor
return None


def install(cu: Optional[str] = None):
if cu is None:
cu = get_cuda_version_from_torch()

if cu is None:
print("Could not detect CUDA version. Please specify manually.")
return

print(f"Detected CUDA version: {cu}")
print("Installing TensorRT requirements...")

# Determine CUDA major version for package selection
cuda_major = cu.split(".")[0] if cu else "12"
cuda_version_float = float(cu) if cu else 12.0

# Skip nvidia-pyindex - it's broken with pip 25.3+ and not actually needed
# The NVIDIA index is already accessible via pip config or environment variables

# Uninstall old TensorRT versions
if is_installed("tensorrt"):
current_version_str = version("tensorrt")
if current_version_str:
try:
from packaging.version import Version

current_version = Version(current_version_str)
if current_version < Version("10.8.0"):
print("Uninstalling old TensorRT version...")
run_pip("uninstall -y tensorrt")
except:
# If packaging is not available, check version string directly
if current_version_str.startswith("9."):
print("Uninstalling old TensorRT version...")
run_pip("uninstall -y tensorrt")

# For CUDA 12.8+ (RTX 5090/Blackwell support), use TensorRT 10.8+
if cuda_version_float >= 12.8:
print("Installing TensorRT 10.8+ for CUDA 12.8+ (Blackwell GPU support)...")

# Install cuDNN 9 for CUDA 12
cudnn_name = "nvidia-cudnn-cu12"
print(f"Installing cuDNN: {cudnn_name}")
run_pip(f"install {cudnn_name} --no-cache-dir")

# Install TensorRT for CUDA 12 (RTX 5090/Blackwell support)
tensorrt_version = "tensorrt-cu12"
print(f"Installing TensorRT for CUDA {cu}: {tensorrt_version}")
run_pip(f"install {tensorrt_version} --no-cache-dir")

elif cuda_major == "12":
print("Installing TensorRT for CUDA 12.x...")

# Install cuDNN for CUDA 12
cudnn_name = "nvidia-cudnn-cu12"
print(f"Installing cuDNN: {cudnn_name}")
run_pip(f"install {cudnn_name} --no-cache-dir")

# Install TensorRT for CUDA 12
tensorrt_version = "tensorrt-cu12"
print(f"Installing TensorRT for CUDA {cu}: {tensorrt_version}")
run_pip(f"install {tensorrt_version} --no-cache-dir")

elif cuda_major == "11":
print("Installing TensorRT for CUDA 11.x...")

# Install cuDNN for CUDA 11
cudnn_name = "nvidia-cudnn-cu11==8.9.4.25"
print(f"Installing cuDNN: {cudnn_name}")
run_pip(f"install {cudnn_name} --no-cache-dir")

# Install TensorRT for CUDA 11
tensorrt_version = "tensorrt==9.0.1.post11.dev4"
print(f"Installing TensorRT for CUDA {cu}: {tensorrt_version}")
run_pip(
f"install --pre --extra-index-url https://pypi.nvidia.com {tensorrt_version} --no-cache-dir"
)
else:
print(f"Unsupported CUDA version: {cu}")
print("Supported versions: CUDA 11.x, 12.x")
return

# Install additional TensorRT tools
if not is_installed("polygraphy"):
print("Installing polygraphy...")
run_pip(
"install polygraphy==0.49.24 --extra-index-url https://pypi.ngc.nvidia.com --no-cache-dir"
)
if not is_installed("onnx_graphsurgeon"):
print("Installing onnx-graphsurgeon...")
run_pip(
"install onnx-graphsurgeon==0.5.8 --extra-index-url https://pypi.ngc.nvidia.com --no-cache-dir"
)
if platform.system() == "Windows" and not is_installed("pywin32"):
print("Installing pywin32...")
run_pip("install pywin32==306 --no-cache-dir")

# Pin onnx 1.18 + onnxruntime-gpu 1.24 together:
# - onnx 1.18 exports IR 11; modelopt needs FLOAT4E2M1 added in 1.18
# - onnx 1.19+ exports IR 12 (ORT 1.24 max) and removes float32_to_bfloat16 (onnx-gs needs it)
# - onnxruntime-gpu 1.24 supports IR 11; never co-install CPU onnxruntime (shared files conflict)
print("Pinning onnx==1.18.0 + onnxruntime-gpu==1.24.3...")
run_pip("install onnx==1.18.0 onnxruntime-gpu==1.24.3 --no-cache-dir")

# FP8 quantization dependencies (CUDA 12 only)
# nvidia-modelopt requires cupy; pin cupy 13.x + numpy<2 for mediapipe compat
if cuda_major == "12":
print("Installing FP8 quantization dependencies (nvidia-modelopt, cupy, numpy)...")
run_pip(
'install "nvidia-modelopt[onnx]" "cupy-cuda12x==13.6.0" "numpy==1.26.4" --no-cache-dir'
)

print("TensorRT installation completed successfully!")


if __name__ == "__main__":
install()
21 changes: 12 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from setuptools import find_packages, setup


# Copied from pip_utils.py to avoid import
def _check_torch_installed():
try:
import torch
import torchvision
except Exception:
msg = (
"Missing required pre-installed packages: torch, torchvision\n"
Expand All @@ -19,16 +19,18 @@ def _check_torch_installed():
raise RuntimeError(msg)

if not torch.version.cuda:
raise RuntimeError("Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package.")
raise RuntimeError(
"Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package."
)


def get_cuda_constraint():
cuda_version = os.environ.get("STREAMDIFFUSION_CUDA_VERSION") or \
os.environ.get("CUDA_VERSION")
cuda_version = os.environ.get("STREAMDIFFUSION_CUDA_VERSION") or os.environ.get("CUDA_VERSION")

if not cuda_version:
try:
import torch

cuda_version = torch.version.cuda
except Exception:
# might not be available during wheel build, so we have to ignore
Expand Down Expand Up @@ -56,10 +58,9 @@ def get_cuda_constraint():
"Pillow>=12.1.1", # CVE-2026-25990: out-of-bounds write in PSD loading
"fire==0.7.1",
"omegaconf==2.3.0",
"onnx==1.18.0", # onnx-graphsurgeon 0.5.8 requires onnx.helper.float32_to_bfloat16 (removed in onnx 1.19+)
"onnxruntime==1.24.3",
"onnxruntime-gpu==1.24.3",
"polygraphy==0.49.26",
"onnx==1.18.0", # IR 11 — modelopt needs FLOAT4E2M1 (added in 1.18); float32_to_bfloat16 present (removed in 1.19+)
"onnxruntime-gpu==1.24.3", # TRT EP, supports IR 11; never co-install CPU onnxruntime — shared files conflict
"polygraphy==0.49.24",
"protobuf>=4.25.8,<5", # mediapipe 0.10.21 requires protobuf 4.x; 4.25.8 fixes CVE-2025-4565; CVE-2026-0994 (JSON DoS) accepted risk for local pipeline
"colored==2.3.1",
"pywin32==311;sys_platform == 'win32'",
Expand All @@ -82,7 +83,9 @@ def deps_list(*pkgs):
extras = {}
extras["xformers"] = deps_list("xformers")
extras["torch"] = deps_list("torch", "accelerate")
extras["tensorrt"] = deps_list("protobuf", "cuda-python", "onnx", "onnxruntime", "onnxruntime-gpu", "colored", "polygraphy", "onnx-graphsurgeon")
extras["tensorrt"] = deps_list(
"protobuf", "cuda-python", "onnx", "onnxruntime-gpu", "colored", "polygraphy", "onnx-graphsurgeon"
)
extras["controlnet"] = deps_list("onnx-graphsurgeon", "controlnet-aux")
extras["ipadapter"] = deps_list("diffusers-ipadapter", "mediapipe", "insightface")

Expand Down
154 changes: 153 additions & 1 deletion src/streamdiffusion/modules/ipadapter_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple, Any
from typing import Dict, Optional, Tuple, Any
from enum import Enum
import torch

Expand Down Expand Up @@ -40,6 +40,158 @@ class IPAdapterConfig:
insightface_model_name: Optional[str] = None


# ---------------------------------------------------------------------------
# IP-Adapter model path mapping by base model architecture and adapter type
# ---------------------------------------------------------------------------
# None means the variant is unavailable for that architecture — callers fall
# back to REGULAR automatically.
IPADAPTER_MODEL_MAP: Dict[tuple, Optional[Dict[str, str]]] = {
("SD1.5", IPAdapterType.REGULAR): {
"model_path": "h94/IP-Adapter/models/ip-adapter_sd15.bin",
"image_encoder_path": "h94/IP-Adapter/models/image_encoder",
},
("SD1.5", IPAdapterType.PLUS): {
"model_path": "h94/IP-Adapter/models/ip-adapter-plus_sd15.safetensors",
"image_encoder_path": "h94/IP-Adapter/models/image_encoder",
},
("SD1.5", IPAdapterType.FACEID): {
"model_path": "h94/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin",
"image_encoder_path": "h94/IP-Adapter/models/image_encoder",
},
("SD2.1", IPAdapterType.REGULAR): None, # not available from h94 (ip-adapter_sd21.bin was never released)
("SD2.1", IPAdapterType.PLUS): None, # not available from h94
("SD2.1", IPAdapterType.FACEID): None, # not available from h94
("SDXL", IPAdapterType.REGULAR): {
"model_path": "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.bin",
"image_encoder_path": "h94/IP-Adapter/sdxl_models/image_encoder",
},
("SDXL", IPAdapterType.PLUS): {
"model_path": "h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors",
"image_encoder_path": "h94/IP-Adapter/sdxl_models/image_encoder",
},
("SDXL", IPAdapterType.FACEID): {
"model_path": "h94/IP-Adapter-FaceID/ip-adapter-faceid_sdxl.bin",
"image_encoder_path": "h94/IP-Adapter/sdxl_models/image_encoder",
},
}

# Set of all known HF model paths — used to distinguish known vs custom paths.
# Custom/local paths are never overridden.
_KNOWN_IPADAPTER_PATHS: frozenset = frozenset(
entry["model_path"]
for entry in IPADAPTER_MODEL_MAP.values()
if entry is not None
)

_KNOWN_ENCODER_PATHS: frozenset = frozenset({
"h94/IP-Adapter/models/image_encoder",
"h94/IP-Adapter/sdxl_models/image_encoder",
})


def _normalize_model_type(detected_model_type: str, is_sdxl: bool) -> Optional[str]:
"""Map model detection strings to IPADAPTER_MODEL_MAP keys."""
if is_sdxl:
return "SDXL"
return {
"SD1.5": "SD1.5",
"SD15": "SD1.5",
"SD2.1": "SD2.1",
"SD21": "SD2.1",
"SDXL": "SDXL",
}.get(detected_model_type)


def resolve_ipadapter_paths(
cfg: Dict[str, Any],
detected_model_type: str,
is_sdxl: bool,
) -> Dict[str, Any]:
"""Validate and auto-resolve IP-Adapter model/encoder paths for the detected base model.

Mutates *cfg* in-place and returns it. Custom/local paths are never overridden.

Args:
cfg: Single IP-Adapter config dict (keys: ipadapter_model_path, image_encoder_path, type, ...).
detected_model_type: Value from detect_model() e.g. "SD1.5", "SD2.1", "SDXL".
is_sdxl: Whether the base model is SDXL-family (takes precedence over detected_model_type).

Returns:
The (potentially mutated) cfg dict.
"""
current_model_path = cfg.get("ipadapter_model_path") or ""
current_encoder_path = cfg.get("image_encoder_path") or ""

# Parse adapter type, default to REGULAR
try:
adapter_type = IPAdapterType(cfg.get("type", "regular"))
except ValueError:
adapter_type = IPAdapterType.REGULAR

# Normalize to map key; unknown types are left unchanged
norm_type = _normalize_model_type(detected_model_type, is_sdxl)
if norm_type is None:
logger.warning(
f"IP-Adapter auto-resolution: unknown model type '{detected_model_type}' — "
f"cannot validate compatibility. Ensure ipadapter_model_path is correct for this model."
)
return cfg

# Custom/local path — respect it, only log info
if current_model_path and current_model_path not in _KNOWN_IPADAPTER_PATHS:
logger.info(
f"IP-Adapter: custom model path '{current_model_path}' — "
f"skipping auto-resolution (manual compatibility check required for {detected_model_type})."
)
return cfg

# Look up the correct entry for this architecture + type
target_entry = IPADAPTER_MODEL_MAP.get((norm_type, adapter_type))

# Variant unavailable for this architecture — fall back to REGULAR with warning
if target_entry is None:
logger.warning(
f"IP-Adapter type '{adapter_type.value}' is not available for {detected_model_type}. "
f"Falling back to 'regular' adapter type."
)
adapter_type = IPAdapterType.REGULAR
cfg["type"] = adapter_type.value
target_entry = IPADAPTER_MODEL_MAP.get((norm_type, adapter_type))

if target_entry is None:
logger.warning(
f"IP-Adapter: no compatible adapter exists for {detected_model_type} "
f"(type='{adapter_type.value}'). No IP-Adapter was released for this architecture. "
f"IP-Adapter will be disabled for this model."
)
cfg["enabled"] = False
return cfg

correct_model_path = target_entry["model_path"]
correct_encoder_path = target_entry["image_encoder_path"]

# Resolve model path
if current_model_path != correct_model_path:
logger.warning(
f"IP-Adapter auto-resolution: '{current_model_path}' is incompatible with "
f"{detected_model_type} (cross_attention_dim mismatch). "
f"Resolving to '{correct_model_path}'."
)
cfg["ipadapter_model_path"] = correct_model_path
else:
logger.info(f"IP-Adapter: '{current_model_path}' is compatible with {detected_model_type}.")

# Resolve encoder path (only if it's a known HF encoder — custom encoders untouched)
if current_encoder_path in _KNOWN_ENCODER_PATHS and current_encoder_path != correct_encoder_path:
logger.info(
f"IP-Adapter: resolving image encoder "
f"'{current_encoder_path}' → '{correct_encoder_path}'."
)
cfg["image_encoder_path"] = correct_encoder_path

return cfg


class IPAdapterModule(OrchestratorUser):
"""IP-Adapter embedding hook provider.

Expand Down
Loading