[BugFix]: Fix missing synchronization issue in cpu_offload#1053
[BugFix]: Fix missing synchronization issue in cpu_offload#1053helloyongyang merged 1 commit intomainfrom
Conversation
The fix is adapted from HanFa's contribution: #1051 Adjustments were made to ensure proper synchronization behavior.
There was a problem hiding this comment.
Code Review
This pull request implements block-level CPU offloading for Flux2 models, allowing for more efficient memory management during inference by asynchronously prefetching and swapping transformer blocks. It introduces the Flux2OffloadTransformerInfer class, updates weight modules to support dedicated offload buffers, and adds relevant configurations and scripts. Review feedback highlights the need to use consistent indices for offload buffer naming to prevent state dict loading issues, recommends extending non-block weight offloading to Qwen models for consistency, and suggests refactoring duplicated rotary embedding logic into a shared base class method.
| self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)]) | ||
| self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers) | ||
|
|
||
| self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)]) |
There was a problem hiding this comment.
The offload buffers are currently created using the loop index i as the block_idx. This results in the buffers having block-specific parameter names (e.g., transformer_blocks.0... and transformer_blocks.1...). When load_state_dict is called to load weights from an arbitrary block (e.g., block 15), the parameter name prefixes will not match, which may cause the loading to fail or require complex mapping logic. It is safer to use a fixed index (like 0) for all buffers to ensure consistent internal naming.
| self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)]) | |
| self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers) | |
| self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)]) | |
| self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, 0, create_cuda_buffer=True) for _ in range(2)]) | |
| self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers) | |
| self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, 0, create_cuda_buffer=True) for _ in range(2)]) |
| elif self.offload_granularity != "model": | ||
| self.pre_weight.to_cpu() | ||
| self.post_weight.to_cpu() |
There was a problem hiding this comment.
To maintain consistency with the offloading logic implemented for Flux2 models in this PR, you should also consider offloading the non-block transformer weights (e.g., modulation layers) here. If QwenImageTransformerWeights contains such weights, they should be moved back to CPU to maximize memory savings when offload_granularity is not model.
| elif self.offload_granularity != "model": | |
| self.pre_weight.to_cpu() | |
| self.post_weight.to_cpu() | |
| elif self.offload_granularity != "model": | |
| self.pre_weight.to_cpu() | |
| self.post_weight.to_cpu() | |
| self.transformer_weights.non_block_weights_to_cpu() |
| if self.seq_p_group is not None and image_rotary_emb is not None: | ||
| world_size = dist.get_world_size(self.seq_p_group) | ||
| cur_rank = dist.get_rank(self.seq_p_group) | ||
|
|
||
| if isinstance(image_rotary_emb, tuple): | ||
| freqs_cos, freqs_sin = image_rotary_emb | ||
|
|
||
| txt_cos = freqs_cos[:num_txt_tokens] | ||
| img_cos = freqs_cos[num_txt_tokens:] | ||
| txt_sin = freqs_sin[:num_txt_tokens] | ||
| img_sin = freqs_sin[num_txt_tokens:] | ||
|
|
||
| seqlen = img_cos.shape[0] | ||
| padding_size = (world_size - (seqlen % world_size)) % world_size | ||
| if padding_size > 0: | ||
| img_cos = F.pad(img_cos, (0, 0, 0, padding_size)) | ||
| img_sin = F.pad(img_sin, (0, 0, 0, padding_size)) | ||
| img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank] | ||
| img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank] | ||
|
|
||
| freqs_cos = torch.cat([txt_cos, img_cos], dim=0) | ||
| freqs_sin = torch.cat([txt_sin, img_sin], dim=0) | ||
| image_rotary_emb = (freqs_cos, freqs_sin) | ||
| else: | ||
| txt_emb = image_rotary_emb[:num_txt_tokens] | ||
| img_emb = image_rotary_emb[num_txt_tokens:] | ||
|
|
||
| seqlen = img_emb.shape[0] | ||
| padding_size = (world_size - (seqlen % world_size)) % world_size | ||
| if padding_size > 0: | ||
| img_emb = F.pad(img_emb, (0, 0, 0, padding_size)) | ||
| img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank] | ||
|
|
||
| image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0) |
There was a problem hiding this comment.
The logic for handling image_rotary_emb and applying RoPE padding is identical to the implementation in the base class Flux2TransformerInfer.infer. To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, this logic should be refactored into a shared helper method in the base class that can be called by both the standard and offloaded inference paths.
|
awesome! |
The fix is adapted from HanFa's contribution:
#1051
Adjustments were made to ensure proper synchronization behavior.