Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modules/ui/OptimizerParamsWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def create_dynamic_ui(
'muon_adam_lr': {'title': 'Auxiliary Adam LR', 'tooltip': 'Learning rate for the auxiliary AdamW optimizer. If empty, it will use the main learning rate.', 'type': 'float'},
'muon_te1_adam_lr': {'title': 'AuxAdam TE1 LR', 'tooltip': 'Learning rate for the auxiliary AdamW optimizer for the first text encoder. If empty, it will use the Auxiliary Adam LR.', 'type': 'float'},
'muon_te2_adam_lr': {'title': 'AuxAdam TE2 LR', 'tooltip': 'Learning rate for the auxiliary AdamW optimizer for the second text encoder. If empty, it will use the Auxiliary Adam LR.', 'type': 'float'},
'rms_rescaling': {'title': 'RMS Rescaling', 'tooltip': 'Muon already scales its updates to approximate and use the same learning rate (LR) as Adam. This option integrates a more accurate method to match the Adam LR, but it is slower.', 'type': 'bool'},
'rms_rescaling': {'title': 'RMS Rescaling', 'tooltip': 'Normalizes Muon update magnitudes to align with Adam. This allows to reuse standard "Adam-style" learning rates instead of specialized Muon scales.', 'type': 'bool'},
'normuon_variant': {'title': 'NorMuon Variant', 'tooltip': 'Enables the NorMuon optimizer variant, which combines Muon orthogonalization with per-neuron adaptive learning rates for better convergence and balanced parameter updates. Costs only one scalar state buffer per parameter group, size few KBs, maintaining high memory efficiency.', 'type': 'bool'},
'beta2_normuon': {'title': 'NorMuon Beta2', 'tooltip': 'Exponential decay rate for the neuron-wise second-moment estimator in NorMuon (analogous to Adams beta2). Controls how past squared updates influence current normalization.', 'type': 'float'},
'normuon_eps': {'title': 'NorMuon EPS', 'tooltip': 'Epsilon for NorMuon normalization stability.', 'type': 'float'},
Expand All @@ -199,6 +199,7 @@ def create_dynamic_ui(
'kappa_p': {'title': 'Lion-K P-value', 'tooltip': 'Controls the Lp-norm geometry for the Lion update. 1.0 = Standard Lion (Sign update, coordinate-wise), best for Transformers. 2.0 = Spherical Lion (Normalized update, rotational invariant), best for Conv2d layers (in unet models). Values between 1.0 and 2.0 interpolate behavior between the two.', 'type': 'float'},
'auto_kappa_p': {'title': 'Auto Lion-K', 'tooltip': 'Automatically determines the optimal P-value based on layer dimensions. Uses p=2.0 (Spherical) for 4D (Conv) tensors for stability and rotational invariance, and p=1.0 (Sign) for 2D (Linear) tensors for sparsity. Overrides the manual P-value. Recommend for unet models.', 'type': 'bool'},
'compile': {'title': 'Compiled Optimizer', 'tooltip': 'Enables PyTorch compilation for the optimizer internal step logic. This is intended to improve performance by allowing PyTorch to fuse operations and optimize the computational graph.', 'type': 'bool'},
'spectral_normalization': {'title': 'Spectral Scaling', 'tooltip': 'Enables explicit Spectral Normalization to automatically rescale the update magnitude and Weight Decay based on layer dimensions. This allows hyperparameters to transfer seamlessly from small to large models without retuning, while making the optimizer highly robust to a wide range of learning rates. This ensures consistent performance across different model sizes, adapter methods, and ranks. ', 'type': 'bool'},
}
# @formatter:on

Expand Down
2 changes: 2 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class TrainOptimizerConfig(BaseConfig):
kappa_p: float
auto_kappa_p: False
compile: False
spectral_normalization: False

def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(data)
Expand Down Expand Up @@ -261,6 +262,7 @@ def default_values():
data.append(("kappa_p", None, float, True))
data.append(("auto_kappa_p", False, bool, False))
data.append(("compile", False, bool, False))
data.append(("spectral_normalization", False, bool, False))

return TrainOptimizerConfig(data)

Expand Down
34 changes: 29 additions & 5 deletions modules/util/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from modules.util.optimizer.adafactor_extensions import patch_adafactor
from modules.util.optimizer.adam_extensions import patch_adam
from modules.util.optimizer.adamw_extensions import patch_adamw
from modules.util.optimizer.muon_util import split_parameters_for_muon
from modules.util.optimizer.muon_util import calculate_muon_n_layers, split_parameters_for_muon
from modules.util.TrainProgress import TrainProgress
from modules.zluda import ZLUDA

Expand Down Expand Up @@ -125,7 +125,7 @@ def create_optimizer(
parameter_group_collection: NamedParameterGroupCollection,
state_dict: dict | None,
config: TrainConfig,
layer_key_fn: dict[int, str] | None = None,
model: BaseModel | None = None,
) -> torch.optim.Optimizer | None:
optimizer = None
optimizer_config = config.optimizer
Expand Down Expand Up @@ -842,7 +842,18 @@ def create_optimizer(

from adv_optm import Muon_adv

params_for_optimizer, MuonWithAuxAdam = split_parameters_for_muon(parameters, layer_key_fn, config)
params_for_optimizer, MuonWithAuxAdam = split_parameters_for_muon(model, parameters, config)

if optimizer_config.spectral_normalization:
# Calculate n_layers for spectral normalization
n_layers_map = calculate_muon_n_layers(model)

for group in params_for_optimizer:
group_name = group.get('name')
if group_name in n_layers_map:
group['n_layers'] = n_layers_map[group_name]
else:
group['n_layers'] = n_layers_map.get('default', 1)

# Prepare Adam-specific keyword arguments from the config
adam_kwargs = {}
Expand Down Expand Up @@ -885,6 +896,7 @@ def create_optimizer(
compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False,
Simplified_AdEMAMix=optimizer_config.Simplified_AdEMAMix if optimizer_config.Simplified_AdEMAMix is not None else False,
alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100,
spectral_normalization=optimizer_config.spectral_normalization if optimizer_config.spectral_normalization is not None else False,
**adam_kwargs
)

Expand All @@ -894,7 +906,18 @@ def create_optimizer(

from adv_optm import AdaMuon_adv

params_for_optimizer, MuonWithAuxAdam = split_parameters_for_muon(parameters, layer_key_fn, config)
params_for_optimizer, MuonWithAuxAdam = split_parameters_for_muon(model, parameters, config)

if optimizer_config.spectral_normalization:
# Calculate n_layers for spectral normalization
n_layers_map = calculate_muon_n_layers(model)

for group in params_for_optimizer:
group_name = group.get('name')
if group_name in n_layers_map:
group['n_layers'] = n_layers_map[group_name]
else:
group['n_layers'] = n_layers_map.get('default', 1)

# Prepare Adam-specific keyword arguments from the config
adam_kwargs = {}
Expand Down Expand Up @@ -939,6 +962,7 @@ def create_optimizer(
orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False,
approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False,
compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False,
spectral_normalization=optimizer_config.spectral_normalization if optimizer_config.spectral_normalization is not None else False,
**adam_kwargs
)

Expand All @@ -947,7 +971,7 @@ def create_optimizer(

from muon import MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam

params_for_optimizer, ___ = split_parameters_for_muon(parameters, layer_key_fn, config)
params_for_optimizer, ___ = split_parameters_for_muon(model, parameters, config)

final_param_groups = []
for group in params_for_optimizer:
Expand Down
78 changes: 77 additions & 1 deletion modules/util/optimizer/muon_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from collections.abc import Callable

from modules.model.BaseModel import BaseModel, TrainConfig
Expand All @@ -8,6 +9,79 @@
import torch


def calculate_muon_n_layers(model: BaseModel) -> dict[str, int]:
"""
Calculates the number of residual layers (the depth) in each component of the model.
Used for Muon optimizer spectral normalization scaling.
"""
match model.model_type:
case (ModelType.STABLE_DIFFUSION_15 | ModelType.STABLE_DIFFUSION_15_INPAINTING |
ModelType.STABLE_DIFFUSION_20_BASE | ModelType.STABLE_DIFFUSION_20_INPAINTING |
ModelType.STABLE_DIFFUSION_20 | ModelType.STABLE_DIFFUSION_21 |
ModelType.STABLE_DIFFUSION_21_BASE | ModelType.STABLE_DIFFUSION_XL_10_BASE |
ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING | ModelType.STABLE_CASCADE_1 |
ModelType.WUERSTCHEN_2):
default_patterns = ['transformer_blocks', 'resnets', 'layers']
case (ModelType.STABLE_DIFFUSION_3 | ModelType.STABLE_DIFFUSION_35 | ModelType.SANA |
ModelType.FLUX_DEV_1 | ModelType.CHROMA_1 | ModelType.QWEN |
ModelType.PIXART_ALPHA | ModelType.PIXART_SIGMA):
default_patterns = ['transformer_blocks', 'single_transformer_blocks', 'encoder.block']
case ModelType.HI_DREAM_FULL:
default_patterns = ['double_stream_blocks', 'single_stream_blocks']
case ModelType.Z_IMAGE:
default_patterns = [
'layers',
'refiner',
]
case _:
raise NotImplementedError(f"Muon optimizer spectral normalization is not implemented for model type: {model.model_type}")

# Build the regex pattern dynamically
joined_patterns = "|".join([re.escape(p) for p in default_patterns])
pattern = re.compile(rf'(?:^|\.)(?:{joined_patterns})\.\d+$')

layer_counts = {}
Copy link
Copy Markdown
Collaborator

@dxqb dxqb Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function always returns {}: layer_counts is never modified.
it doesn't seem to have side effects either.

what is it supposed to do?

it appears that it's supposed to count the number of trained layers, I guess for scaling later in the optimizer.
But why does it have its own regex layer filter? Shouldn't the count depend on what layers the user is actually training (via the layer filter on the training tab)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function always returns {}: layer_counts is never modified. it doesn't seem to have side effects either.

what is it supposed to do?

Fixed, it was deleted accidently

it appears that it's supposed to count the number of trained layers, I guess for scaling later in the optimizer. But why does it have its own regex layer filter? Shouldn't the count depend on what layers the user is actually training (via the layer filter on the training tab)?

It calculates the model depth (the number of residual layers). For SDXL, this consists of transformer_blocks and resnets; for Transformers, it includes only transformer_blocks (or their equivalent names).

I think, we have two additional options:

  1. Create a new utility specifically to calculate depth (the same logic as this)
  2. Hardcode the integer values (e.g., SDXL = 48).

You may ask why we need the depth. To achieve scale-invariance in the optimizer, we must utilize the depth as follows:

  1. For Muon: It is inserted as a damping factor for orthogonalization (eps).
  2. For Adam: It is inserted as a damping factor for the second moment (eps).

This ensures that the damping factor grows as the model grows. For example, with Klein 8B and Klein 4B, these scalings allow us to use the same hyperparameters for both models.


# Iterate over model components (e.g., 'unet', 'text_encoder', 'transformer')
for attr_name, module in vars(model).items():

# Identify the 'Ground Truth' blocks in this component.
target_module = module
if isinstance(module, LoRAModuleWrapper):
target_module = module.orig_module

valid_component_blocks = set()
if isinstance(target_module, torch.nn.Module):
for name, _ in target_module.named_modules():
if pattern.search(name):
valid_component_blocks.add(name)

if not valid_component_blocks:
continue

active_component_blocks = set()

# Filter: Only count blocks that are actually being trained.
if isinstance(module, LoRAModuleWrapper):
# For LoRA, we check if the active leaves reside inside a valid block.
for layer_name in module.lora_modules:
parts = layer_name.split('.')
for i in range(len(parts), 0, -1):
candidate = ".".join(parts[:i])
if candidate in valid_component_blocks:
active_component_blocks.add(candidate)

elif isinstance(module, torch.nn.Module):
# For standard full-finetuning, all valid blocks are active.
active_component_blocks = valid_component_blocks

count = len(active_component_blocks)
if count > 0:
layer_counts[attr_name] = count

return layer_counts


def build_muon_adam_key_fn(
model: BaseModel,
config: TrainConfig,
Expand Down Expand Up @@ -105,8 +179,8 @@ def get_optim_type(param_name: str, p: torch.nn.Parameter) -> str:
return param_map

def split_parameters_for_muon(
model: BaseModel,
parameters: list[dict],
layer_key_fn: dict[int, str],
config: TrainConfig,
) -> tuple[list[dict], bool]:
"""
Expand All @@ -116,6 +190,8 @@ def split_parameters_for_muon(
optimizer_config = config.optimizer

has_adam_params = False
layer_key_fn = build_muon_adam_key_fn(model, config)

if layer_key_fn:
for group in parameters:
for p in group['params']:
Expand Down
10 changes: 3 additions & 7 deletions modules/util/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from modules.util.config.TrainConfig import TrainConfig, TrainOptimizerConfig
from modules.util.enum.Optimizer import Optimizer
from modules.util.NamedParameterGroup import NamedParameterGroupCollection
from modules.util.optimizer.muon_util import build_muon_adam_key_fn
from modules.util.torch_util import optimizer_to_device_

import torch
Expand Down Expand Up @@ -59,13 +58,8 @@ def init_model_parameters(
#to be safe, do that before the optimizer is created because the optimizer could take copies
multi.broadcast_parameters(parameters.parameters(), train_device)

layer_key_fn = None
if model.train_config.optimizer.MuonWithAuxAdam:
print("INFO: Creating layer keys for MuonWithAuxAdam.")
layer_key_fn = build_muon_adam_key_fn(model, model.train_config)

model.optimizer = create.create_optimizer(
parameters, model.optimizer_state_dict, model.train_config, layer_key_fn
parameters, model.optimizer_state_dict, model.train_config, model=model
)

if model.optimizer is not None:
Expand Down Expand Up @@ -596,6 +590,7 @@ def init_model_parameters(
"low_rank_ortho": False,
"ortho_rank": 128,
"rms_rescaling": True,
"spectral_normalization": False,
"nnmf_factor": False,
"stochastic_rounding": True,
"compile": False,
Expand Down Expand Up @@ -627,6 +622,7 @@ def init_model_parameters(
"low_rank_ortho": False,
"ortho_rank": 128,
"rms_rescaling": True,
"spectral_normalization": False,
"nnmf_factor": False,
"stochastic_rounding": True,
"compile": False,
Expand Down