Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions modules/modelLoader/mixin/LoRALoaderMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,48 @@ def __init__(self):
def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None:
pass

@staticmethod
def scale_lora_state_dict(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of applying a scale to the loaded LoRA?

  • I can't think of many situations this could be useful
  • if it is useful, you could achieve the same by converting a LoRA to lower weights (for example through a simple Comfy workflow). Should this be in OneTrainer?
  • I don't see the relation to distillation. if I'm wrong with my earlier two points, I think it should be a separate PR

state_dict: dict,
te_scale: float = 1.0,
unet_scale: float = 1.0,
) -> dict:
"""
Scales LoRA weights for Text Encoder and main component (UNet/Transformer) separately.

Args:
state_dict: The LoRA state dict to scale
te_scale: Scale factor for Text Encoder LoRA weights (default 1.0, applies to lora_te*)
unet_scale: Scale factor for main component LoRA weights (default 1.0, applies to everything else)

Returns:
The scaled state dict
"""
scaled_dict = {}

weight_suffixes = (
".weight",
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"lokr_w1",
"lokr_w2",
"lokr_t1",
"lokr_t2",
)

for key, value in state_dict.items():
is_weight = isinstance(value, torch.Tensor) and key.endswith(weight_suffixes)
if key.startswith("lora_te"):
# Text Encoder LoRA (matches lora_te, lora_te1, lora_te2, etc.)
scaled_dict[key] = value * te_scale if is_weight else value
else:
# Other components: unet, transformer, prior, decoder, etc.
scaled_dict[key] = value * unet_scale if is_weight else value

return scaled_dict

def __load_safetensors(
self,
model: BaseModel,
Expand Down
61 changes: 60 additions & 1 deletion modules/modelSetup/BaseModelSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from contextlib import contextmanager

from modules.model.BaseModel import BaseModel
from modules.module.ParentModelWrapper import ParentModelWrapper
from modules.util.config.TrainConfig import TrainConfig, TrainEmbeddingConfig, TrainModelPartConfig
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModuleFilter import ModuleFilter
from modules.util.NamedParameterGroup import NamedParameterGroup, NamedParameterGroupCollection
from modules.util.TimedActionMixin import TimedActionMixin
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

import torch
Expand Down Expand Up @@ -73,6 +75,7 @@ def predict(
train_progress: TrainProgress,
*,
deterministic: bool = False,
generate_distillation_empty: bool = False,
) -> dict:
pass

Expand Down Expand Up @@ -177,6 +180,62 @@ def prior_model(self, model: BaseModel, config: TrainConfig):
for adapter in model.adapters():
adapter.hook_to_module()

@contextmanager
def distillation_parent_model(
self,
model: BaseModel,
config: TrainConfig,
parent_wrapper: ParentModelWrapper | None,
):
"""
Context manager for distillation with external parent model.

If parent_wrapper is provided and distillation is enabled, loads the parent
model temporarily to train_device for inference. Otherwise falls back to
prior_model behavior (unhooking LoRA adapters).

Args:
model: Student model being trained
config: Training configuration
parent_wrapper: Optional wrapper containing parent model

Yields:
Parent model if available, otherwise the student model with adapters unhooked
"""
if parent_wrapper is None or not config.distillation.enabled:
# Fallback to prior_model behavior
with self.prior_model(model, config):
yield model
return

# Load parent model if not already loaded
if not parent_wrapper.is_loaded():
parent_wrapper.load_parent_model()

# Memory optimization: Swap models if keeping parent on CPU
# Move student model to CPU before loading parent to GPU to reduce peak VRAM
student_was_moved = False
if config.distillation.keep_parent_on_cpu:
model.to(self.temp_device)
student_was_moved = True
torch_gc()

# Move parent to train_device temporarily
parent_wrapper.to_device(self.train_device)

try:
yield parent_wrapper.parent_model
finally:
# Move parent back to temp_device (CPU)
if config.distillation.keep_parent_on_cpu:
parent_wrapper.to_device(self.temp_device)
torch_gc()

# Move student model back to train_device
if student_was_moved:
model.to(self.train_device)
torch_gc()

def _create_model_part_parameters(
self,
parameter_group_collection: NamedParameterGroupCollection,
Expand All @@ -186,7 +245,7 @@ def _create_model_part_parameters(
freeze: list[ModuleFilter] | None = None,
debug: bool = False,
):
if not config.train:
if not config.train or model is None:
return

if freeze is not None and len(freeze) > 0:
Expand Down
44 changes: 44 additions & 0 deletions modules/modelSetup/BaseStableDiffusionXLSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def predict(
train_progress: TrainProgress,
*,
deterministic: bool = False,
generate_distillation_empty: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
Expand Down Expand Up @@ -370,6 +371,49 @@ def predict(
)

model_output_data['prediction_type'] = model.noise_scheduler.config.prediction_type

# For CFG_DISTILL: Generate empty prompt prediction
if generate_distillation_empty and config.distillation.enabled \
and config.distillation.target_mode.value == 'CFG_DISTILL':
with torch.no_grad():
# Create empty text embeddings (as unconditional guidance)
empty_text_encoder_output, empty_pooled_text_encoder_2_output = model.combine_text_encoder_output(
*model.encode_text(
train_device=self.train_device,
batch_size=batch['latent_image'].shape[0],
rand=rand,
text="",
tokens_1=None,
tokens_2=None,
text_encoder_1_layer_skip=config.text_encoder_layer_skip,
text_encoder_2_layer_skip=config.text_encoder_2_layer_skip,
text_encoder_1_output=None,
text_encoder_2_output=None,
pooled_text_encoder_2_output=None,
text_encoder_1_dropout_probability=0.0,
text_encoder_2_dropout_probability=0.0,
)
)

# Create latent input (same structure, but with empty conditioning)
if config.model_type.has_mask_input() and config.model_type.has_conditioning_image_input():
empty_latent_input = torch.concat(
[scaled_noisy_latent_image, batch['latent_mask'], scaled_latent_conditioning_image], 1
)
else:
empty_latent_input = scaled_noisy_latent_image

# Run UNet with empty conditioning
empty_added_cond_kwargs = {"text_embeds": empty_pooled_text_encoder_2_output, "time_ids": add_time_ids}
predicted_latent_noise_empty = model.unet(
sample=empty_latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep=timestep,
encoder_hidden_states=empty_text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
added_cond_kwargs=empty_added_cond_kwargs,
).sample

model_output_data['predicted_empty'] = predicted_latent_noise_empty

return model_output_data

def calculate_loss(
Expand Down
10 changes: 9 additions & 1 deletion modules/modelSetup/StableDiffusionLoRASetup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from modules.model.StableDiffusionModel import StableDiffusionModel
from modules.modelSetup.BaseModelSetup import BaseModelSetup
from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup
from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin
from modules.module.LoRAModule import LoRAModuleWrapper
from modules.util import factory
Comment on lines +4 to 6
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the state_dict_has_prefix(...) logic, the state_dict_has_prefix import in this file is now unused and will trigger Ruff F401. Remove that import (or reintroduce usage if still intended).

Copilot uses AI. Check for mistakes.
from modules.util.config.TrainConfig import TrainConfig
Expand Down Expand Up @@ -69,7 +70,7 @@ def setup_model(
if config.train_any_embedding():
model.text_encoder.get_input_embeddings().to(dtype=config.embedding_weight_dtype.torch_dtype())

create_te = config.text_encoder.train or state_dict_has_prefix(model.lora_state_dict, "lora_te")
create_te = config.text_encoder.train
model.text_encoder_lora = LoRAModuleWrapper(
model.text_encoder, "lora_te", config
) if create_te else None
Expand All @@ -79,6 +80,13 @@ def setup_model(
)

if model.lora_state_dict:
# Apply scaling factors to LoRA weights before loading
model.lora_state_dict = LoRALoaderMixin.scale_lora_state_dict(
model.lora_state_dict,
te_scale=config.lora_te_scale,
unet_scale=config.lora_unet_scale,
)

if create_te:
model.text_encoder_lora.load_state_dict(model.lora_state_dict)
model.unet_lora.load_state_dict(model.lora_state_dict)
Expand Down
12 changes: 10 additions & 2 deletions modules/modelSetup/StableDiffusionXLLoRASetup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from modules.model.StableDiffusionXLModel import StableDiffusionXLModel
from modules.modelSetup.BaseModelSetup import BaseModelSetup
from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup
from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin
from modules.module.LoRAModule import LoRAModuleWrapper
from modules.util import factory
Comment on lines 3 to 6
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the state_dict_has_prefix(...) logic, the state_dict_has_prefix import in this file is now unused and will trigger Ruff F401. Remove that import (or reintroduce usage if still intended).

Copilot uses AI. Check for mistakes.
from modules.util.config.TrainConfig import TrainConfig
Expand Down Expand Up @@ -76,8 +77,8 @@ def setup_model(
model: StableDiffusionXLModel,
config: TrainConfig,
):
create_te1 = config.text_encoder.train or state_dict_has_prefix(model.lora_state_dict, "lora_te1")
create_te2 = config.text_encoder_2.train or state_dict_has_prefix(model.lora_state_dict, "lora_te2")
create_te1 = config.text_encoder.train
create_te2 = config.text_encoder_2.train

model.text_encoder_1_lora = LoRAModuleWrapper(
model.text_encoder_1, "lora_te1", config
Expand All @@ -92,6 +93,13 @@ def setup_model(
)

if model.lora_state_dict:
# Apply scaling factors to LoRA weights before loading
model.lora_state_dict = LoRALoaderMixin.scale_lora_state_dict(
model.lora_state_dict,
te_scale=config.lora_te_scale,
unet_scale=config.lora_unet_scale,
)

if create_te1:
model.text_encoder_1_lora.load_state_dict(model.lora_state_dict)
if create_te2:
Expand Down
17 changes: 16 additions & 1 deletion modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from modules.util.config.TrainConfig import TrainConfig
from modules.util.DiffusionScheduleCoefficients import DiffusionScheduleCoefficients
from modules.util.enum.LossWeight import LossWeight
from modules.util.loss.masked_loss import masked_losses, masked_losses_with_prior
from modules.util.loss.masked_loss import masked_losses, masked_losses_with_prior, distillation_loss
from modules.util.loss.vb_loss import vb_losses

import torch
Expand Down Expand Up @@ -134,6 +134,21 @@ def __masked_losses(
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean(mean_dim) * config.vb_loss_strength

# Distillation loss
if config.distillation.enabled and 'prior_target' in data and 'distillation_indices' in data:
distillation_indices = data['distillation_indices']
if len(distillation_indices) > 0:
# Calculate distillation loss only for samples marked as DISTILLATION
dist_loss = distillation_loss(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think distillation needs a new loss function (except maybe: KL loss - see my other comment below)
distillation is a modification of the training target, not of the loss function per se.

We already modify the training target without modifying the loss function for PRIOR_PREDICTION here:

model_output_data['target'][prior_pred_indices] = prior_model_prediction[prior_pred_indices]

or in my transfer training PR here:
https://github.com/dxqb/OneTrainer/blob/38708bdfb31c608c51d73faec4f761f6fb6b037b/modules/trainer/GenericTrainer.py#L772

it is not necessary for these, and I don't think for your PR, to modify the loss function. it might look like this is necessary because there is some prior prediction code in the loss function handling. However, this is for masked prior prediction which is an entirely different functionality.

student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32),
parent_prediction=data['prior_target'][distillation_indices].to(dtype=torch.float32),
loss_type=config.distillation.loss_type,
temperature=config.distillation.kl_temperature,
mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None,
reduction='mean',
)
losses += dist_loss * config.distillation.loss_weight

Comment on lines +137 to +151
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distillation loss is only added inside __masked_losses(). When config.masked_training is false, the training loop uses __unmasked_losses() and distillation has no effect. Distillation should be applied in both masked and unmasked loss paths (with masking applied conditionally).

Copilot uses AI. Check for mistakes.
return losses

def __unmasked_losses(
Expand Down
Loading