Skip to content

[advoptm] New Features: Scaled Optimizers, Centered WD, Factored 2nd Moment & More#1344

Draft
Koratahiu wants to merge 25 commits intoNerogar:masterfrom
Koratahiu:scaled_optm
Draft

[advoptm] New Features: Scaled Optimizers, Centered WD, Factored 2nd Moment & More#1344
Koratahiu wants to merge 25 commits intoNerogar:masterfrom
Koratahiu:scaled_optm

Conversation

@Koratahiu
Copy link
Copy Markdown
Contributor

This PR introduces several powerful improvements to the advanced optimizers, including:

  1. Scaled Optimizers: Tune once, train anywhere.
    (Continuation of [advoptm] Spectral Normalization for Muon Variants #1263 but for all adv optimizers!)
  2. Centered WD: A powerful technique for full finetuning that preserves original model knowledge by pulling weights toward their original pre-trained state.
    • Note: It also works for DoRA by pulling the DoRA scale norm to its original state.
  3. Factored Second Moment: The second moment of Adam variants is very special. By factoring it, we make the optimizer highly robust to a wide range of LRs, mimicking high-order optimization.

Scaled Optimizer

Automatically scales the LR/WD to transfer seamlessly across all training methods. You only need to tune once, and your hyperparameters will transfer optimally across all LoRA ranks, Full Finetuning, etc.

We already achieved this with spectral Muon (#1263), but this PR extends the technique to all advanced optimizers!

Compatible with: Full finetuning, LoRA, DoRA, and OFT.

  • Set alpha=rank for LoRA/DoRA.

  • This decouples the WD from the LR.


Centered Weight Decay

For small-to-medium scale training, we want to learn new concepts while preserving the model's original knowledge. However, standard WD (for full finetuning & DoRA) is often useless or harmful here, as it pulls weights down to zero, leading to forgetting and destructive behavior.

Centered WD mitigates this by pulling the weight down to its original state. This forces the optimizer to learn a smooth representation of the dataset while preserving the original model knowledge.

  • Quantization Support: Supports different types of quants to save the original model state to pull down to (we only need an approximation): full, fp8, int8, and int4.
  • Hybrid Usage: Can be used alongside standard weight decay.

Special Case with DoRA:
When using DoRA, there's a discrepancy between how LoRA factors decay and how DoRA scale decays. LoRA factors decay to the original model state, while the DoRA norm scale decays to zero. In my tests, LoRA factors endured very high decay values (and delivered great results), while the DoRA scale breaks at small decay values (which makes sense, since decaying it destroys the original model norm).

The Solution: To get around this discrepancy, you can use both Centered WD and Standard WD together. We route them accordingly: Centered WD is applied to the LoRA factors (and DoRA), while Standard WD is applied only to the DoRA scale. This allows you to tune them independently; setting a very small Standard WD for the DoRA scale, and a high Centered WD for everything else.


Signed Optimizers

This PR introduces two main improvements to signed optimizers (SignSGD_adv and Lion_adv):

1. Freeze-on-Flip (Projected Variant)

The sign operation is discontinuous; it flips (+/- 1) randomly and rapidly near zero, leading to unstable training. This instability is especially visible near convergence, where the optimizer has to slow down to reach the optimal solution (this is also why signed optimizers typically need decaying schedulers).

Freeze-on-Flip is a simple method that stores the sign of the previous step (skipping this for factored mode, and using uint8 for standard) and freezes any current update components that have flipped signs. Over successive steps, this simulates the effect of 0 and makes the signed update semi-continuous. This leads to better convergence and much more stable training.

2. Projected Adaptivity

We can make signed optimizers more curvature-aware and adaptive by scaling the LR by the L1 norm, utilizing the dual norm of their geometry ($L_\infty \rightarrow L_1$).
This makes the optimizer semi-adaptive to changes in gradients.

gesen2egee added a commit to gesen2egee/OneTrainer that referenced this pull request Feb 26, 2026
@Koratahiu
Copy link
Copy Markdown
Contributor Author

Update 1

Improved K-B for LoRA/OFT

K-B now calculates $\beta_2$ per-rank/block for LoRA/OFT. This achieves:

  1. Rank-invariant estimations: More stable estimations that are no longer affected by the size of the rank/block.
  2. Fine-grained $\beta_2$: Calculations are now localized; every rank/block (each row/column) receives its own adaptive $\beta_2$ value.

Summary: By ensuring every rank/block in LoRA/OFT has its own adaptive $\beta_2$, training becomes significantly more stable and robust. If one rank/block spikes or becomes noisy, K-B reacts to it immediately without affecting the rest of the ranks/blocks.


Improved OrthoGrad for LoRA/OFT

OrthoGrad has been optimized for LoRA/OFT:

  • Previous behavior: It flattened the matrix into one large vector. While functional for full fine-tuning, this was sub-optimal for LoRA factors and OFT blocks.
  • New behavior: It now iterates per-rank/block to orthogonalize them individually.

This results in more accurate orthogonalization and ensures that if a specific block or rank is noisy, it will not negatively impact the orthogonalization of others.

@Koratahiu
Copy link
Copy Markdown
Contributor Author

Koratahiu commented Mar 16, 2026

Update 2:

Added Fisher Weight Decay (Natural Weight Decay)

The Fisher Information Matrix (FIM) tells us how sensitive the model's output is to changes in parameters. The diagonal of this matrix, diag(F), represents the individual sensitivity of each weight.

Adam's second moment can be interpreted as an approximation of the diag(F), giving the importance of each parameter.

Standard weight decay applies a uniform penalty, Fisher WD is a form of Adaptive Weight Decay, that's derived from Adam's second moment. It applies adaptive per-parameter weight decay.

How it behaves:

  • High Curvature (Large second moment): The FIM is large, meaning this weight is "important" and the loss surface is steep. Fisher WD becomes small. Consequently, the weight decay penalty is reduced, protecting "load-bearing" weights.
  • Flat Regions (Small second moment): The FIM is small, meaning the weight doesn't affect the output much. The Fisher WD becomes large. This accelerates the decay of these "useless" weights, effectively pruning them or pushing them toward zero.

Scale-Invariant WD

By multiplying by the (inverse) square root of the Fisher, we make the regularization scale-invariant. The penalty becomes proportional to how much the weight actually matters to the loss, rather than just how large its raw numerical value is.

  • Standard weight decay is biased on tensor width (scale variant - wider tensors get harder regularization).

When using Fisher WD alongside Scaled Optimizer option, we make another part of the optimizer scale-invariant, which is the weight decay!

Usage

  • Enable Fisher Weight Decay

  • Set weight decay to:
    weight decay = 1e-6/LR
    (or just 1e-6 when using Scaled Optimizer)

    • Effective 1e-3 WD is a good starting point, since it's (Scale-Invariant WD), you only need to tune it once for all models, trainings etc..
    • Try to avoid a long-stalled second moment (high beta2); I would recommend a beta2 of 0.95 to 0.99 (or use K-b for auto-tuning beta2).

Sources:

@Koratahiu
Copy link
Copy Markdown
Contributor Author

Gemini 3.1 Pro inspection of Scaled Optimizer Logic for all adapters (pretty neat):

Based on the provided codebase (specifically scaled_optm.py and LoRAModule.py), the update complexity of the output variance $\mathbb{E}[|\Delta y|^2]$ for all these modules is strictly $\mathcal{O}(1)$ with respect to network dimensions ($d_{in}, d_{out}$), OFT block sizes ($b$), and LoRA ranks ($r$).

This means the optimizer effectively decouples the learning rate from the architecture's geometry (similar to muP / Maximal Update Parametrization scaling).

Below is the mathematical breakdown of the exact variance for each module.


1. Update Complexity Analysis

Assuming a standard initialization where the input vector $x$ has $\mathbb{E}[x_i^2] = 1$ (so $|x|^2 = d_{in}$) and base weights have variance $1/d_{in}$. Let $\eta$ be the learning rate (lr).

A. DoRA Scale (Magnitude Vector $m$)

  • Code logic: In scaled_optm.py, 1D tensors (like dora_scale) undergo rms_normalization.
    scale_n = math.sqrt(n)
    return update.mul_(lr * scale_n / norm)
  • Variance: This forces the $\Delta m$ update to have a Root Mean Square (RMS) of exactly $\eta$. Thus, the variance of each element $\Delta m_i$ is $\eta^2$.
  • Output impact: The forward pass is $y = m \odot \text{normalized features}$. Since the normalized features have a variance of $\approx 1$, the change in output $\Delta y_i \approx \Delta m_i \cdot 1$.
  • Update Complexity: $\mathcal{O}(\eta^2)$ $\rightarrow$ $\mathcal{O}(1)$ independent of $d_{out}$.
  • Exact Variance: $\mathbf{1 \cdot \eta^2}$

B. OFT Blocks ($Q$)

  • Code logic: OFT matrices are block-diagonal skew-symmetric matrices. For a block of size $b$, the number of learnable upper-triangular elements is $n_{el} \approx b^2/2$. The optimizer scales each block independently:
    target_norm = math.sqrt(b / 8)
    scale = target_norm / math.sqrt(n)
  • Variance: rms_normalization(dim=1) makes the vector $L_2$ norm $\eta \sqrt{n_{el}}$. Multiplying by scale makes the $L_2$ norm exactly $\eta \sqrt{b/8}$.
    When populating the full $b \times b$ skew-symmetric matrix block, the Frobenius norm squared becomes $2 \times (\eta^2 b / 8) = \eta^2 b / 4$.
    Distributing this variance across the $b^2$ elements gives a per-element variance of $\frac{\eta^2 b/4}{b^2} = \frac{\eta^2}{4b}$.
  • Output impact: $\Delta y \approx W \Delta Q^T x$. An inner product with $x$ over the $b$ elements in a block yields a variance of $b \times \frac{\eta^2}{4b} \times 1 = \frac{\eta^2}{4}$. Multiplying by $W$ preserves this variance.
  • Account for the derivative of the Cayley Transform, which introduces a factor of 2.
  • Update Complexity: $\mathcal{O}(\eta^2)$ $\rightarrow$ $\mathcal{O}(1)$ independent of $b$ and $d_{in}$.
  • Exact Variance: $\mathbf{1 \cdot \eta^2}$

C. LoRA Factors ($A$ and $B$)

  • Code logic: 2D matrices undergo spectral_normalization to maintain a target spectral norm derived from the "Modular Norm" paper:
    target_scale = math.sqrt(d_out / d_in)
  • Variance: This ensures $|\Delta A|2 = \eta \sqrt{r/d{in}}$ and $|\Delta B|2 = \eta \sqrt{d{out}/r}$. In terms of Frobenius norm, $\Delta B$ acts as an isotropic matrix with variance bounded by $\frac{\eta^2}{r}$ per element.
  • Output impact: $\Delta y = \frac{\alpha}{r} (\Delta B A + B \Delta A) x$. Assuming $B$ starts at 0, $\Delta y = \frac{\alpha}{r} \Delta B A x$. The product $(Ax)$ has variance 1. Multiplying by $\Delta B$ (summing over $r$ elements) yields $r \times \frac{\eta^2}{r} \times 1 = \eta^2$. Applying the LoRA scalar $\frac{\alpha}{r}$ squares the factor.
  • Update Complexity: $\mathcal{O}(\eta^2)$ $\rightarrow$ $\mathcal{O}(1)$ independent of $r, d_{in}, d_{out}$.
  • Exact Variance: $\mathbf{(\frac{\alpha}{r})^2 \cdot \eta^2}$

Aligning LoRA to DoRA/OFT

To bring the LoRA variance to exactly $1.0 \eta^2$, you just need to ensure the term $(\frac{\alpha}{r})^2$ equals $1$.

  • In your training config, simply enforce that lora_alpha is strictly equal to lora_rank (i.e., $\alpha = r$). If you do this, scaled_optm natively guarantees alignment.

Understanding DoRA-OFT Joint Variance

Because the DoRA magnitude vector ($m$) and the OFT rotation blocks ($Q$) map to orthogonal transformations of the weight matrix (magnitude vs direction), their variance contributions are strictly additive.
One training step of DoRA-OFT will yield a combined output variance of exactly:

$$ Var(\Delta y_{total}) = Var(\Delta y_m) + Var(\Delta y_Q) = \eta^2 + \eta^2 = \mathbf{2 \eta^2} $$

In your training config, use x2 smaller LR for DoRA-OFT.

@Koratahiu
Copy link
Copy Markdown
Contributor Author

Koratahiu commented Mar 17, 2026

Update 3: Scale-Invariant epsilon

Scale-Invariant eps

I've implemented the final piece of scale-invariance for Adam-based optimizers: a scaling rule for the eps hyperparameter.

This ensures that second-moment estimations remain stable and do not grow or diminish as tensor sizes change. By doing so, we effectively eliminate eps as a manual hyperparameter, joining Learning Rate and Weight Decay in the "auto-scaled" suite.

With this addition, Scaled AdamW maintained a stable 1e-3 LR across all training methods and even outperformed Spectral Muon.
image_2026-03-16_11-29-57


TODO

  • Input/Output Layer Support: Currently, the scaling logic is implemented for residual layers (those trained in LoRA/OFT/DoRA); I plan to extend support to the full model architecture (even though for small-med trainings, those layers shouldn't be trained.)
  • State Management: As discussed on Discord, I’m considering string-based selection options for optimizer states (e.g., forced FP32, factorization, BF16 with Stochastic Rounding, etc.) to improve flexibility.

@Koratahiu Koratahiu mentioned this pull request Mar 24, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant