Skip to content

[BugFix]: Fix missing synchronization issue in cpu_offload#1053

Merged
helloyongyang merged 1 commit intomainfrom
fix/no_symmetric_offload
May 6, 2026
Merged

[BugFix]: Fix missing synchronization issue in cpu_offload#1053
helloyongyang merged 1 commit intomainfrom
fix/no_symmetric_offload

Conversation

@wangshankun
Copy link
Copy Markdown
Collaborator

The fix is adapted from HanFa's contribution:
#1051

Adjustments were made to ensure proper synchronization behavior.

The fix is adapted from HanFa's contribution:
#1051

Adjustments were made to ensure proper synchronization behavior.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements block-level CPU offloading for Flux2 models, allowing for more efficient memory management during inference by asynchronously prefetching and swapping transformer blocks. It introduces the Flux2OffloadTransformerInfer class, updates weight modules to support dedicated offload buffers, and adds relevant configurations and scripts. Review feedback highlights the need to use consistent indices for offload buffer naming to prevent state dict loading issues, recommends extending non-block weight offloading to Qwen models for consistency, and suggests refactoring duplicated rotary embedding logic into a shared base class method.

Comment on lines +293 to +296
self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)])
self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers)

self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The offload buffers are currently created using the loop index i as the block_idx. This results in the buffers having block-specific parameter names (e.g., transformer_blocks.0... and transformer_blocks.1...). When load_state_dict is called to load weights from an arbitrary block (e.g., block 15), the parameter name prefixes will not match, which may cause the loading to fail or require complex mapping logic. It is safer to use a fixed index (like 0) for all buffers to ensure consistent internal naming.

Suggested change
self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)])
self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers)
self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)])
self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, 0, create_cuda_buffer=True) for _ in range(2)])
self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers)
self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, 0, create_cuda_buffer=True) for _ in range(2)])

Comment on lines +150 to +152
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To maintain consistency with the offloading logic implemented for Flux2 models in this PR, you should also consider offloading the non-block transformer weights (e.g., modulation layers) here. If QwenImageTransformerWeights contains such weights, they should be moved back to CPU to maximize memory savings when offload_granularity is not model.

Suggested change
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()

Comment on lines +38 to +71
if self.seq_p_group is not None and image_rotary_emb is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

if isinstance(image_rotary_emb, tuple):
freqs_cos, freqs_sin = image_rotary_emb

txt_cos = freqs_cos[:num_txt_tokens]
img_cos = freqs_cos[num_txt_tokens:]
txt_sin = freqs_sin[:num_txt_tokens]
img_sin = freqs_sin[num_txt_tokens:]

seqlen = img_cos.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_cos = F.pad(img_cos, (0, 0, 0, padding_size))
img_sin = F.pad(img_sin, (0, 0, 0, padding_size))
img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank]
img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank]

freqs_cos = torch.cat([txt_cos, img_cos], dim=0)
freqs_sin = torch.cat([txt_sin, img_sin], dim=0)
image_rotary_emb = (freqs_cos, freqs_sin)
else:
txt_emb = image_rotary_emb[:num_txt_tokens]
img_emb = image_rotary_emb[num_txt_tokens:]

seqlen = img_emb.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_emb = F.pad(img_emb, (0, 0, 0, padding_size))
img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank]

image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling image_rotary_emb and applying RoPE padding is identical to the implementation in the base class Flux2TransformerInfer.infer. To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, this logic should be refactored into a shared helper method in the base class that can be called by both the standard and offloaded inference paths.

@helloyongyang helloyongyang merged commit c177fd8 into main May 6, 2026
2 checks passed
@helloyongyang helloyongyang deleted the fix/no_symmetric_offload branch May 6, 2026 07:36
@HanFa
Copy link
Copy Markdown

HanFa commented May 6, 2026

awesome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants