feat(moe): add orthogonal initialization for gate parameters#664
feat(moe): add orthogonal initialization for gate parameters#664joenaess wants to merge 6 commits intoarcee-ai:mainfrom
Conversation
|
All contributors have signed the CLA ✍️ ✅ |
|
I have read the CLA Document and I hereby sign the CLA |
| # 3. Cast to the target dtype and move to the requested device | ||
| gate_vecs.append(layer_gate.to(dtype=target_dtype, device=device if device != "auto" else "cpu")) | ||
|
|
||
| return gate_vecs |
There was a problem hiding this comment.
Orthogonal mode returns list instead of tensor
High Severity
The new orthogonal mode returns gate_vecs as a list of tensors, while all other modes (random, uniform_random, hidden, cheap_embed) return a single tensor with shape (num_layers, num_experts, hidden_size). Callers in moe.py use tensor indexing like gate_vecs[:, :len(...), :] and warn_degenerate_gates expects gate_vecs.shape to exist. This will cause a runtime crash when using orthogonal mode.
| # "random" is standard normal distribution (torch.randn) | ||
| # "uniform_random" matches default initialization for torch.nn.Linear | ||
| # "orthogonal" ensures gate vectors are orthogonal for better expert specialization | ||
|
|
There was a problem hiding this comment.
Missing validation bypass for orthogonal mode prompts
Medium Severity
The is_bad_config function has an early return for "random" mode to skip prompt validation, but the new "orthogonal" mode (which also doesn't use prompts) wasn't added to this check. Users attempting to use orthogonal initialization without prompts will get the error "Expert X has no positive prompts" even though orthogonal mode generates gate vectors mathematically without using prompts at all.
|
|
||
| gates = get_gate_params( | ||
| model_cfg=mock_cfg, experts=mock_experts, mode="orthogonal" | ||
| ) |
There was a problem hiding this comment.
Test uses incorrect parameter name for function
Low Severity
The test calls get_gate_params(model_cfg=mock_cfg, ...) but the function signature expects model_ref as the first parameter, not model_cfg. It's also missing the required tokenizer parameter. This test will fail with a TypeError about unexpected keyword arguments.
| dev = [ | ||
| "black~=25.1.0", | ||
| "isort~=6.0.1", | ||
| "pre-commit~=4.2.0", |
There was a problem hiding this comment.
Inconsistent indentation in pyproject.toml dev dependencies
Low Severity
Lines 98-99 in the dev dependency list have inconsistent indentation (only 1 space) compared to the surrounding lines which use 4 spaces. This inconsistency may cause TOML parsing issues or at minimum creates confusing formatting.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
| stats.to_approximate += 1 | ||
|
|
||
| donor_tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
| transformers.AutoTokenizer.from_pretrained( |
There was a problem hiding this comment.
Wasteful tokenizer loading with discarded result
Low Severity
The transformers.AutoTokenizer.from_pretrained call loads a donor tokenizer but the result is discarded. This performs network/disk I/O and memory allocation for no purpose. The change removed the donor_tokenizer variable assignment but kept the useless function call - the entire call can be removed since the loaded tokenizer is never used.


Overview
This PR introduces Orthogonal Initialization for Mixture-of-Experts (MoE) gate parameters. This is a critical feature for "Sparse Upcycling" workflows where a dense monolingual model is transformed into a sparse MoE architecture.
Technical Justification
Standard Gaussian initialization (
random) can lead to high correlation between gate vectors in the early stages of training. In a language technology context, this causes Expert Collapse, where multiple experts are updated with gradients for the same token clusters, hindering specialization.By implementing
torch.nn.init.orthogonal_, we ensure:Changes
orthogonaltoGateModeLiteral for configuration validation.get_gate_params, ensuring float32 precision during computation for mathematical stability.Verification
Ran unit tests using
uv:uv run python -m unittest tests/test_moe_orthogonal.pyStatus: PASSED
Note
Medium Risk
Introduces new initialization behavior for MoE routing weights and widens accepted config values, which can affect downstream training/merge outputs. Also includes a minor but potentially risky change in
tokensurgeon/rope_helpers.pythat leaves a stray no-op statement that could indicate an accidental edit.Overview
Adds a new MoE
gate_modeoption,orthogonal, and implements it inget_gate_paramsby generating per-layer, orthogonally-initialized gate matrices (initialized in float32, then cast/moved to the requested dtype/device).Updates MoE config validation to strictly constrain
gate_modeviaLiteral, and adds unit tests assertingQ @ Q.T ≈ Ifor the new initialization path.Separately, this PR includes broad lint/formatting cleanups (e.g.,
isvs==,is not None, assert formatting), addsruffto dev dependencies, and contains a small functional change intokensurgeon/rope_helpers.pywhere an unusedn_headslocal is removed (leaving a no-op expression).Written by Cursor Bugbot for commit 378bb89. This will update automatically on new commits. Configure here.