Skip to content

fix: release transformer weights after infer when cpu_offload is on (flux2 + qwen_image)#1051

Closed
HanFa wants to merge 2 commits intoModelTC:mainfrom
sutro-planet:fix/cpu-offload-release-weights-after-infer
Closed

fix: release transformer weights after infer when cpu_offload is on (flux2 + qwen_image)#1051
HanFa wants to merge 2 commits intoModelTC:mainfrom
sutro-planet:fix/cpu-offload-release-weights-after-infer

Conversation

@HanFa
Copy link
Copy Markdown

@HanFa HanFa commented May 6, 2026

Problem

Flux2KleinTransformerModel.infer, Flux2DevTransformerModel.infer, and QwenImageTransformerModel.infer all call self.to_cuda() (or partial swaps on pre_weight/post_weight) at the start of an inference step, but never swap weights back to CPU at the end of a step. Once the first step's to_cuda() runs, the transformer weights stay GPU-resident permanently, which defeats the purpose of cpu_offload and OOMs memory-tight cards on the next text-encoder swap-in.

Concrete repro: flux2_klein with FLUX.2-Klein-9B on a 32 GB Blackwell card. Text encoder is 16 GB bf16, DiT is 17 GB bf16. First call works because the TE is on CPU during DiT load. Second call's text_encoder.to(AI_DEVICE) OOMs because DiT (17 GB) was never released:

File "lightx2v/models/input_encoders/hf/flux2/qwen3_model.py", line 48, in infer
    self.text_encoder.to(AI_DEVICE)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB.
GPU 0 has a total capacity of 31.35 GiB of which 31.12 MiB is free.
Including non-PyTorch memory, this process has 30.96 GiB memory in use.

PR #1034 added a lazy_load=true mode that side-steps this by deleting and disk-reloading TE/DiT/VAE per call, but that's ~30s/call from disk reads, vs ~16s warm with this fix on the same hardware.

What this PR does

Adds the symmetric end-of-step swap-to-CPU in both files. Mirrors the block at the start of each infer() exactly, so semantics are: every weight that gets moved to GPU at step start gets moved back to CPU at step end (or, for offload_granularity == \"model\", at the last step).

This matches the canonical pattern already in ZImageTransformerModel.infer (z_image/model.py:140-146) and WanTransformerModel.infer. Survey:

Model `to_cuda` at start `to_cpu` at end
`z_image/model.py`
`wan/model.py`
`flux2/model.py` ❌ → ✅ (this PR)
`qwen_image/model.py` ❌ → ✅ (this PR)

Compatibility with PR #1034 (lazy_load)

Additive — does not touch the `lazy_load` paths. Users running `lazy_load=true` will see the new `to_cpu()` call before the runner deletes & reloads the transformer next call; the operation is a no-op on already-released weights and adds at most a single attribute walk.

Testing

Validated on `flux2_klein` with FLUX.2-Klein-9B on RTX 5090 32 GB, `offload_granularity=phase`, `cpu_offload=true`, 4 steps, torch_sdpa attn, rope split, multi-image i2i (2 reference images):

  • Before: 1st call succeeds (~65s), 2nd call returns 500 OOM at `text_encoder.to(AI_DEVICE)` with `30.96 GiB / 31.35 GiB in use`.
  • After: 3 sequential calls all succeed (65s cold, 16s × 2 warm), no OOM, peak GPU usage stays under 25 GiB between calls.

Notes

  • `qwen_image`'s `transformer_weights` does not currently expose a `non_block_weights_to_cuda/to_cpu` pair the way `z_image` does, so the qwen_image fix only covers what its start-of-infer touches. Adding the non-block-weights pair would close the remaining gap but is out of scope.
  • A wan2.2_moe carve-out is not relevant here (z_image excludes it because of a model-cls-specific reload pattern — neither flux2 nor qwen_image have the equivalent path).

HanFa and others added 2 commits May 5, 2026 18:52
_Flux2TransformerModelBase has `if self.cpu_offload: self.to_cuda()`
at the start of infer(), but no symmetric `to_cpu()` at the end. After
the first inference, the entire DiT (≈17 GB for Klein-9B bf16) stays
GPU-resident permanently, defeating cpu_offload's purpose: the next
call's text-encoder swap-in OOMs on memory-tight cards (32 GB Blackwell
in our case, with TE 16 GB + DiT 17 GB > 32 GB).

The pattern in this same repo:
  - z_image/model.py — has both to_cuda + to_cpu (correct)
  - wan/model.py     — has both to_cuda + to_cpu (correct)
  - flux2/model.py   — only to_cuda  (this fix)
  - qwen_image/model.py — only to_cuda (same bug, separate fix)

PR ModelTC#1034 added a `lazy_load=true` mode that side-steps this by
deleting & disk-reloading TE/DiT/VAE per call, but that costs 30s+
per call and isn't necessary if we just release weights after infer.
With this one-line addition, Klein-9B i2i runs back-to-back in 16s
warm on a 32 GB card without lazy_load.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same asymmetry as flux2 in the previous commit: QwenImageTransformerModel.infer()
swaps weights to GPU at step start (full to_cuda when offload_granularity=="model"
on step 0; pre+post weights otherwise) but never swaps them back to CPU. After
the first inference, transformer weights stay GPU-resident permanently.

This mirror of the start-of-infer block matches the canonical pattern in
ZImageTransformerModel.infer (z_image/model.py:140-146):

  - granularity=="model" + last step  → self.to_cpu()
  - granularity!="model"               → pre_weight.to_cpu() + post_weight.to_cpu()
                                         (block-level swap is handled inside
                                          QwenImageOffloadTransformerInfer)

Note: z_image's branch also calls non_block_weights_to_cpu() on the
transformer_weights; qwen_image does not currently expose that method, so
this fix only covers what qwen_image's start-of-infer touches. Adding a
non_block_weights_to_{cuda,cpu} pair on QwenImageTransformerWeights would
close the remaining gap but is out of scope here.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements CPU offloading for Flux2 and Qwen Image models to improve memory efficiency. Key feedback includes a potential bug in the Qwen model where an incorrect attribute name for inference steps might cause an AttributeError, and a performance optimization for the Flux2 model to make CPU offloading granularity-aware, avoiding redundant weight transfers during the denoising process.

self.scheduler.noise_pred = noise_pred

if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Please verify if self.scheduler.infer_steps is the correct attribute name for the total number of inference steps. In many common scheduler implementations (such as those in the diffusers library), this attribute is named num_inference_steps. If infer_steps is not defined on the scheduler object, this will trigger an AttributeError at the end of the very first inference step, causing the process to crash.

Comment on lines +194 to +195
if self.cpu_offload:
self.to_cpu()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this fix correctly addresses the memory leak by ensuring weights are moved back to CPU, calling self.to_cpu() unconditionally at the end of every infer call is inefficient when offload_granularity is set to "model". In that mode, weights should ideally stay on the GPU across all denoising steps and only be moved back to the CPU after the final step.

Consider making this block granularity-aware (similar to the logic implemented in qwen_image/model.py), provided that the scheduler exposes the current step index and total steps. This would avoid the significant overhead of moving ~17GB of weights back and forth on every step.

helloyongyang pushed a commit that referenced this pull request May 6, 2026
The fix is adapted from HanFa's contribution:
#1051

Adjustments were made to ensure proper synchronization behavior.
@helloyongyang
Copy link
Copy Markdown
Contributor

Thanks! We just fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants