-
-
Notifications
You must be signed in to change notification settings - Fork 275
Draft: Add LoRA weight scaling and CFG distillation functionality #1360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
714aeba
42b2a0d
abd62c5
b618437
8beb308
3d4af62
170a75e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from modules.model.StableDiffusionModel import StableDiffusionModel | ||
| from modules.modelSetup.BaseModelSetup import BaseModelSetup | ||
| from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup | ||
| from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin | ||
| from modules.module.LoRAModule import LoRAModuleWrapper | ||
| from modules.util import factory | ||
|
Comment on lines
+4
to
6
|
||
| from modules.util.config.TrainConfig import TrainConfig | ||
|
|
@@ -69,7 +70,7 @@ def setup_model( | |
| if config.train_any_embedding(): | ||
| model.text_encoder.get_input_embeddings().to(dtype=config.embedding_weight_dtype.torch_dtype()) | ||
|
|
||
| create_te = config.text_encoder.train or state_dict_has_prefix(model.lora_state_dict, "lora_te") | ||
| create_te = config.text_encoder.train | ||
| model.text_encoder_lora = LoRAModuleWrapper( | ||
| model.text_encoder, "lora_te", config | ||
| ) if create_te else None | ||
|
|
@@ -79,6 +80,13 @@ def setup_model( | |
| ) | ||
|
|
||
| if model.lora_state_dict: | ||
| # Apply scaling factors to LoRA weights before loading | ||
| model.lora_state_dict = LoRALoaderMixin.scale_lora_state_dict( | ||
| model.lora_state_dict, | ||
| te_scale=config.lora_te_scale, | ||
| unet_scale=config.lora_unet_scale, | ||
| ) | ||
|
|
||
| if create_te: | ||
| model.text_encoder_lora.load_state_dict(model.lora_state_dict) | ||
| model.unet_lora.load_state_dict(model.lora_state_dict) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from modules.model.StableDiffusionXLModel import StableDiffusionXLModel | ||
| from modules.modelSetup.BaseModelSetup import BaseModelSetup | ||
| from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup | ||
| from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin | ||
| from modules.module.LoRAModule import LoRAModuleWrapper | ||
| from modules.util import factory | ||
|
Comment on lines
3
to
6
|
||
| from modules.util.config.TrainConfig import TrainConfig | ||
|
|
@@ -76,8 +77,8 @@ def setup_model( | |
| model: StableDiffusionXLModel, | ||
| config: TrainConfig, | ||
| ): | ||
| create_te1 = config.text_encoder.train or state_dict_has_prefix(model.lora_state_dict, "lora_te1") | ||
| create_te2 = config.text_encoder_2.train or state_dict_has_prefix(model.lora_state_dict, "lora_te2") | ||
| create_te1 = config.text_encoder.train | ||
| create_te2 = config.text_encoder_2.train | ||
|
|
||
| model.text_encoder_1_lora = LoRAModuleWrapper( | ||
| model.text_encoder_1, "lora_te1", config | ||
|
|
@@ -92,6 +93,13 @@ def setup_model( | |
| ) | ||
|
|
||
| if model.lora_state_dict: | ||
| # Apply scaling factors to LoRA weights before loading | ||
| model.lora_state_dict = LoRALoaderMixin.scale_lora_state_dict( | ||
| model.lora_state_dict, | ||
| te_scale=config.lora_te_scale, | ||
| unet_scale=config.lora_unet_scale, | ||
| ) | ||
|
|
||
| if create_te1: | ||
| model.text_encoder_1_lora.load_state_dict(model.lora_state_dict) | ||
| if create_te2: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |||
| from modules.util.config.TrainConfig import TrainConfig | ||||
| from modules.util.DiffusionScheduleCoefficients import DiffusionScheduleCoefficients | ||||
| from modules.util.enum.LossWeight import LossWeight | ||||
| from modules.util.loss.masked_loss import masked_losses, masked_losses_with_prior | ||||
| from modules.util.loss.masked_loss import masked_losses, masked_losses_with_prior, distillation_loss | ||||
| from modules.util.loss.vb_loss import vb_losses | ||||
|
|
||||
| import torch | ||||
|
|
@@ -134,6 +134,21 @@ def __masked_losses( | |||
| normalize_masked_area_loss=config.normalize_masked_area_loss, | ||||
| ).mean(mean_dim) * config.vb_loss_strength | ||||
|
|
||||
| # Distillation loss | ||||
| if config.distillation.enabled and 'prior_target' in data and 'distillation_indices' in data: | ||||
| distillation_indices = data['distillation_indices'] | ||||
| if len(distillation_indices) > 0: | ||||
| # Calculate distillation loss only for samples marked as DISTILLATION | ||||
| dist_loss = distillation_loss( | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think distillation needs a new loss function (except maybe: KL loss - see my other comment below) We already modify the training target without modifying the loss function for PRIOR_PREDICTION here: OneTrainer/modules/trainer/GenericTrainer.py Line 744 in 45b6d9a
or in my transfer training PR here: it is not necessary for these, and I don't think for your PR, to modify the loss function. it might look like this is necessary because there is some prior prediction code in the loss function handling. However, this is for masked prior prediction which is an entirely different functionality. |
||||
| student_prediction=data['predicted'][distillation_indices].to(dtype=torch.float32), | ||||
| parent_prediction=data['prior_target'][distillation_indices].to(dtype=torch.float32), | ||||
| loss_type=config.distillation.loss_type, | ||||
| temperature=config.distillation.kl_temperature, | ||||
| mask=batch['latent_mask'][distillation_indices].to(dtype=torch.float32) if config.masked_training else None, | ||||
| reduction='mean', | ||||
| ) | ||||
| losses += dist_loss * config.distillation.loss_weight | ||||
|
|
||||
|
Comment on lines
+137
to
+151
|
||||
| return losses | ||||
|
|
||||
| def __unmasked_losses( | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of applying a scale to the loaded LoRA?