diff --git a/configs/flux2/flux2_dev_offload.json b/configs/flux2/flux2_dev_offload.json index 1528f198a..6f5ba2523 100644 --- a/configs/flux2/flux2_dev_offload.json +++ b/configs/flux2/flux2_dev_offload.json @@ -11,5 +11,5 @@ "rope_type": "flashinfer", "text_encoder_out_layers": [10, 20, 30], "cpu_offload": true, - "offload_granularity": "block" + "offload_granularity": "model" } diff --git a/configs/flux2/flux2_klein_distill_offload.json b/configs/flux2/flux2_klein_distill_offload.json new file mode 100644 index 000000000..31dca84d5 --- /dev/null +++ b/configs/flux2/flux2_klein_distill_offload.json @@ -0,0 +1,14 @@ +{ + "model_cls": "flux2_klein", + "task": "t2i", + "infer_steps": 4, + "sample_guide_scale": 1.0, + "vae_scale_factor": 16, + "feature_caching": "None", + "enable_cfg": false, + "patch_size": 2, + "tokenizer_max_length": 512, + "rope_type": "flashinfer", + "cpu_offload": true, + "offload_granularity": "model" +} diff --git a/lightx2v/models/networks/flux2/infer/offload/transformer_infer.py b/lightx2v/models/networks/flux2/infer/offload/transformer_infer.py new file mode 100644 index 000000000..b079e9e2b --- /dev/null +++ b/lightx2v/models/networks/flux2/infer/offload/transformer_infer.py @@ -0,0 +1,127 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from lightx2v.common.offload.manager import WeightAsyncStreamManager +from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +class Flux2OffloadTransformerInfer(Flux2TransformerInfer): + """Flux2 transformer inference with block-level CPU offload.""" + + def __init__(self, config): + super().__init__(config) + if self.config.get("cpu_offload", False): + offload_granularity = self.config.get("offload_granularity", "block") + if offload_granularity == "block": + self.infer_func = self.infer_with_blocks_offload + self.offload_manager_double = WeightAsyncStreamManager(offload_granularity=offload_granularity) + self.offload_manager_single = WeightAsyncStreamManager(offload_granularity=offload_granularity) + elif offload_granularity == "model": + self.infer_func = super().infer + else: + raise ValueError(f"Unsupported offload_granularity: {offload_granularity}") + else: + self.infer_func = super().infer + + def infer_with_blocks_offload(self, block_weights, pre_infer_out): + hidden_states = pre_infer_out.hidden_states + encoder_hidden_states = pre_infer_out.encoder_hidden_states + timestep = pre_infer_out.timestep + image_rotary_emb = pre_infer_out.image_rotary_emb + + num_txt_tokens = encoder_hidden_states.shape[0] + + if self.seq_p_group is not None and image_rotary_emb is not None: + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + if isinstance(image_rotary_emb, tuple): + freqs_cos, freqs_sin = image_rotary_emb + + txt_cos = freqs_cos[:num_txt_tokens] + img_cos = freqs_cos[num_txt_tokens:] + txt_sin = freqs_sin[:num_txt_tokens] + img_sin = freqs_sin[num_txt_tokens:] + + seqlen = img_cos.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + img_cos = F.pad(img_cos, (0, 0, 0, padding_size)) + img_sin = F.pad(img_sin, (0, 0, 0, padding_size)) + img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank] + img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank] + + freqs_cos = torch.cat([txt_cos, img_cos], dim=0) + freqs_sin = torch.cat([txt_sin, img_sin], dim=0) + image_rotary_emb = (freqs_cos, freqs_sin) + else: + txt_emb = image_rotary_emb[:num_txt_tokens] + img_emb = image_rotary_emb[num_txt_tokens:] + + seqlen = img_emb.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + img_emb = F.pad(img_emb, (0, 0, 0, padding_size)) + img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank] + + image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0) + + timestep_act = F.silu(timestep) + double_stream_mod_img = block_weights.double_stream_modulation_img_linear.apply(timestep_act) + double_stream_mod_txt = block_weights.double_stream_modulation_txt_linear.apply(timestep_act) + single_stream_mod = block_weights.single_stream_modulation_linear.apply(timestep_act) + + current_stream = torch_device_module.current_stream() + self.offload_manager_double.compute_stream.wait_stream(current_stream) + for block_idx in range(len(block_weights.double_blocks)): + self.block_idx = block_idx + + if self.offload_manager_double.need_init_first_buffer: + self.offload_manager_double.init_first_buffer(block_weights.double_blocks) + + self.offload_manager_double.prefetch_weights((block_idx + 1) % len(block_weights.double_blocks), block_weights.double_blocks) + + with torch_device_module.stream(self.offload_manager_double.compute_stream): + encoder_hidden_states, hidden_states = self.infer_double_stream_block( + self.offload_manager_double.cuda_buffers[0], + hidden_states, + encoder_hidden_states, + double_stream_mod_img, + double_stream_mod_txt, + image_rotary_emb, + ) + + self.offload_manager_double.swap_blocks() + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0) + + self.offload_manager_single.compute_stream.wait_stream(self.offload_manager_double.compute_stream) + for block_idx in range(len(block_weights.single_blocks)): + self.block_idx = block_idx + + if self.offload_manager_single.need_init_first_buffer: + self.offload_manager_single.init_first_buffer(block_weights.single_blocks) + + self.offload_manager_single.prefetch_weights((block_idx + 1) % len(block_weights.single_blocks), block_weights.single_blocks) + + with torch_device_module.stream(self.offload_manager_single.compute_stream): + hidden_states = self.infer_single_stream_block( + self.offload_manager_single.cuda_buffers[0], + hidden_states, + None, + single_stream_mod, + image_rotary_emb, + num_txt_tokens=num_txt_tokens, + ) + + self.offload_manager_single.swap_blocks() + + hidden_states = hidden_states[num_txt_tokens:, ...] + return hidden_states + + def infer(self, block_weights, pre_infer_out): + return self.infer_func(block_weights, pre_infer_out) diff --git a/lightx2v/models/networks/flux2/model.py b/lightx2v/models/networks/flux2/model.py index a0e09402f..998465a3a 100644 --- a/lightx2v/models/networks/flux2/model.py +++ b/lightx2v/models/networks/flux2/model.py @@ -3,6 +3,7 @@ from torch.nn import functional as F from lightx2v.models.networks.base_model import BaseTransformerModel +from lightx2v.models.networks.flux2.infer.offload.transformer_infer import Flux2OffloadTransformerInfer from lightx2v.models.networks.flux2.infer.post_infer import Flux2PostInfer from lightx2v.models.networks.flux2.infer.pre_infer import Flux2DevPreInfer, Flux2PreInfer from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer @@ -30,9 +31,13 @@ def _init_infer(self): self.transformer_infer = self.transformer_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config) self.post_infer = self.post_infer_class(self.config) - if hasattr(self.transformer_infer, "offload_manager"): + if hasattr(self.transformer_infer, "offload_manager_double") and hasattr(self.transformer_infer, "offload_manager_single"): self._init_offload_manager() + def _init_offload_manager(self): + self.transformer_infer.offload_manager_double.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_double_block_cuda_buffers) + self.transformer_infer.offload_manager_single.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_single_block_cuda_buffers) + @torch.no_grad() def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True, txt_ids=None, img_ids=None): self.scheduler.infer_condition = infer_condition @@ -98,7 +103,10 @@ class Flux2KleinTransformerModel(_Flux2TransformerModelBase): pre_weight_class = Flux2PreWeights def _init_infer_class(self): - self.transformer_infer_class = Flux2TransformerInfer + if self.cpu_offload and self.offload_granularity == "block": + self.transformer_infer_class = Flux2OffloadTransformerInfer + else: + self.transformer_infer_class = Flux2TransformerInfer self.pre_infer_class = Flux2PreInfer self.post_infer_class = Flux2PostInfer @@ -106,7 +114,12 @@ def _init_infer_class(self): @torch.no_grad() def infer(self, inputs): if self.cpu_offload: - self.to_cuda() + if self.offload_granularity == "model" and self.scheduler.step_index == 0: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.post_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() latents = self.scheduler.latents do_cfg = self.config.get("enable_cfg", True) and self.config.get("sample_guide_scale", 1.0) > 1.0 @@ -191,6 +204,14 @@ def infer(self, inputs): ) self.scheduler.noise_pred = noise_pred + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.post_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() + class Flux2DevTransformerModel(_Flux2TransformerModelBase): """Flux2 Dev transformer: single forward pass with embedded guidance (no CFG).""" @@ -198,7 +219,10 @@ class Flux2DevTransformerModel(_Flux2TransformerModelBase): pre_weight_class = Flux2DevPreWeights def _init_infer_class(self): - self.transformer_infer_class = Flux2TransformerInfer + if self.cpu_offload and self.offload_granularity == "block": + self.transformer_infer_class = Flux2OffloadTransformerInfer + else: + self.transformer_infer_class = Flux2TransformerInfer self.pre_infer_class = Flux2DevPreInfer self.post_infer_class = Flux2PostInfer @@ -206,7 +230,12 @@ def _init_infer_class(self): @torch.no_grad() def infer(self, inputs): if self.cpu_offload: - self.to_cuda() + if self.offload_granularity == "model" and self.scheduler.step_index == 0: + self.to_cuda() + elif self.offload_granularity != "model": + self.pre_weight.to_cuda() + self.post_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() latents = self.scheduler.latents txt_ids = inputs["text_encoder_output"].get("text_ids", None) @@ -220,3 +249,11 @@ def infer(self, inputs): img_ids=img_ids, ) self.scheduler.noise_pred = noise_pred + + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.post_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() diff --git a/lightx2v/models/networks/flux2/weights/transformer_weights.py b/lightx2v/models/networks/flux2/weights/transformer_weights.py index 870f1a4cb..da3ed9141 100644 --- a/lightx2v/models/networks/flux2/weights/transformer_weights.py +++ b/lightx2v/models/networks/flux2/weights/transformer_weights.py @@ -5,7 +5,7 @@ class Flux2DoubleBlockWeights(WeightModule): """Weights for a single double-stream transformer block.""" - def __init__(self, config, block_idx): + def __init__(self, config, block_idx, create_cuda_buffer=False, create_cpu_buffer=False): super().__init__() self.config = config self.block_idx = block_idx @@ -20,30 +20,43 @@ def __init__(self, config, block_idx): "to_q", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_q.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "to_k", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_k.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "to_v", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_v.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "norm_q", RMS_WEIGHT_REGISTER[self.rms_norm_type]( f"{p}.attn.norm_q.weight", + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "norm_k", RMS_WEIGHT_REGISTER[self.rms_norm_type]( f"{p}.attn.norm_k.weight", + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -51,30 +64,43 @@ def __init__(self, config, block_idx): "add_q_proj", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.add_q_proj.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "add_k_proj", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.add_k_proj.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "add_v_proj", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.add_v_proj.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "norm_added_q", RMS_WEIGHT_REGISTER[self.rms_norm_type]( f"{p}.attn.norm_added_q.weight", + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "norm_added_k", RMS_WEIGHT_REGISTER[self.rms_norm_type]( f"{p}.attn.norm_added_k.weight", + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -82,12 +108,18 @@ def __init__(self, config, block_idx): "to_out", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_out.0.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "to_add_out", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_add_out.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -103,12 +135,18 @@ def __init__(self, config, block_idx): "ff_net_0", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.ff.linear_in.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "ff_net_2", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.ff.linear_out.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -116,12 +154,18 @@ def __init__(self, config, block_idx): "ff_context_net_0", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.ff_context.linear_in.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "ff_context_net_2", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.ff_context.linear_out.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -139,7 +183,7 @@ def to_cpu(self, non_blocking=True): class Flux2SingleBlockWeights(WeightModule): """Weights for a single single-stream transformer block.""" - def __init__(self, config, block_idx): + def __init__(self, config, block_idx, create_cuda_buffer=False, create_cpu_buffer=False): super().__init__() self.config = config self.block_idx = block_idx @@ -154,6 +198,9 @@ def __init__(self, config, block_idx): "to_qkv_mlp_proj", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_qkv_mlp_proj.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -161,12 +208,16 @@ def __init__(self, config, block_idx): "norm_q", RMS_WEIGHT_REGISTER[self.rms_norm_type]( f"{p}.attn.norm_q.weight", + create_cuda_buffer, + create_cpu_buffer, ), ) self.add_module( "norm_k", RMS_WEIGHT_REGISTER[self.rms_norm_type]( f"{p}.attn.norm_k.weight", + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -174,6 +225,9 @@ def __init__(self, config, block_idx): "to_out", MM_WEIGHT_REGISTER[self.mm_type]( f"{p}.attn.to_out.weight", + None, + create_cuda_buffer, + create_cpu_buffer, ), ) @@ -210,6 +264,7 @@ def __init__(self, config): self.double_blocks = WeightModuleList([Flux2DoubleBlockWeights(config, i) for i in range(self.num_layers)]) self.single_blocks = WeightModuleList([Flux2SingleBlockWeights(config, i) for i in range(self.num_single_layers)]) + self.register_offload_buffers(config) self.add_module("double_blocks", self.double_blocks) self.add_module("single_blocks", self.single_blocks) @@ -233,23 +288,37 @@ def __init__(self, config): ), ) + def register_offload_buffers(self, config): + if config.get("cpu_offload", False) and config.get("offload_granularity", "block") == "block": + self.offload_double_block_cuda_buffers = WeightModuleList([Flux2DoubleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)]) + self.add_module("offload_double_block_cuda_buffers", self.offload_double_block_cuda_buffers) + + self.offload_single_block_cuda_buffers = WeightModuleList([Flux2SingleBlockWeights(config, i, create_cuda_buffer=True) for i in range(2)]) + self.add_module("offload_single_block_cuda_buffers", self.offload_single_block_cuda_buffers) + + def non_block_weights_to_cuda(self, non_blocking=True): + self.double_stream_modulation_img_linear.to_cuda(non_blocking=non_blocking) + self.double_stream_modulation_txt_linear.to_cuda(non_blocking=non_blocking) + self.single_stream_modulation_linear.to_cuda(non_blocking=non_blocking) + + def non_block_weights_to_cpu(self, non_blocking=True): + self.double_stream_modulation_img_linear.to_cpu(non_blocking=non_blocking) + self.double_stream_modulation_txt_linear.to_cpu(non_blocking=non_blocking) + self.single_stream_modulation_linear.to_cpu(non_blocking=non_blocking) + def to_cuda(self, non_blocking=True): for block in self.double_blocks: block.to_cuda(non_blocking=non_blocking) for block in self.single_blocks: block.to_cuda(non_blocking=non_blocking) - self.double_stream_modulation_img_linear.to_cuda(non_blocking=non_blocking) - self.double_stream_modulation_txt_linear.to_cuda(non_blocking=non_blocking) - self.single_stream_modulation_linear.to_cuda(non_blocking=non_blocking) + self.non_block_weights_to_cuda(non_blocking=non_blocking) def to_cpu(self, non_blocking=True): for block in self.double_blocks: block.to_cpu(non_blocking=non_blocking) for block in self.single_blocks: block.to_cpu(non_blocking=non_blocking) - self.double_stream_modulation_img_linear.to_cpu(non_blocking=non_blocking) - self.double_stream_modulation_txt_linear.to_cpu(non_blocking=non_blocking) - self.single_stream_modulation_linear.to_cpu(non_blocking=non_blocking) + self.non_block_weights_to_cpu(non_blocking=non_blocking) # Backward-compatible aliases diff --git a/lightx2v/models/networks/qwen_image/model.py b/lightx2v/models/networks/qwen_image/model.py index b9dceb1c2..70d195fcc 100755 --- a/lightx2v/models/networks/qwen_image/model.py +++ b/lightx2v/models/networks/qwen_image/model.py @@ -143,3 +143,10 @@ def infer(self, inputs): if self.config["task"] == "i2i": noise_pred = noise_pred[:, : latents.size(1)] self.scheduler.noise_pred = noise_pred + + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.post_weight.to_cpu() diff --git a/scripts/flux2/infer_flux2_klein_distill_offload.sh b/scripts/flux2/infer_flux2_klein_distill_offload.sh new file mode 100644 index 000000000..7da3a67d0 --- /dev/null +++ b/scripts/flux2/infer_flux2_klein_distill_offload.sh @@ -0,0 +1,15 @@ +#!/bin/bash +lightx2v_path= +model_path="/data/temp/FLUX.2-klein-9B" +export CUDA_VISIBLE_DEVICES=7 + +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls flux2_klein \ + --task t2i \ + --target_shape 1024 1024 \ + --model_path $model_path \ + --prompt "A cat holding a sign that says hello world" \ + --save_result_path "${lightx2v_path}/save_results/flux2_klein_distill_offload.png" \ + --config_json "${lightx2v_path}/configs/flux2/flux2_klein_distill_offload.json"