Skip to content

Draft: Add LoRA weight scaling and CFG distillation functionality#1360

Draft
m4xw wants to merge 7 commits intoNerogar:masterfrom
m4xw:master
Draft

Draft: Add LoRA weight scaling and CFG distillation functionality#1360
m4xw wants to merge 7 commits intoNerogar:masterfrom
m4xw:master

Conversation

@m4xw
Copy link
Copy Markdown

@m4xw m4xw commented Mar 6, 2026

What’s working

  • Lower CFG during inference now looks significantly better.
    • Training was done at CFG Distillation Scale 7 (and then 3), and inference at CFG 1 now looks roughly like prior CFG 7 quality.
  • Overall quality gains are already noticeable after 5-10 steps, especially on trained concepts.
  • Style transfer performance is very strong.
  • For concept-focused use cases, results at low CFG are excellent.
  • Inference speed improved substantially at CFG 1:
    • From ~21-24s (CFG 7) to ~11.5-11.7s (CFG 1), close to 2x faster.
  • Compared against the base checkpoint, this setup gives both a quality boost and a speed boost.

Observed behavior / limitations

  • Tested SDXL to SDXL only
  • Trained concepts improve strongly at earlier CFG, but non-trained content improves less consistently.
  • In refiner workflows, behavior is somewhat unstable/funky but very good after later training.
  • There appears to be a drop in baseline activations without additional regularization/diversification.
  • The model strongly biases toward trained concepts even with generic prompts (offset by lower CFG at inference)

Practical takeaway

So far, this is a success for targeted training:

  • Very high quality at CFG 1
  • Strong style transfer
  • Better ability to learn concepts that the parent model previously resisted

This could be a path toward practical “poor man’s turbo checkpoints” once stabilized.

Additional note

  • Rank compression/back-merge experiments (e.g., train rank 32 -> apply -> diff to rank 16) seem promising:
    • Styles can be preserved
    • CFG-scale shifts can be partially reduced in some merges

Known bug

  • Cache generation can fail on low VRAM after the first epoch.
  • Current workaround:
    1. Press stop after cache generation starts
    2. Wait for cache generation to finish
    3. Ensure sampling and backups are disabled during the cache run

Initial Draft:
This pull request introduces support for distillation training by adding infrastructure for managing a parent model (teacher) during training, including memory-efficient loading and quantization, as well as integrating distillation loss computation into the training loop. Additionally, it improves LoRA weight handling by allowing separate scaling for text encoder and UNet/transformer components, and simplifies logic for LoRA module creation.

Distillation training support:

  • Added a new ParentModelWrapper class to handle loading, quantization, device management, and unloading of the parent (teacher) model for distillation. This enables efficient memory usage by supporting CPU offloading and quantized inference for the parent model.
  • Integrated the parent model wrapper into the trainer (GenericTrainer) and model setup (BaseModelSetup). Added a context manager distillation_parent_model to temporarily load and move the parent model during distillation steps, with VRAM optimization.
  • Added distillation loss calculation to the diffusion loss mixin, applying the configured loss type and temperature only to relevant samples during training.

LoRA handling improvements:

  • Introduced a static method scale_lora_state_dict in LoRALoaderMixin to allow separate scaling of LoRA weights for text encoder and UNet/transformer components, and applied this scaling before loading LoRA weights in both SD and SDXL LoRA setups.
  • Simplified logic for creating LoRA modules for text encoders in SD and SDXL setups by removing checks for state dict prefixes and relying only on the training config.

As Discussed in #1353 , Implementing a 2 Step pipeline and precomputing the predictions may allow us to hit almost prior-prediction level performance without the high (V)RAM requirements.

--

As a last note, I kinda rushed this because I was kinda just playing around and I was not actually planning to PR it, so I am not sure it has upstreams quality standard.

…s. Implement Injecting of TE or UNet

Implement Simple Distillation
Copilot AI review requested due to automatic review settings March 6, 2026 10:14
@m4xw m4xw marked this pull request as draft March 6, 2026 10:14
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds initial infrastructure for distillation training (teacher/parent model management + distillation loss wiring) and improves LoRA loading by supporting separate TE vs UNet scaling and simplifying TE LoRA module creation.

Changes:

  • Introduces ParentModelWrapper and integrates a distillation parent-model context manager into the trainer/setup.
  • Adds a distillation loss helper and attempts to integrate distillation loss into diffusion loss computation.
  • Adds LoRA state-dict scaling (TE vs UNet) and exposes scaling controls in config + UI.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
modules/util/loss/masked_loss.py Adds distillation_loss() helper and new enum dependency.
modules/util/enum/DistillationLossType.py New enum for distillation loss selection.
modules/util/enum/ConceptType.py Adds DISTILLATION concept type.
modules/util/config/TrainConfig.py Adds DistillationConfig, LoRA scaling config fields, and config migrations/version bump.
modules/ui/TrainingTab.py Adds a distillation settings UI frame.
modules/ui/LoraTab.py Adds TE/UNet LoRA load scaling UI fields.
modules/trainer/GenericTrainer.py Integrates parent model wrapper and attempts to generate/store teacher outputs for distillation.
modules/module/ParentModelWrapper.py New wrapper to load/quantize/offload the teacher model.
modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py Attempts to add distillation loss into diffusion loss computation.
modules/modelSetup/StableDiffusionXLLoRASetup.py Applies LoRA scaling before loading and simplifies TE module creation logic.
modules/modelSetup/StableDiffusionLoRASetup.py Applies LoRA scaling before loading and simplifies TE module creation logic.
modules/modelSetup/BaseModelSetup.py Adds distillation_parent_model() context manager; minor guard in parameter creation.
modules/modelLoader/mixin/LoRALoaderMixin.py Adds scale_lora_state_dict() helper.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@@ -0,0 +1,188 @@
import torch
from abc import ABCMeta
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ABCMeta is imported but never used, which will trigger Ruff F401. Remove the unused import.

Suggested change
from abc import ABCMeta

Copilot uses AI. Check for mistakes.
Comment on lines +138 to +151
if config.distillation.enabled and 'parent_target' in data and 'distillation_indices' in data:
distillation_indices = data['distillation_indices']
if len(distillation_indices) > 0:
# Calculate distillation loss only for samples marked as DISTILLATION
dist_loss = distillation_loss(
student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32),
parent_prediction=data['parent_target'][distillation_indices].to(dtype=torch.float32),
loss_type=config.distillation.loss_type,
temperature=config.distillation.kl_temperature,
mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None,
reduction='mean',
)
losses += dist_loss * config.distillation.loss_weight

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distillation loss is guarded by 'parent_target' in data, but the trainer populates model_output_data['prior_target'] (and never sets parent_target). As a result, distillation loss will never run. Use a consistent key (e.g., set parent_target in the trainer and read that here, or reuse prior_target intentionally).

Suggested change
if config.distillation.enabled and 'parent_target' in data and 'distillation_indices' in data:
distillation_indices = data['distillation_indices']
if len(distillation_indices) > 0:
# Calculate distillation loss only for samples marked as DISTILLATION
dist_loss = distillation_loss(
student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32),
parent_prediction=data['parent_target'][distillation_indices].to(dtype=torch.float32),
loss_type=config.distillation.loss_type,
temperature=config.distillation.kl_temperature,
mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None,
reduction='mean',
)
losses += dist_loss * config.distillation.loss_weight
if config.distillation.enabled and 'distillation_indices' in data:
# Prefer explicitly-provided parent_target, fall back to prior_target if available
parent_key = None
if 'parent_target' in data:
parent_key = 'parent_target'
elif 'prior_target' in data:
parent_key = 'prior_target'
if parent_key is not None:
distillation_indices = data['distillation_indices']
if len(distillation_indices) > 0:
# Calculate distillation loss only for samples marked as DISTILLATION
dist_loss = distillation_loss(
student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32),
parent_prediction=data[parent_key][distillation_indices].to(dtype=torch.float32),
loss_type=config.distillation.loss_type,
temperature=config.distillation.kl_temperature,
mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None,
reduction='mean',
)
losses += dist_loss * config.distillation.loss_weight

Copilot uses AI. Check for mistakes.
Comment on lines +754 to +788
# Run parent/prior model if needed for either prior prediction or distillation
if len(prior_pred_indices) > 0 or len(distillation_indices) > 0 \
or (self.config.masked_training
and self.config.masked_prior_preservation_weight > 0
and self.config.training_method == TrainingMethod.LORA):
with self.model_setup.prior_model(self.model, self.config), torch.no_grad():
#do NOT create a subbatch using the indices, even though it would be more efficient:
#different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly:
prior_model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
with self.model_setup.distillation_parent_model(
self.model,
self.config,
self.parent_model_wrapper
) as parent_model, torch.no_grad():
# Do NOT create a subbatch using the indices, even though it would be more efficient:
# Different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly
parent_model_output_data = self.model_setup.predict(
parent_model if parent_model != self.model else self.model,
batch,
self.config,
train_progress
)

# Run student model
model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
prior_model_prediction = prior_model_output_data['predicted'].to(dtype=model_output_data['target'].dtype)
model_output_data['target'][prior_pred_indices] = prior_model_prediction[prior_pred_indices]

# Get parent predictions
prior_model_prediction = parent_model_output_data['predicted'].to(dtype=model_output_data['target'].dtype)

# For prior_prediction: Replace target (legacy behavior)
if len(prior_pred_indices) > 0:
model_output_data['target'][prior_pred_indices] = prior_model_prediction[prior_pred_indices]

# Store parent prediction for masked prior preservation and distillation
model_output_data['prior_target'] = prior_model_prediction

# For distillation: Store indices for loss calculation
if len(distillation_indices) > 0:
model_output_data['distillation_indices'] = distillation_indices
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior_model_prediction is computed from parent_model_output_data (teacher when distillation is enabled) and then stored into model_output_data['prior_target']. This changes the meaning of prior_target (used by masked prior preservation losses) from “student base model without LoRA” to “external parent model”, altering existing training behavior whenever distillation is enabled. Store teacher output under a separate key (e.g., parent_target) and continue computing prior_target via prior_model() (unhooked adapters) when masked prior preservation / PRIOR_PREDICTION is needed.

Copilot uses AI. Check for mistakes.
Comment on lines +83 to +93
# For KL divergence on continuous predictions (e.g., latents)
# We use a log-normal approximation:
# Treat predictions as means of distributions, compute KL between Gaussians
# KL(P||Q) = 0.5 * (||mu_p - mu_q||^2 / sigma^2)
# For simplicity, assume unit variance (sigma=1)
# Temperature scaling: divide by temperature before computing
student_scaled = student_prediction / temperature
parent_scaled = parent_prediction / temperature

# MSE-based KL approximation with temperature scaling
loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DistillationLossType.KL_DIVERGENCE currently computes an MSE between scaled predictions (a Gaussian KL up to a constant factor), not an actual KL divergence. This makes the enum name/UI option misleading and makes KL and MSE nearly equivalent (aside from temperature). Either rename it to reflect the implemented metric (e.g., Gaussian_KL / scaled_MSE) or implement a true KL for the intended distribution.

Suggested change
# For KL divergence on continuous predictions (e.g., latents)
# We use a log-normal approximation:
# Treat predictions as means of distributions, compute KL between Gaussians
# KL(P||Q) = 0.5 * (||mu_p - mu_q||^2 / sigma^2)
# For simplicity, assume unit variance (sigma=1)
# Temperature scaling: divide by temperature before computing
student_scaled = student_prediction / temperature
parent_scaled = parent_prediction / temperature
# MSE-based KL approximation with temperature scaling
loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2)
# KL divergence for distillation on logit-like predictions.
# Standard formulation:
# KL(P_teacher || P_student) with
# P_teacher = softmax(parent / T), P_student = softmax(student / T)
# implemented as:
# kl_div(log_softmax(student / T), softmax(parent / T)).
# As in common distillation practice, we multiply by T^2 to preserve
# gradient magnitudes when using temperature scaling.
student_log_probs = F.log_softmax(student_prediction / temperature, dim=-1)
parent_probs = F.softmax(parent_prediction / temperature, dim=-1)
loss = F.kl_div(student_log_probs, parent_probs, reduction='none') * (temperature ** 2)

Copilot uses AI. Check for mistakes.
Comment on lines +768 to +769
#("NFloat4", DataType.NFLOAT_4),
#("Int8", DataType.INT_8),
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distillation quantization dtype dropdown omits NFLOAT_4 / INT_8, but DistillationConfig.default_values() sets parent_quantization_dtype=NFLOAT_4. components.options_kv() will treat this as invalid and silently force the value to the first option (None), so the UI won’t reflect config defaults. Either include the default dtype(s) in the options list or change the config default to match what the UI supports.

Suggested change
#("NFloat4", DataType.NFLOAT_4),
#("Int8", DataType.INT_8),
("NFloat4", DataType.NFLOAT_4),
("Int8", DataType.INT_8),

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record those types are broken in this setup for some reason

Comment on lines +137 to +151
# Distillation loss
if config.distillation.enabled and 'parent_target' in data and 'distillation_indices' in data:
distillation_indices = data['distillation_indices']
if len(distillation_indices) > 0:
# Calculate distillation loss only for samples marked as DISTILLATION
dist_loss = distillation_loss(
student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32),
parent_prediction=data['parent_target'][distillation_indices].to(dtype=torch.float32),
loss_type=config.distillation.loss_type,
temperature=config.distillation.kl_temperature,
mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None,
reduction='mean',
)
losses += dist_loss * config.distillation.loss_weight

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distillation loss is only added inside __masked_losses(). When config.masked_training is false, the training loop uses __unmasked_losses() and distillation has no effect. Distillation should be applied in both masked and unmasked loss paths (with masking applied conditionally).

Copilot uses AI. Check for mistakes.
Comment on lines +754 to +771
# Run parent/prior model if needed for either prior prediction or distillation
if len(prior_pred_indices) > 0 or len(distillation_indices) > 0 \
or (self.config.masked_training
and self.config.masked_prior_preservation_weight > 0
and self.config.training_method == TrainingMethod.LORA):
with self.model_setup.prior_model(self.model, self.config), torch.no_grad():
#do NOT create a subbatch using the indices, even though it would be more efficient:
#different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly:
prior_model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
with self.model_setup.distillation_parent_model(
self.model,
self.config,
self.parent_model_wrapper
) as parent_model, torch.no_grad():
# Do NOT create a subbatch using the indices, even though it would be more efficient:
# Different timesteps are used for a smaller subbatch by predict(), but the conditioning must match exactly
parent_model_output_data = self.model_setup.predict(
parent_model if parent_model != self.model else self.model,
batch,
self.config,
train_progress
)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When distillation is enabled, the code runs distillation_parent_model(...) even for ConceptType.PRIOR_PREDICTION samples. That means PRIOR_PREDICTION targets come from the external parent model instead of the “model as it was before training” (as described in the UI/help text), which is a behavioral break. Consider running PRIOR_PREDICTION via prior_model() (unhook LoRA adapters) regardless of distillation, and only use the external parent model for ConceptType.DISTILLATION samples.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record: This was wanted behaviour in my testing phase

Comment on lines +97 to +102
# Apply mask if provided
if mask is not None:
# Use same masking logic as other losses
clamped_mask = torch.clamp(mask, 0.1, 1.0) # Use default unmasked_weight
loss = loss * clamped_mask

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distillation_loss() applies a mask by clamping with a hard-coded 0.1 (unmasked_weight). This bypasses the configurable TrainConfig.unmasked_weight used everywhere else, so masked distillation behaves inconsistently with other losses. Pass unmasked_weight into this helper (or clamp using the caller’s configured value).

Copilot uses AI. Check for mistakes.
Comment on lines +82 to +93
elif loss_type == DistillationLossType.KL_DIVERGENCE:
# For KL divergence on continuous predictions (e.g., latents)
# We use a log-normal approximation:
# Treat predictions as means of distributions, compute KL between Gaussians
# KL(P||Q) = 0.5 * (||mu_p - mu_q||^2 / sigma^2)
# For simplicity, assume unit variance (sigma=1)
# Temperature scaling: divide by temperature before computing
student_scaled = student_prediction / temperature
parent_scaled = parent_prediction / temperature

# MSE-based KL approximation with temperature scaling
loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temperature is used as a divisor for the KL/scaled-MSE path but is not validated. If the config/UI allows setting it to 0 (or a negative value), this will produce inf/NaNs. Add an explicit check that temperature > 0 (and raise a clear ValueError) before dividing.

Copilot uses AI. Check for mistakes.
Comment on lines 3 to 6
from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup
from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin
from modules.module.LoRAModule import LoRAModuleWrapper
from modules.util import factory
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the state_dict_has_prefix(...) logic, the state_dict_has_prefix import in this file is now unused and will trigger Ruff F401. Remove that import (or reintroduce usage if still intended).

Copilot uses AI. Check for mistakes.
@m4xw
Copy link
Copy Markdown
Author

m4xw commented Mar 6, 2026

Guess the biggest issue seems i still mess around some prior_prediction stuff in distillation codepaths, gonna fix that when i get home from work and maybe also implement the loss properly / unsimplified

m4xw added 2 commits March 6, 2026 11:55
TODO: Split prior_target into parent_target for type DISTILLATION
@m4xw
Copy link
Copy Markdown
Author

m4xw commented Mar 6, 2026

Implemented a crude prediction cache which doubles the perf on consumer GPU (but looks like theres a small leak introduced from the cache)

Copy link
Copy Markdown
Collaborator

@dxqb dxqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some early comments, not meant as a full review.
I did not read Copilot's comments here. Could you clean those up, and only leave open those that you plan to work on?

distillation_indices = data['distillation_indices']
if len(distillation_indices) > 0:
# Calculate distillation loss only for samples marked as DISTILLATION
dist_loss = distillation_loss(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think distillation needs a new loss function (except maybe: KL loss - see my other comment below)
distillation is a modification of the training target, not of the loss function per se.

We already modify the training target without modifying the loss function for PRIOR_PREDICTION here:

model_output_data['target'][prior_pred_indices] = prior_model_prediction[prior_pred_indices]

or in my transfer training PR here:
https://github.com/dxqb/OneTrainer/blob/38708bdfb31c608c51d73faec4f761f6fb6b037b/modules/trainer/GenericTrainer.py#L772

it is not necessary for these, and I don't think for your PR, to modify the loss function. it might look like this is necessary because there is some prior prediction code in the loss function handling. However, this is for masked prior prediction which is an entirely different functionality.

pass

@staticmethod
def scale_lora_state_dict(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of applying a scale to the loaded LoRA?

  • I can't think of many situations this could be useful
  • if it is useful, you could achieve the same by converting a LoRA to lower weights (for example through a simple Comfy workflow). Should this be in OneTrainer?
  • I don't see the relation to distillation. if I'm wrong with my earlier two points, I think it should be a separate PR

Hash-based cache key
"""
# Combine image path and timestep for uniqueness
key_string = f"{image_path}_{timestep}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image path and timestep isn't a sufficient cache key, because we have variations

from torch import Tensor


class DistillationCacheManager:
Copy link
Copy Markdown
Collaborator

@dxqb dxqb Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this needs to be designed as a cache.
For distillation, you first have to make a prediction by the teacher model. Then you set that prediciton as the training target of the student model.

To avoid having both models in ram/vram at the same time, you can separate these two steps and save the teacher predictions. But does it have to be a cache, with keys, hits and misses?
Or do we simply write parent predictions sequentially to disk, and then read them sequentially later?

quantized_modules = []

# Step 1: Replace linear layers with quantized versions
for attr_name in dir(self.parent_model):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this PR is a draft, but a final PR shouldn't have this amount of code duplication (from quantization_util.py)

# Parent Model Type
# Filter only Stable Diffusion family models
sd_model_types = [
("SD 1.5", ModelType.STABLE_DIFFUSION_15),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code duplication


return migrated_data

def __migration_10(self, data: dict) -> dict:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we implement migration only for versions that might already be out there. it doesn't make sense to implement 2 versions in 1 PR. There is no version 11 file out there anywhere, only version 10 before this PR and version 12 after this PR.

parent_scaled = parent_prediction / temperature

# MSE-based KL approximation with temperature scaling
loss = F.mse_loss(student_scaled, parent_scaled, reduction='none') * (temperature ** 2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this does what the comments or the name KL_DIVERGENCE suggests.
There is just a simple MSE loss

Temperature is applied, but that doesn't do anything to MSE loss except put a weight loss in front.
Consider MSE loss:

Image

And then add temperature:

Image

simplify:

Image

That's just a weight loss. temperature makes sense in non-linear functions such as softmax or sigmoid, but here?

@O-J1 O-J1 changed the title DRAFT: Request for Comments: Add LoRA weight scaling and simple distillation functionality Draft: Add LoRA weight scaling and simple distillation functionality Mar 10, 2026
m4xw added 3 commits March 10, 2026 16:15
- Introduced DistillationTargetMode enum for target transformation options.
- Enhanced DistillationCacheManager to support new parameters: target_mode, cfg_scale, rollout_steps, and rollout_blend.
- Updated GenericTrainer to utilize new distillation target mode settings.
- Modified TrainingTab UI to include options for target mode, cfg scale, rollout steps, and rollout blend.
- Updated TrainConfig to include new distillation configuration parameters.
@m4xw
Copy link
Copy Markdown
Author

m4xw commented Mar 12, 2026

What’s working

  • Lower CFG during inference now looks significantly better.
    • Training was done at CFG Distillation Scale 7 (and then 3), and inference at CFG 1 now looks roughly like prior CFG 7 quality.
  • Overall quality gains are already noticeable after 5-10 steps, especially on trained concepts.
  • Style transfer performance is very strong.
  • For concept-focused use cases, results at low CFG are excellent.
  • Inference speed improved substantially at CFG 1:
    • From ~21-24s (CFG 7) to ~11.5-11.7s (CFG 1), close to 2x faster.
  • Compared against the base checkpoint, this setup gives both a quality boost and a speed boost.

Observed behavior / limitations

  • Trained concepts improve strongly at earlier CFG, but non-trained content improves less consistently.
  • In refiner workflows, behavior is somewhat unstable/funky but very good after later training.
  • There appears to be a drop in baseline activations without additional regularization/diversification.
  • The model strongly biases toward trained concepts even with generic prompts (offset by lower CFG at inference)

Practical takeaway

So far, this is a success for targeted training:

  • Very high quality at CFG 1
  • Strong style transfer
  • Better ability to learn concepts that the parent model previously resisted

This could be a path toward practical “poor man’s turbo checkpoints” once stabilized.

Additional note

  • Rank compression/back-merge experiments (e.g., train rank 32 -> apply -> diff to rank 16) seem promising:
    • Styles can be preserved
    • CFG-scale shifts can be partially reduced in some merges

Known bug

  • Cache generation can fail on low VRAM after the first epoch.
  • Current workaround:
    1. Press stop after cache generation starts
    2. Wait for cache generation to finish
    3. Ensure sampling and backups are disabled during the cache run

I will squash it up and separate the features into different PR's next time i work on it

@m4xw m4xw changed the title Draft: Add LoRA weight scaling and simple distillation functionality Draft: Add LoRA weight scaling and CFG distillation functionality Mar 12, 2026
@dxqb
Copy link
Copy Markdown
Collaborator

dxqb commented Mar 28, 2026

What is your plan regarding this PR?
If you want to finish it towards a mergeable PR, I'd advise to break it down to smaller pieces and submit multiple, smaller PRs. It currently does multiple unrelated things. Most obvious example: LoRA weight scaling.

Also please think about what this PR should achieve in terms of transfer learning / distillation, and what is beyond its scope. It seems to me that this PR currently attempts to do it all.
but distillation is an active field of research. For example, what made Z-Image famous is a new type of distillation: https://arxiv.org/pdf/2511.13649

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants