Skip to content

[advoptm] Spectral Normalization for Muon Variants#1263

Draft
Koratahiu wants to merge 11 commits intoNerogar:masterfrom
Koratahiu:SN_MUON
Draft

[advoptm] Spectral Normalization for Muon Variants#1263
Koratahiu wants to merge 11 commits intoNerogar:masterfrom
Koratahiu:SN_MUON

Conversation

@Koratahiu
Copy link
Copy Markdown
Contributor

TL;DR: Tune Once, Train Anywhere

This PR implements the spectral normalization/scaling proposed in Hyperparameter Transfer Enables Consistent Gains of Matrix-Preconditioned Optimizers Across Scales (NeurIPS 2025) for Muon_adv and AdaMuon_adv.

This method allows you to tune hyperparameters (LR, Weight Decay) just once, and they will transfer to any model size.

  • Cross-Model: Train SD1.5, SDXL, and Flux using the exact same LR/WD.
  • Cross-Rank: A Rank 1 LoRA typically uses ~1e-3 LR. With this, you can use that same 1e-3 LR for Rank 1, Rank 128, and beyond.

Important Notes:

  • LoRA Alpha: The built-in scaling of LoRA alpha negatively interacts with this method. Set alpha=rank to disable the internal scaling!
  • OFT: OFT uses matrices with unique (skewed) dimensions (Rank vs. Total Elements). While this is compatible, the extreme aspect ratio is sub-optimal for matrix-aware optimizers like Muon. It will work, but likely requires a different LR range than standard methods (0.1 LR worked for me).

Other Notes:

  • Suggested Ranges: Start with 1e-3 LR for standard LoRAs/Finetunes. For OFT, you may need to go up to 0.1.
  • High Robustness: This method is extremely stable. You can often multiply the LR by 10x and still maintain very similar validation loss.
  • Unified Rates: You can typically use the exact same (or very similar) LR for both the UNet/DiT and Text Encoders.
  • Stable & Tested: This method has been tested on SDXL LoRA, OFT, and Full Finetuning with solid results. However, I am leaving this as a draft/dev version to collect more feedback.

More Info:
Koratahiu/Advanced_Optimizers#14

@Koratahiu
Copy link
Copy Markdown
Contributor Author

Update 1:
I wouldn't recommend this method or the orthogonal optimizers for OFT. Its shape and mechanics do not work well with them.
But I trained it using 0.1 LR, it trains but not optimal.

​I have made the weight decay for this method static and "decoupled" from the learning rate.
The paper recommends a weight decay value of 0.1; however, this needs more testing.

@Koratahiu
Copy link
Copy Markdown
Contributor Author

Update 2:

  • This is incompatible with DoRA (muP/spectral scaling conflicts with DoRA scaling).

  • The optimal value for weight decay is what the paper found:
    0.1 (hyperparameter value) * 1/width (scaling rule)
    This worked very well for me for both LoRAs and finetunes.

@Koratahiu
Copy link
Copy Markdown
Contributor Author

I tested this, and it works very well.

I used the same hyperparameters for LoRA/finetuning across SDXL, Chroma, and Zib; it trained successfully for all of them and delivered very solid results.

I’ve found the baseline for all of them to be:

  • Learning Rate (LR): 1e-3
  • Weight Decay: 0.1

From there, you can adjust as needed (e.g., a higher Batch Size requires a larger LR, BF16 needs a higher LR, etc.).

The weight decay is constant and differs from standard implementations, meaning it maintains the same effect regardless of whether the LR is high or low. The formula for it is:
Weight_decay (hyperparameter value) * 1/width (scaling rule)

❗ Note: that this method does not work with DoRAs, as DoRA has its own scaling which conflicts with this approach. It also behaves unpredictably with OFT (not sure, it trains at 0.1 LR).

❕ For full finetuning, 1D vectors are trained using AuxAdam, so you should use standard AdamW LR and weight decay settings for those.

@Koratahiu Koratahiu marked this pull request as ready for review January 31, 2026 18:39
@Koratahiu
Copy link
Copy Markdown
Contributor Author

More Helpful Notes for LoRA Using This Method

1) Rank-Invariant Updates:

It is interesting to note that using this method for LoRA completely cancels out the rank effect (assuming alpha = rank).
Its update rule can be simplified as:
ΔW = A (√ height/rank) * B (√ rank/width)

In this scenario, rank is cancelled out, allowing us to apply the full finetuning scaling rule:
ΔW (A*B) = √(height/width)

This achieves the same learning rate as full finetuning and results in rank-invariant updates.
This leads to a universal, shared LR across all ranks, which aligns perfectly with the goal of this method: tuning once for all ranks and adapters.

2) Addressing the LoRA A-Matrix

Muon appears to be sub-optimal for LoRA because the A matrix is often extremely "flat" or exhibits extreme dimensions, leading to unstable or "garbage" orthogonalization.
However, spectral scaling seems to resolve this. By calculating a very high eps, the orthogonalization is heavily dampened. This forces Muon to behave more like normalized SGD.

Mathematically, we achieve:

  • Rank-invariant updates.
  • A stabilized A matrix.

These are my own findings, as the original paper did not experiment with LoRAs. Nonetheless, these results are very promising for LoRA/Muon combinations.

@dxqb
Copy link
Copy Markdown
Collaborator

dxqb commented Feb 27, 2026

is this superceeded by #1344?

@Koratahiu
Copy link
Copy Markdown
Contributor Author

is this superceeded by #1344?

Yes and no
I didn't touch the Muon logic in #1344, and left this PR as it's well tested and proven
Once #1344 is ready and stable it should be

@Koratahiu
Copy link
Copy Markdown
Contributor Author

@dxqb
Can we merge this?

joined_patterns = "|".join([re.escape(p) for p in default_patterns])
pattern = re.compile(rf'(?:^|\.)(?:{joined_patterns})\.\d+$')

layer_counts = {}
Copy link
Copy Markdown
Collaborator

@dxqb dxqb Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function always returns {}: layer_counts is never modified.
it doesn't seem to have side effects either.

what is it supposed to do?

it appears that it's supposed to count the number of trained layers, I guess for scaling later in the optimizer.
But why does it have its own regex layer filter? Shouldn't the count depend on what layers the user is actually training (via the layer filter on the training tab)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function always returns {}: layer_counts is never modified. it doesn't seem to have side effects either.

what is it supposed to do?

Fixed, it was deleted accidently

it appears that it's supposed to count the number of trained layers, I guess for scaling later in the optimizer. But why does it have its own regex layer filter? Shouldn't the count depend on what layers the user is actually training (via the layer filter on the training tab)?

It calculates the model depth (the number of residual layers). For SDXL, this consists of transformer_blocks and resnets; for Transformers, it includes only transformer_blocks (or their equivalent names).

I think, we have two additional options:

  1. Create a new utility specifically to calculate depth (the same logic as this)
  2. Hardcode the integer values (e.g., SDXL = 48).

You may ask why we need the depth. To achieve scale-invariance in the optimizer, we must utilize the depth as follows:

  1. For Muon: It is inserted as a damping factor for orthogonalization (eps).
  2. For Adam: It is inserted as a damping factor for the second moment (eps).

This ensures that the damping factor grows as the model grows. For example, with Klein 8B and Klein 4B, these scalings allow us to use the same hyperparameters for both models.

@dxqb dxqb marked this pull request as draft March 15, 2026 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants