Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/streamdiffusion/acceleration/tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,22 @@ def compile_unet(
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
# Extract FP8-specific options before passing the rest to EngineBuilder.build().
# These are not valid kwargs for build_engine() and must be handled here.
build_options = dict(engine_build_options)
fp8 = build_options.pop("fp8", False)
calibration_data_fn = build_options.pop("calibration_data_fn", None)

unet = unet.to(torch.device("cuda"), dtype=torch.float16)
builder = EngineBuilder(model_data, unet, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
fp8=fp8,
calibration_data_fn=calibration_data_fn,
**build_options,
)


Expand Down
107 changes: 93 additions & 14 deletions src/streamdiffusion/acceleration/tensorrt/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import json
import logging
import os
import time
from datetime import datetime, timezone
Expand All @@ -15,7 +16,7 @@
optimize_onnx,
)

import logging

_build_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,6 +71,8 @@ def build(
force_engine_build: bool = False,
force_onnx_export: bool = False,
force_onnx_optimize: bool = False,
fp8: bool = False,
calibration_data_fn=None,
):
build_total_start = time.perf_counter()
engine_name = Path(engine_path).parent.name
Expand Down Expand Up @@ -145,6 +148,49 @@ def build(
)
_build_logger.info(f"Verified ONNX opt file: {onnx_opt_path} ({opt_file_size / (1024**2):.1f} MB)")

# --- FP8 Quantization (if enabled) ---
# Inserts Q/DQ nodes into the optimized ONNX and replaces onnx_opt_path with
# the FP8-annotated ONNX for the TRT build step below.
onnx_trt_input = onnx_opt_path # default: use FP16 opt ONNX
fp8_trt = fp8 # may be set to False below if FP8 quantization fails
if fp8:
onnx_fp8_path = onnx_opt_path.replace(".opt.onnx", ".fp8.onnx")
if not os.path.exists(onnx_fp8_path):
_build_logger.warning(f"[BUILD] FP8 quantization starting...")
t0 = time.perf_counter()
from .fp8_quantize import quantize_onnx_fp8
try:
quantize_onnx_fp8(
onnx_opt_path,
onnx_fp8_path,
model_data=self.model,
opt_batch_size=opt_batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
)
elapsed = time.perf_counter() - t0
stats["stages"]["fp8_quantize"] = {"status": "built", "elapsed_s": round(elapsed, 2)}
_build_logger.warning(f"[BUILD] FP8 quantization ({engine_filename}): {elapsed:.1f}s")
onnx_trt_input = onnx_fp8_path
except Exception as fp8_err:
elapsed = time.perf_counter() - t0
_build_logger.warning(
f"[BUILD] FP8 quantization failed after {elapsed:.1f}s: {fp8_err}. "
f"Falling back to FP16 TensorRT engine (onnx_trt_input unchanged)."
)
stats["stages"]["fp8_quantize"] = {
"status": "failed_fallback_fp16",
"elapsed_s": round(elapsed, 2),
"error": str(fp8_err),
}
# onnx_trt_input remains onnx_opt_path (FP16 ONNX)
# Disable FP8 engine build path (avoids STRONGLY_TYPED flag)
fp8_trt = False
else:
_build_logger.info(f"[BUILD] Found cached FP8 ONNX: {onnx_fp8_path}")
stats["stages"]["fp8_quantize"] = {"status": "cached"}
onnx_trt_input = onnx_fp8_path

# --- TRT Engine Build ---
if not force_engine_build and os.path.exists(engine_path):
print(f"Found cached engine: {engine_path}")
Expand All @@ -153,7 +199,7 @@ def build(
t0 = time.perf_counter()
build_engine(
engine_path=engine_path,
onnx_opt_path=onnx_opt_path,
onnx_opt_path=onnx_trt_input,
model_data=self.model,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
Expand All @@ -162,21 +208,13 @@ def build(
build_dynamic_shape=build_dynamic_shape,
build_all_tactics=build_all_tactics,
build_enable_refit=build_enable_refit,
fp8=fp8_trt,
)
elapsed = time.perf_counter() - t0
stats["stages"]["trt_build"] = {"status": "built", "elapsed_s": round(elapsed, 2)}
_build_logger.warning(f"[BUILD] TRT engine build ({engine_filename}): {elapsed:.1f}s")

# Cleanup ONNX artifacts — tolerate Windows file-lock failures (Issue #4)
for file in os.listdir(os.path.dirname(engine_path)):
if file.endswith('.engine'):
continue
try:
os.remove(os.path.join(os.path.dirname(engine_path), file))
except OSError as cleanup_err:
_build_logger.warning(f"[BUILD] Could not delete temp file {file}: {cleanup_err}")

# Record totals
# Record totals (before cleanup so build_stats.json is preserved)
total_elapsed = time.perf_counter() - build_total_start
stats["total_elapsed_s"] = round(total_elapsed, 2)
stats["build_end"] = datetime.now(timezone.utc).isoformat()
Expand All @@ -188,5 +226,46 @@ def build(
_build_logger.warning(f"[BUILD] {engine_filename} complete: {total_elapsed:.1f}s total")
_write_build_stats(engine_path, stats)

gc.collect()
torch.cuda.empty_cache()
# Cleanup ONNX artifacts — preserve .engine, .fp8.onnx, and build_stats.json
# Two-pass deletion to handle Windows file locks (gc.collect releases Python handles)
_keep_suffixes = (".engine", ".fp8.onnx")
_keep_exact = {"build_stats.json"}
engine_dir = os.path.dirname(engine_path)
_to_delete = []
for file in os.listdir(engine_dir):
if file in _keep_exact or any(file.endswith(s) for s in _keep_suffixes):
continue
_to_delete.append(os.path.join(engine_dir, file))

if _to_delete:
_failed = []
for fpath in _to_delete:
try:
os.remove(fpath)
except OSError:
_failed.append(fpath)

# Release Python-held file handles (ONNX model refs), retry failures
if _failed:
gc.collect()
torch.cuda.empty_cache()
time.sleep(0.5)
_still_failed = []
for fpath in _failed:
try:
os.remove(fpath)
except OSError as cleanup_err:
_still_failed.append(os.path.basename(fpath))
_build_logger.warning(f"[BUILD] Could not delete temp file {os.path.basename(fpath)}: {cleanup_err}")
if _still_failed:
_build_logger.warning(
f"[BUILD] {len(_still_failed)} intermediate files could not be cleaned. "
f"Manual cleanup: delete all files except *.engine and *.fp8.onnx from {engine_dir}"
)
cleaned = len(_to_delete) - len(_still_failed)
else:
cleaned = len(_to_delete)
_build_logger.info(f"[BUILD] Cleaned {cleaned}/{len(_to_delete)} intermediate files")
else:
gc.collect()
torch.cuda.empty_cache()
5 changes: 4 additions & 1 deletion src/streamdiffusion/acceleration/tensorrt/engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def get_engine_path(self,
controlnet_model_id: Optional[str] = None,
is_faceid: Optional[bool] = None,
use_cached_attn: bool = False,
use_controlnet: bool = False
use_controlnet: bool = False,
fp8: bool = False
) -> Path:
"""
Generate engine path using wrapper.py's current logic.
Expand Down Expand Up @@ -151,6 +152,8 @@ def get_engine_path(self,
prefix += f"--use_cached_attn-{use_cached_attn}"
if use_controlnet:
prefix += "--controlnet"
if fp8:
prefix += "--fp8"

prefix += f"--mode-{mode}"

Expand Down
Loading