Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/streamdiffusion/acceleration/tensorrt/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def build(
build_enable_refit: bool = False,
build_static_batch: bool = False,
build_dynamic_shape: bool = True,
build_all_tactics: bool = False,
onnx_opset: int = 17,
force_engine_build: bool = False,
force_onnx_export: bool = False,
Expand All @@ -84,7 +83,6 @@ def build(
"opt_resolution": f"{opt_image_width}x{opt_image_height}",
"dynamic_range": f"{min_image_resolution}-{max_image_resolution}" if build_dynamic_shape else "static",
"batch_size": opt_batch_size,
"build_all_tactics": build_all_tactics,
"stages": {},
}

Expand Down Expand Up @@ -206,7 +204,6 @@ def build(
opt_batch_size=opt_batch_size,
build_static_batch=build_static_batch,
build_dynamic_shape=build_dynamic_shape,
build_all_tactics=build_all_tactics,
build_enable_refit=build_enable_refit,
fp8=fp8_trt,
)
Expand All @@ -226,10 +223,10 @@ def build(
_build_logger.warning(f"[BUILD] {engine_filename} complete: {total_elapsed:.1f}s total")
_write_build_stats(engine_path, stats)

# Cleanup ONNX artifacts — preserve .engine, .fp8.onnx, and build_stats.json
# Cleanup ONNX artifacts — preserve .engine, .fp8.onnx, timing.cache, and build_stats.json
# Two-pass deletion to handle Windows file locks (gc.collect releases Python handles)
_keep_suffixes = (".engine", ".fp8.onnx")
_keep_exact = {"build_stats.json"}
_keep_suffixes = (".engine", ".fp8.onnx", ".cache")
_keep_exact = {"build_stats.json", "timing.cache"}
engine_dir = os.path.dirname(engine_path)
_to_delete = []
for file in os.listdir(engine_dir):
Expand Down
321 changes: 176 additions & 145 deletions src/streamdiffusion/acceleration/tensorrt/engine_manager.py

Large diffs are not rendered by default.

34 changes: 20 additions & 14 deletions src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ def quantize_onnx_fp8(
onnx_fp8_path: Output FP8 quantized ONNX path (*.fp8.onnx).
calibration_data: Unused. Kept for backward compatibility.
quantize_mha: Enable FP8 quantization of multi-head attention ops.
Recommended: True. Requires TRT 10+ and compute 8.9+.
percentile: Percentile for activation range calibration.
1.0 = no clipping (safest for first run).
Kept False — MHA analysis via ORT inference adds ~3 hours to build.
Non-MHA ops (Conv, Gemm, MatMul outside MHA) are still FP8.
percentile: Unused. Kept for backward compatibility (entropy calibration
does not use percentile clipping).
alpha: SmoothQuant alpha — balances quantization difficulty between
activations (alpha→0) and weights (alpha→1). 0.8 is optimal
for transformer attention layers.
Expand Down Expand Up @@ -256,7 +257,7 @@ def quantize_onnx_fp8(
logger.info(f"[FP8] Starting ONNX FP8 quantization")
logger.info(f"[FP8] Input: {onnx_opt_path} ({input_size_mb:.0f} MB)")
logger.info(f"[FP8] Output: {onnx_fp8_path}")
logger.info(f"[FP8] Config: quantize_mha={quantize_mha}, percentile={percentile}, alpha={alpha}")
logger.info(f"[FP8] Config: quantize_mha={quantize_mha}, calibration=entropy, alpha={alpha}")
logger.info(f"[FP8] Calibration: RandomDataProvider with calibration_shapes (model_data={'provided' if model_data is not None else 'none'})")

# Patch ByteSize() for >2GB ONNX models: modelopt calls onnx_model.ByteSize()
Expand Down Expand Up @@ -373,8 +374,10 @@ def _safe_byte_size(self):
quantize_kwargs = {
"quantize_mode": "fp8",
"output_path": onnx_fp8_path,
"calibration_method": "percentile",
"percentile": percentile,
# entropy: minimizes KL divergence to find optimal clipping point for each tensor.
# Better than percentile=1.0 (no clipping) which allows outliers to stretch the
# quantization range, reducing precision for the bulk of activations.
"calibration_method": "entropy",
"alpha": alpha,
"use_external_data_format": True,
# override_shapes replaces dynamic dims in the ONNX model itself with static
Expand All @@ -388,20 +391,23 @@ def _safe_byte_size(self):
# Use default EPs ["cpu","cuda:0","trt"] — CPU-only would fail on this FP16 SDXL
# model because ORT's mandatory CastFloat16Transformer inserts Cast nodes that
# conflict with existing Cast nodes in the upsampler conv.
# disable_mha_qdq controls modelopt's MHA analysis. When True, MHA MatMul
# nodes are excluded from FP8 quantization WITHOUT running ORT inference.
# Non-MHA ops (Conv, Linear, LayerNorm) still get FP8 Q/DQ nodes.
# disable_mha_qdq=True: skip MHA pattern analysis (avoids 3-hour ORT inference
# pass over the full model graph). Non-MHA ops (Conv, Gemm, MatMul outside MHA)
# still get FP8 Q/DQ nodes via the normal KGEN/CASK path.
"disable_mha_qdq": not quantize_mha,
# calibrate_per_node: calibrate one node at a time to reduce peak VRAM during
# calibration. Essential for large UNets (83 inputs, 7993 nodes) to avoid OOM.
"calibrate_per_node": True,
}

try:
modelopt_quantize(onnx_opt_path, **quantize_kwargs)
except TypeError as e:
# Older nvidia-modelopt versions may not support alpha/disable_mha_qdq.
# Retry with base parameters only.
logger.warning(f"[FP8] Retrying without alpha/disable_mha_qdq (TypeError: {e})")
quantize_kwargs.pop("alpha", None)
quantize_kwargs.pop("disable_mha_qdq", None)
# Older nvidia-modelopt versions may not support newer kwargs.
# Strip down to base parameters and retry.
logger.warning(f"[FP8] Retrying with reduced kwargs (TypeError: {e})")
for _k in ("alpha", "disable_mha_qdq", "calibrate_per_node"):
quantize_kwargs.pop(_k, None)
modelopt_quantize(onnx_opt_path, **quantize_kwargs)
except Exception as e:
# MHA analysis (disable_mha_qdq=False) requires an ORT inference run that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def __init__(self):
self._curr_key_buf: Optional[torch.Tensor] = None
self._curr_value_buf: Optional[torch.Tensor] = None
self._kv_out_buf: Optional[torch.Tensor] = None # shape: (2, 1, B, seq, inner_dim)
# When False (default): ONNX-safe .clone() path — used during torch.onnx.export() tracing.
# When True: zero-alloc .copy_() path — set after ONNX export for non-TRT runtime inference.
# NOTE: aten::copy has no ONNX symbolic and cannot be traced; never set True before export.
self._use_prealloc: bool = False

# Pre-allocated buffers for zero-alloc hot path (lazy init on first call).
# _use_prealloc is False by default so ONNX export tracing uses the original
Expand Down
Loading