Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/flux2/flux2_dev_offload.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
"rope_type": "flashinfer",
"text_encoder_out_layers": [10, 20, 30],
"cpu_offload": true,
"offload_granularity": "block"
"offload_granularity": "model"
}
14 changes: 14 additions & 0 deletions configs/flux2/flux2_klein_distill_offload.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"model_cls": "flux2_klein",
"task": "t2i",
"infer_steps": 4,
"sample_guide_scale": 1.0,
"vae_scale_factor": 16,
"feature_caching": "None",
"enable_cfg": false,
"patch_size": 2,
"tokenizer_max_length": 512,
"rope_type": "flashinfer",
"cpu_offload": true,
"offload_granularity": "model"
}
127 changes: 127 additions & 0 deletions lightx2v/models/networks/flux2/infer/offload/transformer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F

from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


class Flux2OffloadTransformerInfer(Flux2TransformerInfer):
"""Flux2 transformer inference with block-level CPU offload."""

def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
self.infer_func = self.infer_with_blocks_offload
self.offload_manager_double = WeightAsyncStreamManager(offload_granularity=offload_granularity)
self.offload_manager_single = WeightAsyncStreamManager(offload_granularity=offload_granularity)
elif offload_granularity == "model":
self.infer_func = super().infer
else:
raise ValueError(f"Unsupported offload_granularity: {offload_granularity}")
else:
self.infer_func = super().infer

def infer_with_blocks_offload(self, block_weights, pre_infer_out):
hidden_states = pre_infer_out.hidden_states
encoder_hidden_states = pre_infer_out.encoder_hidden_states
timestep = pre_infer_out.timestep
image_rotary_emb = pre_infer_out.image_rotary_emb

num_txt_tokens = encoder_hidden_states.shape[0]

if self.seq_p_group is not None and image_rotary_emb is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

if isinstance(image_rotary_emb, tuple):
freqs_cos, freqs_sin = image_rotary_emb

txt_cos = freqs_cos[:num_txt_tokens]
img_cos = freqs_cos[num_txt_tokens:]
txt_sin = freqs_sin[:num_txt_tokens]
img_sin = freqs_sin[num_txt_tokens:]

seqlen = img_cos.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_cos = F.pad(img_cos, (0, 0, 0, padding_size))
img_sin = F.pad(img_sin, (0, 0, 0, padding_size))
img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank]
img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank]

freqs_cos = torch.cat([txt_cos, img_cos], dim=0)
freqs_sin = torch.cat([txt_sin, img_sin], dim=0)
image_rotary_emb = (freqs_cos, freqs_sin)
else:
txt_emb = image_rotary_emb[:num_txt_tokens]
img_emb = image_rotary_emb[num_txt_tokens:]

seqlen = img_emb.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_emb = F.pad(img_emb, (0, 0, 0, padding_size))
img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank]

image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0)
Comment on lines +38 to +71
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling image_rotary_emb and applying RoPE padding is identical to the implementation in the base class Flux2TransformerInfer.infer. To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, this logic should be refactored into a shared helper method in the base class that can be called by both the standard and offloaded inference paths.


timestep_act = F.silu(timestep)
double_stream_mod_img = block_weights.double_stream_modulation_img_linear.apply(timestep_act)
double_stream_mod_txt = block_weights.double_stream_modulation_txt_linear.apply(timestep_act)
single_stream_mod = block_weights.single_stream_modulation_linear.apply(timestep_act)

current_stream = torch_device_module.current_stream()
self.offload_manager_double.compute_stream.wait_stream(current_stream)
for block_idx in range(len(block_weights.double_blocks)):
self.block_idx = block_idx

if self.offload_manager_double.need_init_first_buffer:
self.offload_manager_double.init_first_buffer(block_weights.double_blocks)

self.offload_manager_double.prefetch_weights((block_idx + 1) % len(block_weights.double_blocks), block_weights.double_blocks)

with torch_device_module.stream(self.offload_manager_double.compute_stream):
encoder_hidden_states, hidden_states = self.infer_double_stream_block(
self.offload_manager_double.cuda_buffers[0],
hidden_states,
encoder_hidden_states,
double_stream_mod_img,
double_stream_mod_txt,
image_rotary_emb,
)

self.offload_manager_double.swap_blocks()

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)

self.offload_manager_single.compute_stream.wait_stream(self.offload_manager_double.compute_stream)
for block_idx in range(len(block_weights.single_blocks)):
self.block_idx = block_idx

if self.offload_manager_single.need_init_first_buffer:
self.offload_manager_single.init_first_buffer(block_weights.single_blocks)

self.offload_manager_single.prefetch_weights((block_idx + 1) % len(block_weights.single_blocks), block_weights.single_blocks)

with torch_device_module.stream(self.offload_manager_single.compute_stream):
hidden_states = self.infer_single_stream_block(
self.offload_manager_single.cuda_buffers[0],
hidden_states,
None,
single_stream_mod,
image_rotary_emb,
num_txt_tokens=num_txt_tokens,
)

self.offload_manager_single.swap_blocks()

hidden_states = hidden_states[num_txt_tokens:, ...]
return hidden_states

def infer(self, block_weights, pre_infer_out):
return self.infer_func(block_weights, pre_infer_out)
47 changes: 42 additions & 5 deletions lightx2v/models/networks/flux2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import functional as F

from lightx2v.models.networks.base_model import BaseTransformerModel
from lightx2v.models.networks.flux2.infer.offload.transformer_infer import Flux2OffloadTransformerInfer
from lightx2v.models.networks.flux2.infer.post_infer import Flux2PostInfer
from lightx2v.models.networks.flux2.infer.pre_infer import Flux2DevPreInfer, Flux2PreInfer
from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer
Expand Down Expand Up @@ -30,9 +31,13 @@ def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config)
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
if hasattr(self.transformer_infer, "offload_manager_double") and hasattr(self.transformer_infer, "offload_manager_single"):
self._init_offload_manager()

def _init_offload_manager(self):
self.transformer_infer.offload_manager_double.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_double_block_cuda_buffers)
self.transformer_infer.offload_manager_single.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_single_block_cuda_buffers)

@torch.no_grad()
def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True, txt_ids=None, img_ids=None):
self.scheduler.infer_condition = infer_condition
Expand Down Expand Up @@ -98,15 +103,23 @@ class Flux2KleinTransformerModel(_Flux2TransformerModelBase):
pre_weight_class = Flux2PreWeights

def _init_infer_class(self):
self.transformer_infer_class = Flux2TransformerInfer
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
self.pre_infer_class = Flux2PreInfer
self.post_infer_class = Flux2PostInfer

@compiled_method()
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
self.to_cuda()
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()

latents = self.scheduler.latents
do_cfg = self.config.get("enable_cfg", True) and self.config.get("sample_guide_scale", 1.0) > 1.0
Expand Down Expand Up @@ -191,22 +204,38 @@ def infer(self, inputs):
)
self.scheduler.noise_pred = noise_pred

if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()


class Flux2DevTransformerModel(_Flux2TransformerModelBase):
"""Flux2 Dev transformer: single forward pass with embedded guidance (no CFG)."""

pre_weight_class = Flux2DevPreWeights

def _init_infer_class(self):
self.transformer_infer_class = Flux2TransformerInfer
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
self.pre_infer_class = Flux2DevPreInfer
self.post_infer_class = Flux2PostInfer

@compiled_method()
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
self.to_cuda()
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()

latents = self.scheduler.latents
txt_ids = inputs["text_encoder_output"].get("text_ids", None)
Expand All @@ -220,3 +249,11 @@ def infer(self, inputs):
img_ids=img_ids,
)
self.scheduler.noise_pred = noise_pred

if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()
Loading
Loading