Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/streamdiffusion/acceleration/tensorrt/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,16 @@ def __del__(self):
CUASSERT(cudart.cudaGraphDestroy(self.graph))
except:
pass


# Destroy execution context before engine to release CUDA memory properly.
# TensorRT requires context to be destroyed before engine; nullifying the
# Python reference alone is insufficient — set to None explicitly so the
# C++ destructor runs before the engine is released.
if hasattr(self, 'context') and self.context is not None:
self.context = None
if hasattr(self, 'engine') and self.engine is not None:
self.engine = None

del self.engine
del self.context
del self.buffers
Expand Down
250 changes: 163 additions & 87 deletions src/streamdiffusion/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,35 @@ def _load_model(
self.cleanup_gpu_memory()
except Exception as e:
logger.warning(f"GPU cleanup warning: {e}")


# ------------------------------------------------------------------
# VRAM pre-flight check — abort early rather than OOM mid-load.
#
# SDXL models with TensorRT typically require ~14–18 GB of free VRAM
# for engine activation + text_encoder_2 + inference buffers. If less
# than MIN_FREE_VRAM_GB is available after cleanup we raise immediately
# instead of attempting to load and crashing after partial allocation.
# ------------------------------------------------------------------
_MIN_FREE_VRAM_GB = 2.0 # conservative floor; adjust per deployment
try:
_free_bytes, _total_bytes = torch.cuda.mem_get_info()
_free_gb = _free_bytes / (1024 ** 3)
_total_gb = _total_bytes / (1024 ** 3)
logger.info(
f"_load_model: VRAM pre-check — {_free_gb:.2f} GB free / {_total_gb:.2f} GB total"
)
if _free_gb < _MIN_FREE_VRAM_GB:
raise RuntimeError(
f"_load_model: Insufficient free VRAM before model load: "
f"{_free_gb:.2f} GB available, {_MIN_FREE_VRAM_GB:.2f} GB required. "
f"GPU memory cleanup did not release enough non-PyTorch memory "
f"(likely residual TensorRT context). Restart the worker process."
)
except RuntimeError:
raise
except Exception as _vram_check_err:
logger.warning(f"_load_model: VRAM pre-check skipped: {_vram_check_err}")

# Reset CUDA context to prevent corruption from previous runs
torch.cuda.empty_cache()
torch.cuda.synchronize()
Expand Down Expand Up @@ -2156,114 +2184,162 @@ def get_stream_state(self, include_caches: bool = False) -> Dict[str, Any]:

return state

@staticmethod
def _destroy_trt_engine(engine_obj) -> None:
"""
Explicitly destroy a TensorRT ``Engine`` wrapper and its underlying
ICudaEngine / IExecutionContext objects.

TensorRT requires the execution context to be released *before* the
engine, otherwise GPU memory can remain pinned until the CUDA context
is torn down. Simply ``del``-ing the Python wrapper is not reliable
because CPython may defer the destructor call; setting the attributes
to ``None`` forces the reference count to drop immediately.
"""
if engine_obj is None:
return
# Step 1: free output/input tensor buffers
try:
if hasattr(engine_obj, 'buffers'):
from cuda import cuda as _cuda
for buf in engine_obj.buffers.values():
try:
buf.free()
except Exception:
pass
engine_obj.buffers.clear()
except Exception:
pass
# Step 2: destroy execution context first (order matters for TRT)
try:
if hasattr(engine_obj, 'context') and engine_obj.context is not None:
engine_obj.context = None
except Exception:
pass
# Step 3: destroy the ICudaEngine
try:
if hasattr(engine_obj, 'engine') and engine_obj.engine is not None:
engine_obj.engine = None
except Exception:
pass

def cleanup_gpu_memory(self) -> None:
"""Comprehensive GPU memory cleanup for model switching."""
"""
Comprehensive GPU memory cleanup for model switching.

The previous implementation called ``__del__()`` manually, which is
unreliable — CPython may not honour the destructor immediately, leaving
TensorRT-allocated GPU memory (engines, execution contexts, CUDA
scratch buffers) pinned even after the call returns. This revision
instead:

1. Explicitly nullifies the TRT context *then* engine on every engine
wrapper object (the required teardown order per the TRT docs).
2. Runs ``gc.collect()`` *between* context and engine deletion so CPython
releases the objects before the next allocation attempt.
3. Calls ``torch.cuda.synchronize()`` and ``torch.cuda.empty_cache()``
after all Python objects are gone to flush PyTorch's caching
allocator.
4. Reports both PyTorch and total GPU memory so callers can detect
residual non-PyTorch allocations (e.g. leftover TRT CUDA context).
"""
import gc
import torch

logger.info("Cleaning up GPU memory...")
# Clear prompt caches

# --- 1. Clear prompt caches ------------------------------------------------
if hasattr(self, 'stream') and self.stream:
try:
self.stream._param_updater.clear_caches()
logger.info(" Cleared prompt caches")
except:
except Exception:
pass
# Enhanced TensorRT engine cleanup

# --- 2. Explicit TensorRT engine teardown ----------------------------------
if hasattr(self, 'stream') and self.stream:
try:
# Cleanup UNet TensorRT engine
if hasattr(self.stream, 'unet'):
unet_engine = self.stream.unet
logger.info(" Cleaning up TensorRT UNet engine...")

# Check if it's a TensorRT engine and cleanup properly
if hasattr(unet_engine, 'engine') and hasattr(unet_engine.engine, '__del__'):
try:
# Call the engine's destructor explicitly
unet_engine.engine.__del__()
except:
pass

# Clear all engine-related attributes
if hasattr(unet_engine, 'context'):
try:
del unet_engine.context
except:
pass
if hasattr(unet_engine, 'engine'):
try:
del unet_engine.engine.engine # TensorRT runtime engine
del unet_engine.engine
except:
pass

del self.stream.unet
logger.info(" UNet engine cleanup completed")

# Cleanup VAE TensorRT engines
if hasattr(self.stream, 'vae'):
vae_engine = self.stream.vae
logger.info(" Cleaning up TensorRT VAE engines...")

# VAE has encoder and decoder engines
for engine_name in ['vae_encoder', 'vae_decoder']:
if hasattr(vae_engine, engine_name):
engine = getattr(vae_engine, engine_name)
if hasattr(engine, 'engine') and hasattr(engine.engine, '__del__'):
try:
engine.engine.__del__()
except:
pass
try:
delattr(vae_engine, engine_name)
except:
pass

del self.stream.vae
logger.info(" VAE engines cleanup completed")

# Cleanup ControlNet engine pool if it exists
if hasattr(self.stream, 'controlnet_engine_pool'):
logger.info(" Cleaning up ControlNet engine pool...")
# UNet engine
if hasattr(self.stream, 'unet') and self.stream.unet is not None:
logger.info(" Destroying TensorRT UNet engine...")
unet_wrapper = self.stream.unet
# UNet2DConditionModelEngine.engine is an Engine instance
if hasattr(unet_wrapper, 'engine'):
self._destroy_trt_engine(unet_wrapper.engine)
unet_wrapper.engine = None
self.stream.unet = None
del unet_wrapper
gc.collect()
logger.info(" UNet engine destroyed")

# VAE encoder / decoder engines
if hasattr(self.stream, 'vae') and self.stream.vae is not None:
logger.info(" Destroying TensorRT VAE engines...")
vae_wrapper = self.stream.vae
for attr in ('vae_encoder', 'vae_decoder'):
if hasattr(vae_wrapper, attr):
sub = getattr(vae_wrapper, attr)
if sub is not None and hasattr(sub, 'engine'):
self._destroy_trt_engine(sub.engine)
sub.engine = None
setattr(vae_wrapper, attr, None)
self.stream.vae = None
del vae_wrapper
gc.collect()
logger.info(" VAE engines destroyed")

# ControlNet engine pool
if hasattr(self.stream, 'controlnet_engine_pool') and \
self.stream.controlnet_engine_pool is not None:
logger.info(" Destroying ControlNet engine pool...")
try:
self.stream.controlnet_engine_pool.cleanup()
del self.stream.controlnet_engine_pool
logger.info(" ControlNet engine pool cleanup completed")
except:
except Exception:
pass

except Exception as e:
logger.error(f" TensorRT cleanup warning: {e}")

# Clear the entire stream object to free all models
if hasattr(self, 'stream'):
self.stream.controlnet_engine_pool = None
gc.collect()
logger.info(" ControlNet engine pool destroyed")

except Exception as exc:
logger.error(f" TensorRT cleanup warning: {exc}")

# --- 3. Drop the entire StreamDiffusion object -----------------------------
if hasattr(self, 'stream') and self.stream is not None:
try:
del self.stream
logger.info(" Cleared stream object")
except:
except Exception:
pass
self.stream = None
# Force multiple garbage collection cycles for thorough cleanup
for i in range(3):

# --- 4. Python GC + CUDA allocator flush -----------------------------------
for _ in range(3):
gc.collect()

# Clear CUDA cache and cleanup IPC handles
torch.cuda.empty_cache()

torch.cuda.synchronize()

# Force additional memory cleanup
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# Get memory info
allocated = torch.cuda.memory_allocated() / (1024**3) # GB
cached = torch.cuda.memory_reserved() / (1024**3) # GB
logger.info(f" GPU Memory after cleanup: {allocated:.2f}GB allocated, {cached:.2f}GB cached")

logger.info(" Enhanced GPU memory cleanup complete")

# --- 5. Diagnostics --------------------------------------------------------
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
cached = torch.cuda.memory_reserved() / (1024 ** 3)
logger.info(f" PyTorch GPU memory after cleanup: {allocated:.2f} GB allocated, {cached:.2f} GB cached")

# Warn when non-PyTorch GPU memory is substantial (likely residual TRT context)
try:
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
free_mem_gb = (torch.cuda.mem_get_info()[0]) / (1024 ** 3)
non_torch = total_mem - free_mem_gb - allocated
if non_torch > 0.5:
logger.warning(
f" Non-PyTorch GPU memory still in use: ~{non_torch:.2f} GB "
f"(TensorRT context / driver overhead). "
f"Free: {free_mem_gb:.2f} GB / {total_mem:.2f} GB total."
)
except Exception:
pass

logger.info(" GPU memory cleanup complete")

def check_gpu_memory_for_engine(self, engine_size_gb: float) -> bool:
"""
Expand Down