Skip to content

feat(moe): add orthogonal initialization for gate parameters#664

Open
joenaess wants to merge 6 commits intoarcee-ai:mainfrom
joenaess:feature/moe-orthogonal-init
Open

feat(moe): add orthogonal initialization for gate parameters#664
joenaess wants to merge 6 commits intoarcee-ai:mainfrom
joenaess:feature/moe-orthogonal-init

Conversation

@joenaess
Copy link
Copy Markdown

@joenaess joenaess commented Feb 4, 2026

Overview

This PR introduces Orthogonal Initialization for Mixture-of-Experts (MoE) gate parameters. This is a critical feature for "Sparse Upcycling" workflows where a dense monolingual model is transformed into a sparse MoE architecture.

Technical Justification

Standard Gaussian initialization (random) can lead to high correlation between gate vectors in the early stages of training. In a language technology context, this causes Expert Collapse, where multiple experts are updated with gradients for the same token clusters, hindering specialization.

By implementing torch.nn.init.orthogonal_, we ensure:

  • The gate matrix has a condition number of 1.
  • Experts start by covering maximally distinct regions of the hidden state manifold.
  • Faster convergence during the initial multilingual fine-tuning phase.

Changes

  • mergekit/moe/config.py: Added orthogonal to GateMode Literal for configuration validation.
  • mergekit/moe/router.py: Implemented the initialization logic in get_gate_params, ensuring float32 precision during computation for mathematical stability.
  • tests/test_moe_orthogonal.py: Added unit tests to verify the mathematical orthogonality ($Q Q^T = I$) across all layers.

Verification

Ran unit tests using uv:
uv run python -m unittest tests/test_moe_orthogonal.py
Status: PASSED


Note

Medium Risk
Introduces new initialization behavior for MoE routing weights and widens accepted config values, which can affect downstream training/merge outputs. Also includes a minor but potentially risky change in tokensurgeon/rope_helpers.py that leaves a stray no-op statement that could indicate an accidental edit.

Overview
Adds a new MoE gate_mode option, orthogonal, and implements it in get_gate_params by generating per-layer, orthogonally-initialized gate matrices (initialized in float32, then cast/moved to the requested dtype/device).

Updates MoE config validation to strictly constrain gate_mode via Literal, and adds unit tests asserting Q @ Q.T ≈ I for the new initialization path.

Separately, this PR includes broad lint/formatting cleanups (e.g., is vs ==, is not None, assert formatting), adds ruff to dev dependencies, and contains a small functional change in tokensurgeon/rope_helpers.py where an unused n_heads local is removed (leaving a no-op expression).

Written by Cursor Bugbot for commit 378bb89. This will update automatically on new commits. Configure here.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 4, 2026

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

@joenaess
Copy link
Copy Markdown
Author

joenaess commented Feb 4, 2026

I have read the CLA Document and I hereby sign the CLA

Comment thread mergekit/moe/router.py
# 3. Cast to the target dtype and move to the requested device
gate_vecs.append(layer_gate.to(dtype=target_dtype, device=device if device != "auto" else "cpu"))

return gate_vecs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal mode returns list instead of tensor

High Severity

The new orthogonal mode returns gate_vecs as a list of tensors, while all other modes (random, uniform_random, hidden, cheap_embed) return a single tensor with shape (num_layers, num_experts, hidden_size). Callers in moe.py use tensor indexing like gate_vecs[:, :len(...), :] and warn_degenerate_gates expects gate_vecs.shape to exist. This will cause a runtime crash when using orthogonal mode.

Fix in Cursor Fix in Web

Comment thread mergekit/moe/config.py Outdated
Comment thread mergekit/moe/config.py
Comment thread mergekit/moe/config.py
# "random" is standard normal distribution (torch.randn)
# "uniform_random" matches default initialization for torch.nn.Linear
# "orthogonal" ensures gate vectors are orthogonal for better expert specialization

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation bypass for orthogonal mode prompts

Medium Severity

The is_bad_config function has an early return for "random" mode to skip prompt validation, but the new "orthogonal" mode (which also doesn't use prompts) wasn't added to this check. Users attempting to use orthogonal initialization without prompts will get the error "Expert X has no positive prompts" even though orthogonal mode generates gate vectors mathematically without using prompts at all.

Fix in Cursor Fix in Web

Comment thread tests/test_moe_init.py

gates = get_gate_params(
model_cfg=mock_cfg, experts=mock_experts, mode="orthogonal"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test uses incorrect parameter name for function

Low Severity

The test calls get_gate_params(model_cfg=mock_cfg, ...) but the function signature expects model_ref as the first parameter, not model_cfg. It's also missing the required tokenizer parameter. This test will fail with a TypeError about unexpected keyword arguments.

Fix in Cursor Fix in Web

Comment thread pyproject.toml
dev = [
"black~=25.1.0",
"isort~=6.0.1",
"pre-commit~=4.2.0",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent indentation in pyproject.toml dev dependencies

Low Severity

Lines 98-99 in the dev dependency list have inconsistent indentation (only 1 space) compared to the surrounding lines which use 4 spaces. This inconsistency may cause TOML parsing issues or at minimum creates confusing formatting.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

stats.to_approximate += 1

donor_tokenizer = transformers.AutoTokenizer.from_pretrained(
transformers.AutoTokenizer.from_pretrained(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasteful tokenizer loading with discarded result

Low Severity

The transformers.AutoTokenizer.from_pretrained call loads a donor tokenizer but the result is discarded. This performs network/disk I/O and memory allocation for no purpose. The change removed the donor_tokenizer variable assignment but kept the useless function call - the entire call can be removed since the loaded tokenizer is never used.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant