Draft: Add LoRA weight scaling and CFG distillation functionality#1360
Draft: Add LoRA weight scaling and CFG distillation functionality#1360m4xw wants to merge 7 commits intoNerogar:masterfrom
Conversation
…s. Implement Injecting of TE or UNet Implement Simple Distillation
There was a problem hiding this comment.
Pull request overview
Adds initial infrastructure for distillation training (teacher/parent model management + distillation loss wiring) and improves LoRA loading by supporting separate TE vs UNet scaling and simplifying TE LoRA module creation.
Changes:
- Introduces
ParentModelWrapperand integrates a distillation parent-model context manager into the trainer/setup. - Adds a distillation loss helper and attempts to integrate distillation loss into diffusion loss computation.
- Adds LoRA state-dict scaling (TE vs UNet) and exposes scaling controls in config + UI.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| modules/util/loss/masked_loss.py | Adds distillation_loss() helper and new enum dependency. |
| modules/util/enum/DistillationLossType.py | New enum for distillation loss selection. |
| modules/util/enum/ConceptType.py | Adds DISTILLATION concept type. |
| modules/util/config/TrainConfig.py | Adds DistillationConfig, LoRA scaling config fields, and config migrations/version bump. |
| modules/ui/TrainingTab.py | Adds a distillation settings UI frame. |
| modules/ui/LoraTab.py | Adds TE/UNet LoRA load scaling UI fields. |
| modules/trainer/GenericTrainer.py | Integrates parent model wrapper and attempts to generate/store teacher outputs for distillation. |
| modules/module/ParentModelWrapper.py | New wrapper to load/quantize/offload the teacher model. |
| modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py | Attempts to add distillation loss into diffusion loss computation. |
| modules/modelSetup/StableDiffusionXLLoRASetup.py | Applies LoRA scaling before loading and simplifies TE module creation logic. |
| modules/modelSetup/StableDiffusionLoRASetup.py | Applies LoRA scaling before loading and simplifies TE module creation logic. |
| modules/modelSetup/BaseModelSetup.py | Adds distillation_parent_model() context manager; minor guard in parameter creation. |
| modules/modelLoader/mixin/LoRALoaderMixin.py | Adds scale_lora_state_dict() helper. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| @@ -0,0 +1,188 @@ | |||
| import torch | |||
| from abc import ABCMeta | |||
There was a problem hiding this comment.
ABCMeta is imported but never used, which will trigger Ruff F401. Remove the unused import.
| from abc import ABCMeta |
| if config.distillation.enabled and 'parent_target' in data and 'distillation_indices' in data: | ||
| distillation_indices = data['distillation_indices'] | ||
| if len(distillation_indices) > 0: | ||
| # Calculate distillation loss only for samples marked as DISTILLATION | ||
| dist_loss = distillation_loss( | ||
| student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32), | ||
| parent_prediction=data['parent_target'][distillation_indices].to(dtype=torch.float32), | ||
| loss_type=config.distillation.loss_type, | ||
| temperature=config.distillation.kl_temperature, | ||
| mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None, | ||
| reduction='mean', | ||
| ) | ||
| losses += dist_loss * config.distillation.loss_weight | ||
|
|
There was a problem hiding this comment.
Distillation loss is guarded by 'parent_target' in data, but the trainer populates model_output_data['prior_target'] (and never sets parent_target). As a result, distillation loss will never run. Use a consistent key (e.g., set parent_target in the trainer and read that here, or reuse prior_target intentionally).
| if config.distillation.enabled and 'parent_target' in data and 'distillation_indices' in data: | |
| distillation_indices = data['distillation_indices'] | |
| if len(distillation_indices) > 0: | |
| # Calculate distillation loss only for samples marked as DISTILLATION | |
| dist_loss = distillation_loss( | |
| student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32), | |
| parent_prediction=data['parent_target'][distillation_indices].to(dtype=torch.float32), | |
| loss_type=config.distillation.loss_type, | |
| temperature=config.distillation.kl_temperature, | |
| mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None, | |
| reduction='mean', | |
| ) | |
| losses += dist_loss * config.distillation.loss_weight | |
| if config.distillation.enabled and 'distillation_indices' in data: | |
| # Prefer explicitly-provided parent_target, fall back to prior_target if available | |
| parent_key = None | |
| if 'parent_target' in data: | |
| parent_key = 'parent_target' | |
| elif 'prior_target' in data: | |
| parent_key = 'prior_target' | |
| if parent_key is not None: | |
| distillation_indices = data['distillation_indices'] | |
| if len(distillation_indices) > 0: | |
| # Calculate distillation loss only for samples marked as DISTILLATION | |
| dist_loss = distillation_loss( | |
| student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32), | |
| parent_prediction=data[parent_key][distillation_indices].to(dtype=torch.float32), | |
| loss_type=config.distillation.loss_type, | |
| temperature=config.distillation.kl_temperature, | |
| mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None, | |
| reduction='mean', | |
| ) | |
| losses += dist_loss * config.distillation.loss_weight |
| # Run parent/prior model if needed for either prior prediction or distillation | ||
| if len(prior_pred_indices) > 0 or len(distillation_indices) > 0 \ | ||
| or (self.config.masked_training | ||
| and self.config.masked_prior_preservation_weight > 0 | ||
| and self.config.training_method == TrainingMethod.LORA): | ||
| with self.model_setup.prior_model(self.model, self.config), torch.no_grad(): | ||
| #do NOT create a subbatch using the indices, even though it would be more efficient: | ||
| #different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly: | ||
| prior_model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress) | ||
| with self.model_setup.distillation_parent_model( | ||
| self.model, | ||
| self.config, | ||
| self.parent_model_wrapper | ||
| ) as parent_model, torch.no_grad(): | ||
| # Do NOT create a subbatch using the indices, even though it would be more efficient: | ||
| # Different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly | ||
| parent_model_output_data = self.model_setup.predict( | ||
| parent_model if parent_model != self.model else self.model, | ||
| batch, | ||
| self.config, | ||
| train_progress | ||
| ) | ||
|
|
||
| # Run student model | ||
| model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress) | ||
| prior_model_prediction = prior_model_output_data['predicted'].to(dtype=model_output_data['target'].dtype) | ||
| model_output_data['target'][prior_pred_indices] = prior_model_prediction[prior_pred_indices] | ||
|
|
||
| # Get parent predictions | ||
| prior_model_prediction = parent_model_output_data['predicted'].to(dtype=model_output_data['target'].dtype) | ||
|
|
||
| # For prior_prediction: Replace target (legacy behavior) | ||
| if len(prior_pred_indices) > 0: | ||
| model_output_data['target'][prior_pred_indices] = prior_model_prediction[prior_pred_indices] | ||
|
|
||
| # Store parent prediction for masked prior preservation and distillation | ||
| model_output_data['prior_target'] = prior_model_prediction | ||
|
|
||
| # For distillation: Store indices for loss calculation | ||
| if len(distillation_indices) > 0: | ||
| model_output_data['distillation_indices'] = distillation_indices |
There was a problem hiding this comment.
prior_model_prediction is computed from parent_model_output_data (teacher when distillation is enabled) and then stored into model_output_data['prior_target']. This changes the meaning of prior_target (used by masked prior preservation losses) from “student base model without LoRA” to “external parent model”, altering existing training behavior whenever distillation is enabled. Store teacher output under a separate key (e.g., parent_target) and continue computing prior_target via prior_model() (unhooked adapters) when masked prior preservation / PRIOR_PREDICTION is needed.
| # For KL divergence on continuous predictions (e.g., latents) | ||
| # We use a log-normal approximation: | ||
| # Treat predictions as means of distributions, compute KL between Gaussians | ||
| # KL(P||Q) = 0.5 * (||mu_p - mu_q||^2 / sigma^2) | ||
| # For simplicity, assume unit variance (sigma=1) | ||
| # Temperature scaling: divide by temperature before computing | ||
| student_scaled = student_prediction / temperature | ||
| parent_scaled = parent_prediction / temperature | ||
|
|
||
| # MSE-based KL approximation with temperature scaling | ||
| loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2) |
There was a problem hiding this comment.
DistillationLossType.KL_DIVERGENCE currently computes an MSE between scaled predictions (a Gaussian KL up to a constant factor), not an actual KL divergence. This makes the enum name/UI option misleading and makes KL and MSE nearly equivalent (aside from temperature). Either rename it to reflect the implemented metric (e.g., Gaussian_KL / scaled_MSE) or implement a true KL for the intended distribution.
| # For KL divergence on continuous predictions (e.g., latents) | |
| # We use a log-normal approximation: | |
| # Treat predictions as means of distributions, compute KL between Gaussians | |
| # KL(P||Q) = 0.5 * (||mu_p - mu_q||^2 / sigma^2) | |
| # For simplicity, assume unit variance (sigma=1) | |
| # Temperature scaling: divide by temperature before computing | |
| student_scaled = student_prediction / temperature | |
| parent_scaled = parent_prediction / temperature | |
| # MSE-based KL approximation with temperature scaling | |
| loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2) | |
| # KL divergence for distillation on logit-like predictions. | |
| # Standard formulation: | |
| # KL(P_teacher || P_student) with | |
| # P_teacher = softmax(parent / T), P_student = softmax(student / T) | |
| # implemented as: | |
| # kl_div(log_softmax(student / T), softmax(parent / T)). | |
| # As in common distillation practice, we multiply by T^2 to preserve | |
| # gradient magnitudes when using temperature scaling. | |
| student_log_probs = F.log_softmax(student_prediction / temperature, dim=-1) | |
| parent_probs = F.softmax(parent_prediction / temperature, dim=-1) | |
| loss = F.kl_div(student_log_probs, parent_probs, reduction='none') * (temperature ** 2) |
| #("NFloat4", DataType.NFLOAT_4), | ||
| #("Int8", DataType.INT_8), |
There was a problem hiding this comment.
The distillation quantization dtype dropdown omits NFLOAT_4 / INT_8, but DistillationConfig.default_values() sets parent_quantization_dtype=NFLOAT_4. components.options_kv() will treat this as invalid and silently force the value to the first option (None), so the UI won’t reflect config defaults. Either include the default dtype(s) in the options list or change the config default to match what the UI supports.
| #("NFloat4", DataType.NFLOAT_4), | |
| #("Int8", DataType.INT_8), | |
| ("NFloat4", DataType.NFLOAT_4), | |
| ("Int8", DataType.INT_8), |
There was a problem hiding this comment.
For the record those types are broken in this setup for some reason
| # Distillation loss | ||
| if config.distillation.enabled and 'parent_target' in data and 'distillation_indices' in data: | ||
| distillation_indices = data['distillation_indices'] | ||
| if len(distillation_indices) > 0: | ||
| # Calculate distillation loss only for samples marked as DISTILLATION | ||
| dist_loss = distillation_loss( | ||
| student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32), | ||
| parent_prediction=data['parent_target'][distillation_indices].to(dtype=torch.float32), | ||
| loss_type=config.distillation.loss_type, | ||
| temperature=config.distillation.kl_temperature, | ||
| mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None, | ||
| reduction='mean', | ||
| ) | ||
| losses += dist_loss * config.distillation.loss_weight | ||
|
|
There was a problem hiding this comment.
Distillation loss is only added inside __masked_losses(). When config.masked_training is false, the training loop uses __unmasked_losses() and distillation has no effect. Distillation should be applied in both masked and unmasked loss paths (with masking applied conditionally).
| # Run parent/prior model if needed for either prior prediction or distillation | ||
| if len(prior_pred_indices) > 0 or len(distillation_indices) > 0 \ | ||
| or (self.config.masked_training | ||
| and self.config.masked_prior_preservation_weight > 0 | ||
| and self.config.training_method == TrainingMethod.LORA): | ||
| with self.model_setup.prior_model(self.model, self.config), torch.no_grad(): | ||
| #do NOT create a subbatch using the indices, even though it would be more efficient: | ||
| #different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly: | ||
| prior_model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress) | ||
| with self.model_setup.distillation_parent_model( | ||
| self.model, | ||
| self.config, | ||
| self.parent_model_wrapper | ||
| ) as parent_model, torch.no_grad(): | ||
| # Do NOT create a subbatch using the indices, even though it would be more efficient: | ||
| # Different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly | ||
| parent_model_output_data = self.model_setup.predict( | ||
| parent_model if parent_model != self.model else self.model, | ||
| batch, | ||
| self.config, | ||
| train_progress | ||
| ) |
There was a problem hiding this comment.
When distillation is enabled, the code runs distillation_parent_model(...) even for ConceptType.PRIOR_PREDICTION samples. That means PRIOR_PREDICTION targets come from the external parent model instead of the “model as it was before training” (as described in the UI/help text), which is a behavioral break. Consider running PRIOR_PREDICTION via prior_model() (unhook LoRA adapters) regardless of distillation, and only use the external parent model for ConceptType.DISTILLATION samples.
There was a problem hiding this comment.
For the record: This was wanted behaviour in my testing phase
| # Apply mask if provided | ||
| if mask is not None: | ||
| # Use same masking logic as other losses | ||
| clamped_mask = torch.clamp(mask, 0.1, 1.0) # Use default unmasked_weight | ||
| loss = loss * clamped_mask | ||
|
|
There was a problem hiding this comment.
distillation_loss() applies a mask by clamping with a hard-coded 0.1 (unmasked_weight). This bypasses the configurable TrainConfig.unmasked_weight used everywhere else, so masked distillation behaves inconsistently with other losses. Pass unmasked_weight into this helper (or clamp using the caller’s configured value).
| elif loss_type == DistillationLossType.KL_DIVERGENCE: | ||
| # For KL divergence on continuous predictions (e.g., latents) | ||
| # We use a log-normal approximation: | ||
| # Treat predictions as means of distributions, compute KL between Gaussians | ||
| # KL(P||Q) = 0.5 * (||mu_p - mu_q||^2 / sigma^2) | ||
| # For simplicity, assume unit variance (sigma=1) | ||
| # Temperature scaling: divide by temperature before computing | ||
| student_scaled = student_prediction / temperature | ||
| parent_scaled = parent_prediction / temperature | ||
|
|
||
| # MSE-based KL approximation with temperature scaling | ||
| loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2) |
There was a problem hiding this comment.
temperature is used as a divisor for the KL/scaled-MSE path but is not validated. If the config/UI allows setting it to 0 (or a negative value), this will produce inf/NaNs. Add an explicit check that temperature > 0 (and raise a clear ValueError) before dividing.
| from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup | ||
| from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin | ||
| from modules.module.LoRAModule import LoRAModuleWrapper | ||
| from modules.util import factory |
There was a problem hiding this comment.
After removing the state_dict_has_prefix(...) logic, the state_dict_has_prefix import in this file is now unused and will trigger Ruff F401. Remove that import (or reintroduce usage if still intended).
|
Guess the biggest issue seems i still mess around some prior_prediction stuff in distillation codepaths, gonna fix that when i get home from work and maybe also implement the loss properly / unsimplified |
TODO: Split prior_target into parent_target for type DISTILLATION
… and add cache configuration options
|
Implemented a crude prediction cache which doubles the perf on consumer GPU (but looks like theres a small leak introduced from the cache) |
dxqb
left a comment
There was a problem hiding this comment.
Here are some early comments, not meant as a full review.
I did not read Copilot's comments here. Could you clean those up, and only leave open those that you plan to work on?
| distillation_indices = data['distillation_indices'] | ||
| if len(distillation_indices) > 0: | ||
| # Calculate distillation loss only for samples marked as DISTILLATION | ||
| dist_loss = distillation_loss( |
There was a problem hiding this comment.
I don't think distillation needs a new loss function (except maybe: KL loss - see my other comment below)
distillation is a modification of the training target, not of the loss function per se.
We already modify the training target without modifying the loss function for PRIOR_PREDICTION here:
OneTrainer/modules/trainer/GenericTrainer.py
Line 744 in 45b6d9a
or in my transfer training PR here:
https://github.com/dxqb/OneTrainer/blob/38708bdfb31c608c51d73faec4f761f6fb6b037b/modules/trainer/GenericTrainer.py#L772
it is not necessary for these, and I don't think for your PR, to modify the loss function. it might look like this is necessary because there is some prior prediction code in the loss function handling. However, this is for masked prior prediction which is an entirely different functionality.
| pass | ||
|
|
||
| @staticmethod | ||
| def scale_lora_state_dict( |
There was a problem hiding this comment.
What is the purpose of applying a scale to the loaded LoRA?
- I can't think of many situations this could be useful
- if it is useful, you could achieve the same by converting a LoRA to lower weights (for example through a simple Comfy workflow). Should this be in OneTrainer?
- I don't see the relation to distillation. if I'm wrong with my earlier two points, I think it should be a separate PR
| Hash-based cache key | ||
| """ | ||
| # Combine image path and timestep for uniqueness | ||
| key_string = f"{image_path}_{timestep}" |
There was a problem hiding this comment.
image path and timestep isn't a sufficient cache key, because we have variations
| from torch import Tensor | ||
|
|
||
|
|
||
| class DistillationCacheManager: |
There was a problem hiding this comment.
I'm not sure this needs to be designed as a cache.
For distillation, you first have to make a prediction by the teacher model. Then you set that prediciton as the training target of the student model.
To avoid having both models in ram/vram at the same time, you can separate these two steps and save the teacher predictions. But does it have to be a cache, with keys, hits and misses?
Or do we simply write parent predictions sequentially to disk, and then read them sequentially later?
| quantized_modules = [] | ||
|
|
||
| # Step 1: Replace linear layers with quantized versions | ||
| for attr_name in dir(self.parent_model): |
There was a problem hiding this comment.
I understand this PR is a draft, but a final PR shouldn't have this amount of code duplication (from quantization_util.py)
| # Parent Model Type | ||
| # Filter only Stable Diffusion family models | ||
| sd_model_types = [ | ||
| ("SD 1.5", ModelType.STABLE_DIFFUSION_15), |
|
|
||
| return migrated_data | ||
|
|
||
| def __migration_10(self, data: dict) -> dict: |
There was a problem hiding this comment.
we implement migration only for versions that might already be out there. it doesn't make sense to implement 2 versions in 1 PR. There is no version 11 file out there anywhere, only version 10 before this PR and version 12 after this PR.
| parent_scaled = parent_prediction / temperature | ||
|
|
||
| # MSE-based KL approximation with temperature scaling | ||
| loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2) |
There was a problem hiding this comment.
I don't think this does what the comments or the name KL_DIVERGENCE suggests.
There is just a simple MSE loss
Temperature is applied, but that doesn't do anything to MSE loss except put a weight loss in front.
Consider MSE loss:
And then add temperature:
simplify:
That's just a weight loss. temperature makes sense in non-linear functions such as softmax or sigmoid, but here?
- Introduced DistillationTargetMode enum for target transformation options. - Enhanced DistillationCacheManager to support new parameters: target_mode, cfg_scale, rollout_steps, and rollout_blend. - Updated GenericTrainer to utilize new distillation target mode settings. - Modified TrainingTab UI to include options for target mode, cfg scale, rollout steps, and rollout blend. - Updated TrainConfig to include new distillation configuration parameters.
What’s working
Observed behavior / limitations
Practical takeawaySo far, this is a success for targeted training:
This could be a path toward practical “poor man’s turbo checkpoints” once stabilized. Additional note
Known bug
I will squash it up and separate the features into different PR's next time i work on it |
|
What is your plan regarding this PR? Also please think about what this PR should achieve in terms of transfer learning / distillation, and what is beyond its scope. It seems to me that this PR currently attempts to do it all. |
What’s working
Observed behavior / limitations
Practical takeaway
So far, this is a success for targeted training:
This could be a path toward practical “poor man’s turbo checkpoints” once stabilized.
Additional note
Known bug
Initial Draft:
This pull request introduces support for distillation training by adding infrastructure for managing a parent model (teacher) during training, including memory-efficient loading and quantization, as well as integrating distillation loss computation into the training loop. Additionally, it improves LoRA weight handling by allowing separate scaling for text encoder and UNet/transformer components, and simplifies logic for LoRA module creation.
Distillation training support:
ParentModelWrapperclass to handle loading, quantization, device management, and unloading of the parent (teacher) model for distillation. This enables efficient memory usage by supporting CPU offloading and quantized inference for the parent model.GenericTrainer) and model setup (BaseModelSetup). Added a context managerdistillation_parent_modelto temporarily load and move the parent model during distillation steps, with VRAM optimization.LoRA handling improvements:
scale_lora_state_dictinLoRALoaderMixinto allow separate scaling of LoRA weights for text encoder and UNet/transformer components, and applied this scaling before loading LoRA weights in both SD and SDXL LoRA setups.As Discussed in #1353 , Implementing a 2 Step pipeline and precomputing the predictions may allow us to hit almost prior-prediction level performance without the high (V)RAM requirements.
--
As a last note, I kinda rushed this because I was kinda just playing around and I was not actually planning to PR it, so I am not sure it has upstreams quality standard.