From cf689519d0ab513caedd1651d73e95e881d4cf08 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Wed, 18 Mar 2026 13:39:25 +0000
Subject: [PATCH 1/6] update test for 5.3.0
---
tests/test_convert_model.py | 45 +++++++++++++++++++++++++++++++++++--
1 file changed, 43 insertions(+), 2 deletions(-)
diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py
index 2b1babf..a1f68dd 100644
--- a/tests/test_convert_model.py
+++ b/tests/test_convert_model.py
@@ -282,7 +282,12 @@ def _copy_sparse_moe_weights(original_block: nn.Module, defused_block: nn.Module
def _assert_sparse_moe_defused_matches_fused_math(
- original_block: nn.Module, defused_block: nn.Module, hidden_states: torch.Tensor
+ original_block: nn.Module,
+ defused_block: nn.Module,
+ hidden_states: torch.Tensor,
+ *,
+ atol: float | None = None,
+ rtol: float | None = None,
) -> None:
_seed_floating_tensors(original_block)
_copy_sparse_moe_weights(original_block, defused_block)
@@ -291,7 +296,12 @@ def _assert_sparse_moe_defused_matches_fused_math(
actual = defused_block.eval()(hidden_states)
# The defused replacement must preserve the exact MoE matmul path of the fused block.
- torch.testing.assert_close(actual, expected)
+ assert_close_kwargs = {}
+ if atol is not None:
+ assert_close_kwargs["atol"] = atol
+ if rtol is not None:
+ assert_close_kwargs["rtol"] = rtol
+ torch.testing.assert_close(actual, expected, **assert_close_kwargs)
def test_qwen2_moe():
@@ -634,9 +644,40 @@ def test_glm4_moe_defused_forward_matches_fused_math():
Glm4MoeMoE(config),
LinearGlm4MoeMoE(config),
hidden_states,
+ # GLM4 MoE now shows tiny fp32 roundoff drift in 5.3.0 because the fused gate/up matmul
+ # is compared against two split linears. Keep the tolerance narrow enough to catch real regressions.
+ atol=3e-4,
+ rtol=1e-4,
)
+def test_defused_models_preserve_output_router_logits_capture():
+ cases = [
+ (
+ "mixtral",
+ lambda: MixtralForCausalLM(_tiny_mixtral_config()),
+ ),
+ (
+ "qwen2_moe",
+ lambda: Qwen2MoeForCausalLM(_tiny_moe_config(Qwen2MoeConfig)),
+ ),
+ (
+ "qwen3_moe",
+ lambda: Qwen3MoeForCausalLM(_tiny_moe_config(Qwen3MoeConfig)),
+ ),
+ ]
+
+ for model_type, build_model in cases:
+ replace_fused_blocks(model_type)
+ model = build_model().eval()
+ outputs = model(input_ids=torch.tensor([[1, 2, 3]]), output_router_logits=True)
+
+ # Router logits are captured through upstream hooks, so defused blocks must keep the same router module type.
+ assert outputs.router_logits is not None
+ assert len(outputs.router_logits) == 1
+ assert outputs.router_logits[0].shape == (3, model.config.num_experts)
+
+
def test_glm4v_checkpoint_mapping_splits_gate_up_proj():
from defuser.defuser import get_checkpoint_conversion_mapping
From 1a572bd8bd75701bdd51107d92f94423f73db51b Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Wed, 18 Mar 2026 13:43:51 +0000
Subject: [PATCH 2/6] update readme
---
README.md | 196 ++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 190 insertions(+), 6 deletions(-)
diff --git a/README.md b/README.md
index 17ead0e..4594d57 100644
--- a/README.md
+++ b/README.md
@@ -10,10 +10,194 @@
-Model defuser helper for HF Transformers >= 5.0. In HF Transformers 5.x releases, many MoE modules became auto-stacked or auto-fused by new modeling code which has benefits but also downsides.
-* Goal is to provide naive module/layer forwarding code for all models supported by HF transformers where run-time
- weight and structure level optimizations such weight merging, stacking, fusing are reversed so the model is operating
- in a simple naive state.
-* There are cases, quantization libraries, where we need to run inference where module input/output needs to be
- individually captured and this pkg can help complete this task.
+Defuser converts select Hugging Face Transformers 5.x fused or stacked MoE and MLP blocks back into plain, per-expert `nn.Linear` modules. It keeps the forward math intact while exposing individual projections again so quantizers, activation capture, debugging hooks, and checkpoint tooling can work against a simple module layout instead of fused expert tensors.
+
+## Purpose
+
+Defuser exists for cases where newer Transformers modeling code optimizes model structure in ways that are good for runtime, but harder for tooling that needs direct access to individual projections.
+
+Depending on the model family, Defuser can:
+
+- patch a supported model class before load so HF instantiates a defused block directly
+- split fused tensors such as `gate_up_proj` into `gate_proj` + `up_proj`
+- convert 3D expert tensors into numbered expert `nn.Linear` modules
+- preserve the original fused math while presenting a naive module structure again
+
+Public API:
+
+```python
+from defuser import convert_model, replace_fused_blocks
+```
+
+- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
+- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints.
+
+## Supported Models
+
+| Model type | Min `transformers` | Recommended entrypoint | Defused op performed |
+| --- | --- | --- | --- |
+| `mixtral` | `5.0.0` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
+| `qwen2_moe` | `5.0.0` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
+| `qwen3_moe` | `5.0.0` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
+| `qwen3_5_moe` | `5.2.0` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. |
+| `qwen3_5_moe_text` | `5.2.0` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. |
+| `qwen3_next` | `5.0.0` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. |
+| `qwen3_omni_moe` | `5.0.0` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
+| `glm4_moe` | `5.0.0` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
+| `glm4v` | `5.0.0` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
+
+## Workflow Summary
+
+Use `replace_fused_blocks()` for model families that Defuser can patch before load:
+
+```python
+from defuser import replace_fused_blocks
+from transformers import MixtralForCausalLM
+
+replace_fused_blocks("mixtral")
+model = MixtralForCausalLM.from_pretrained(
+ "mistralai/Mixtral-8x7B-v0.1",
+ dtype="auto",
+ device_map="auto",
+)
+```
+
+Use `convert_model()` for already loaded models whose expert tensors still need runtime defusion:
+
+```python
+from defuser import convert_model
+
+converted = convert_model(model)
+print(converted) # True when runtime defusion happened
+```
+
+## Real Qwen3.5 MoE Example
+
+The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser's support gate for this family starts at `transformers>=5.2.0`.
+
+### Fused Weights Before And After
+
+Before `convert_model(model)`:
+
+```text
++--------------------------------------------------------+---------------------------------------------+
+| State dict key | Layout |
++--------------------------------------------------------+---------------------------------------------+
+| model.language_model.layers.0.mlp.experts.gate_up_proj | fused gate+up tensor for all experts |
+| | [num_experts, 2 * moe_intermediate, hidden] |
+| model.language_model.layers.0.mlp.experts.down_proj | fused per-expert down tensor |
+| | [num_experts, hidden, moe_intermediate] |
++--------------------------------------------------------+---------------------------------------------+
+```
+
+After `convert_model(model)`:
+
+```text
++-----------------------------------------------------------------+--------------------------------------+
+| State dict key | Layout |
++-----------------------------------------------------------------+--------------------------------------+
+| model.language_model.layers.0.mlp.experts.0.gate_proj.weight | expert 0 gate projection |
+| model.language_model.layers.0.mlp.experts.0.up_proj.weight | expert 0 up projection |
+| model.language_model.layers.0.mlp.experts.0.down_proj.weight | expert 0 down projection |
+| ... repeated for experts 1..N-1 | numbered expert nn.Linear modules |
++-----------------------------------------------------------------+--------------------------------------+
+```
+
+### Sample 1: Inspect The Conversion In Place
+
+```python
+from defuser import convert_model
+from transformers import Qwen3_5MoeForConditionalGeneration
+
+model_id = "Qwen/Qwen3.5-35B-A3B-Instruct"
+
+model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
+ model_id,
+ dtype="auto",
+ device_map="auto",
+)
+
+prefix = "model.language_model.layers.0.mlp.experts"
+
+before = [name for name, _ in model.named_parameters() if name.startswith(prefix)]
+print(before)
+# [
+# "model.language_model.layers.0.mlp.experts.gate_up_proj",
+# "model.language_model.layers.0.mlp.experts.down_proj",
+# ]
+
+converted = convert_model(model)
+assert converted is True
+
+after = [name for name, _ in model.named_parameters() if name.startswith(prefix)]
+print(after[:6])
+# [
+# "model.language_model.layers.0.mlp.experts.0.down_proj.weight",
+# "model.language_model.layers.0.mlp.experts.0.gate_proj.weight",
+# "model.language_model.layers.0.mlp.experts.0.up_proj.weight",
+# "model.language_model.layers.0.mlp.experts.1.down_proj.weight",
+# "model.language_model.layers.0.mlp.experts.1.gate_proj.weight",
+# "model.language_model.layers.0.mlp.experts.1.up_proj.weight",
+# ]
+```
+
+### Sample 2: Convert And Keep Using The Model Normally
+
+```python
+import torch
+
+from defuser import convert_model
+from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration
+
+model_id = "Qwen/Qwen3.5-35B-A3B-Instruct"
+
+model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
+ model_id,
+ dtype="auto",
+ device_map="auto",
+)
+processor = AutoProcessor.from_pretrained(model_id)
+
+convert_model(model)
+
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Explain mixture-of-experts routing in one sentence."},
+ ],
+ }
+]
+
+inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt",
+)
+inputs = inputs.to(model.device)
+
+with torch.inference_mode():
+ output_ids = model.generate(**inputs, max_new_tokens=64)
+
+generated_ids = [
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
+]
+text = processor.batch_decode(
+ generated_ids,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+)[0]
+print(text)
+```
+
+After conversion, the first routed expert in the first MoE layer is exposed as normal submodules:
+
+```python
+expert0 = model.model.language_model.layers[0].mlp.experts[0]
+print(type(expert0.gate_proj).__name__) # Linear
+print(type(expert0.up_proj).__name__) # Linear
+print(type(expert0.down_proj).__name__) # Linear
+```
From e1d56e238c45f272eeab794d3b8d505dff2a41a2 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Wed, 18 Mar 2026 13:49:08 +0000
Subject: [PATCH 3/6] pin min version to transformers 5.3.0
---
README.md | 27 ++++++++++---------
defuser/defuser.py | 16 ++++++++++-
defuser/model_registry.py | 19 ++++++-------
defuser/utils/common.py | 22 ++++++++++++++-
defuser/utils/hf.py | 4 +++
tests/test_convert_model.py | 54 +++++++++++++++++++++++++++++++++++++
6 files changed, 118 insertions(+), 24 deletions(-)
diff --git a/README.md b/README.md
index 4594d57..0ca1ed1 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
-Defuser converts select Hugging Face Transformers 5.x fused or stacked MoE and MLP blocks back into plain, per-expert `nn.Linear` modules. It keeps the forward math intact while exposing individual projections again so quantizers, activation capture, debugging hooks, and checkpoint tooling can work against a simple module layout instead of fused expert tensors.
+Defuser converts select Hugging Face Transformers `5.3.0+` fused or stacked MoE and MLP blocks back into plain, per-expert `nn.Linear` modules. It keeps the forward math intact while exposing individual projections again so quantizers, activation capture, debugging hooks, and checkpoint tooling can work against a simple module layout instead of fused expert tensors.
## Purpose
@@ -32,20 +32,21 @@ from defuser import convert_model, replace_fused_blocks
- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints.
+- Defuser only supports `transformers>=5.3.0`. Older versions log a warning on these public APIs and are skipped as unsupported.
## Supported Models
-| Model type | Min `transformers` | Recommended entrypoint | Defused op performed |
-| --- | --- | --- | --- |
-| `mixtral` | `5.0.0` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
-| `qwen2_moe` | `5.0.0` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
-| `qwen3_moe` | `5.0.0` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
-| `qwen3_5_moe` | `5.2.0` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. |
-| `qwen3_5_moe_text` | `5.2.0` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. |
-| `qwen3_next` | `5.0.0` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. |
-| `qwen3_omni_moe` | `5.0.0` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
-| `glm4_moe` | `5.0.0` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
-| `glm4v` | `5.0.0` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
+| Model type | Recommended entrypoint | Defused op performed |
+| --- | --- | --- |
+| `mixtral` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
+| `qwen2_moe` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
+| `qwen3_moe` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
+| `qwen3_5_moe` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. |
+| `qwen3_5_moe_text` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. |
+| `qwen3_next` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. |
+| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
+| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
+| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
## Workflow Summary
@@ -74,7 +75,7 @@ print(converted) # True when runtime defusion happened
## Real Qwen3.5 MoE Example
-The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser's support gate for this family starts at `transformers>=5.2.0`.
+The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser supports `transformers>=5.3.0`.
### Fused Weights Before And After
diff --git a/defuser/defuser.py b/defuser/defuser.py
index 1e8ab55..9a9ccbe 100644
--- a/defuser/defuser.py
+++ b/defuser/defuser.py
@@ -10,6 +10,11 @@
from defuser.model_registry import MODEL_CONFIG, PATCH
from defuser.modeling.model_patches import apply_model_class_patches, apply_model_patches
from defuser.modeling.update_module import update_module
+from defuser.utils.common import (
+ MIN_SUPPORTED_TRANSFORMERS_VERSION,
+ is_supported_transformers_version,
+ warn_if_public_api_transformers_unsupported,
+)
from packaging import version
import transformers
from logbar import LogBar
@@ -36,6 +41,9 @@ class PatchError(Exception):
def replace_fused_blocks(model_type: str) -> bool:
+ if warn_if_public_api_transformers_unsupported("replace_fused_blocks()", logger):
+ return False
+
apply_model_class_patches(model_type)
cfg = MODEL_CONFIG.get(model_type)
@@ -60,7 +68,7 @@ def replace_fused_blocks(model_type: str) -> bool:
custom_class = getattr(custom_module, custom_class_name)
setattr(orig_module, orig_class_name, custom_class)
- if version.parse(transformers.__version__) >= version.parse("5.0.0"):
+ if version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION):
from transformers import conversion_mapping
if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
@@ -89,6 +97,9 @@ def check_model_compatibility(model: nn.Module) -> bool:
if model_type not in MODEL_CONFIG:
return False
+ if not is_supported_transformers_version():
+ return False
+
min_ver = MODEL_CONFIG[model_type].get("min_transformers_version")
current_ver = version.parse(transformers.__version__)
if min_ver and current_ver < version.parse(min_ver):
@@ -106,6 +117,9 @@ def convert_model(
cleanup_original: bool = False,
max_layers: int | None = None,
) -> bool:
+ if warn_if_public_api_transformers_unsupported("convert_model()", logger):
+ return False
+
if max_layers is not None and max_layers < 1:
raise ValueError("max_layers must be >= 1 when provided")
diff --git a/defuser/model_registry.py b/defuser/model_registry.py
index 2e42505..967f2fb 100644
--- a/defuser/model_registry.py
+++ b/defuser/model_registry.py
@@ -7,6 +7,7 @@
from transformers.core_model_loading import WeightConverter, WeightRenaming
from defuser.checkpoint_ops import OwnedChunk, SplitFusedExpertDownProj, SplitFusedExpertGateUpProj
+from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION
class PATCH(str, Enum):
@@ -15,7 +16,7 @@ class PATCH(str, Enum):
MODEL_CONFIG = {
"mixtral": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
@@ -43,7 +44,7 @@ class PATCH(str, Enum):
],
},
"qwen2_moe": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
@@ -52,7 +53,7 @@ class PATCH(str, Enum):
],
},
"qwen3_moe": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
# structure path only replaces modeling structure
PATCH.REPLACE_MODULE: [
(
@@ -62,13 +63,13 @@ class PATCH(str, Enum):
],
},
"qwen3_5_moe": {
- "min_transformers_version": "5.2.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"qwen3_5_moe_text": {
- "min_transformers_version": "5.2.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"qwen3_next": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.qwen3_next.modeling_qwen3_next.Qwen3NextSparseMoeBlock",
@@ -77,7 +78,7 @@ class PATCH(str, Enum):
],
},
"qwen3_omni_moe": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock",
@@ -86,7 +87,7 @@ class PATCH(str, Enum):
],
},
"glm4_moe": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.glm4_moe.modeling_glm4_moe.Glm4MoeMoE",
@@ -95,7 +96,7 @@ class PATCH(str, Enum):
],
},
"glm4v": {
- "min_transformers_version": "5.0.0",
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.glm4v.modeling_glm4v.Glm4vTextMLP",
diff --git a/defuser/utils/common.py b/defuser/utils/common.py
index 8d49a4c..30620ae 100644
--- a/defuser/utils/common.py
+++ b/defuser/utils/common.py
@@ -8,12 +8,14 @@
from functools import lru_cache
import re
+from packaging import version
# Match module paths like "...layers.0..." and capture the numeric layer index.
_LAYER_NAME_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)")
TRUTHFUL = {"1", "true", "yes", "on", "y"}
+MIN_SUPPORTED_TRANSFORMERS_VERSION = "5.3.0"
def env_flag(name: str, default: str | bool | None = "0") -> bool:
@@ -32,11 +34,29 @@ def env_flag(name: str, default: str | bool | None = "0") -> bool:
@lru_cache(None)
def is_transformers_version_greater_or_equal_5():
import transformers
- from packaging import version
return version.parse(transformers.__version__) >= version.parse("5.0.0")
+def is_supported_transformers_version() -> bool:
+ import transformers
+
+ return version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION)
+
+
+def warn_if_public_api_transformers_unsupported(api_name: str, logger) -> bool:
+ import transformers
+
+ if is_supported_transformers_version():
+ return False
+
+ logger.warning(
+ f"Defuser public API `{api_name}` requires transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}. "
+ f"Current version is {transformers.__version__}. This call is unsupported and will be skipped."
+ )
+ return True
+
+
def is_within_max_layers(module_name: str, max_layers: int | None) -> bool:
"""Return True when module path is within requested layer limit."""
if max_layers is None:
diff --git a/defuser/utils/hf.py b/defuser/utils/hf.py
index b4168f9..f8b7d1f 100644
--- a/defuser/utils/hf.py
+++ b/defuser/utils/hf.py
@@ -17,6 +17,7 @@
from transformers import AutoConfig
from defuser.model_registry import MODEL_CONFIG
+from defuser.utils.common import warn_if_public_api_transformers_unsupported
logger = LogBar(__name__)
@@ -73,6 +74,9 @@ def get_file_path_via_model_name(model_or_path: str, file_name):
def pre_check_config(model_name: str | torch.nn.Module):
+ if warn_if_public_api_transformers_unsupported("pre_check_config()", logger):
+ return False
+
if isinstance(model_name, str):
config = AutoConfig.from_pretrained(model_name)
elif isinstance(model_name, torch.nn.Module):
diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py
index a1f68dd..ee5ec58 100644
--- a/tests/test_convert_model.py
+++ b/tests/test_convert_model.py
@@ -36,8 +36,11 @@
Qwen3OmniMoeThinkerTextSparseMoeBlock,
)
+import defuser.defuser as defuser_api
+import defuser.utils.hf as hf_utils
from defuser import convert_model, replace_fused_blocks
from defuser.checkpoint_ops import OwnedChunk, SplitFusedExpertGateUpProj
+from defuser.model_registry import MODEL_CONFIG
from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model
from defuser.modeling.unfused_moe.glm4_moe import LinearGlm4MoeMoE
from defuser.modeling.unfused_moe.mixtral import LinearMixtralSparseMoeBlock
@@ -45,6 +48,7 @@
from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock
from defuser.modeling.unfused_moe.qwen3_next import LinearQwen3NextSparseMoeBlock
from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock
+from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION
@@ -417,6 +421,56 @@ def test_replace_fused_blocks_returns_false_for_unregistered_model():
assert replace_fused_blocks("unsupported_model_type") is False
+def test_model_registry_requires_transformers_5_3_or_newer():
+ assert {cfg["min_transformers_version"] for cfg in MODEL_CONFIG.values()} == {MIN_SUPPORTED_TRANSFORMERS_VERSION}
+
+
+def test_replace_fused_blocks_warns_on_unsupported_transformers(monkeypatch):
+ warnings = []
+
+ monkeypatch.setattr(defuser_api.transformers, "__version__", "5.2.9")
+ monkeypatch.setattr(defuser_api.logger, "warning", warnings.append)
+
+ assert defuser_api.replace_fused_blocks("mixtral") is False
+ assert len(warnings) == 1
+ assert "replace_fused_blocks()" in warnings[0]
+ assert f"transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}" in warnings[0]
+
+
+def test_convert_model_warns_on_unsupported_transformers(monkeypatch):
+ warnings = []
+
+ class DummyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.config = SimpleNamespace(model_type="mixtral")
+
+ monkeypatch.setattr(defuser_api.transformers, "__version__", "5.2.9")
+ monkeypatch.setattr(defuser_api.logger, "warning", warnings.append)
+
+ assert defuser_api.convert_model(DummyModel()) is False
+ assert len(warnings) == 1
+ assert "convert_model()" in warnings[0]
+ assert f"transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}" in warnings[0]
+
+
+def test_pre_check_config_warns_on_unsupported_transformers(monkeypatch):
+ warnings = []
+
+ class DummyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.config = SimpleNamespace(model_type="mixtral")
+
+ monkeypatch.setattr(hf_utils.transformers, "__version__", "5.2.9")
+ monkeypatch.setattr(hf_utils.logger, "warning", warnings.append)
+
+ assert hf_utils.pre_check_config(DummyModel()) is False
+ assert len(warnings) == 1
+ assert "pre_check_config()" in warnings[0]
+ assert f"transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}" in warnings[0]
+
+
def test_qwen3_5_moe():
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
From 8eb1c24212c6faf3c9782a2e422e596e39f07f52 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Wed, 18 Mar 2026 13:51:22 +0000
Subject: [PATCH 4/6] cleanup
---
README.md | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 0ca1ed1..d4499fe 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,8 @@
Defuser converts select Hugging Face Transformers `5.3.0+` fused or stacked MoE and MLP blocks back into plain, per-expert `nn.Linear` modules. It keeps the forward math intact while exposing individual projections again so quantizers, activation capture, debugging hooks, and checkpoint tooling can work against a simple module layout instead of fused expert tensors.
+Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range.
+
## Purpose
Defuser exists for cases where newer Transformers modeling code optimizes model structure in ways that are good for runtime, but harder for tooling that needs direct access to individual projections.
@@ -32,7 +34,7 @@ from defuser import convert_model, replace_fused_blocks
- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints.
-- Defuser only supports `transformers>=5.3.0`. Older versions log a warning on these public APIs and are skipped as unsupported.
+- Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported.
## Supported Models
From 954eb56d9c13f13e4acf56dd4377599daa57713a Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Wed, 18 Mar 2026 14:08:01 +0000
Subject: [PATCH 5/6] regression patch
---
defuser/model_registry.py | 4 ++--
defuser/modeling/replace_modules.py | 11 ++++++++---
2 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/defuser/model_registry.py b/defuser/model_registry.py
index 65d9038..b9ca11e 100644
--- a/defuser/model_registry.py
+++ b/defuser/model_registry.py
@@ -117,10 +117,10 @@ class PATCH(str, Enum):
],
},
"gpt_oss": {
- # "min_transformers_version": "", # When `gpt_oss` was added to `transformers`, it was already implemented as "fused experts."
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"llama4": {
- # "min_transformers_version": "", # When `llama4` was added to `transformers`, it was already implemented as "fused experts."
+ "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.DEFUSE: "defuser.modeling.fused_moe.llama4",
},
}
diff --git a/defuser/modeling/replace_modules.py b/defuser/modeling/replace_modules.py
index 34df127..8f0cc62 100644
--- a/defuser/modeling/replace_modules.py
+++ b/defuser/modeling/replace_modules.py
@@ -267,11 +267,16 @@ def apply_replacements(
_log_first_moe_block(model, "before replacement")
- # Custom replacements first
+ # Models with an explicit DEFUSE workflow (for example llama4) need their
+ # custom replacements applied directly. All other models should still keep
+ # the legacy behavior: auto-detect fused experts first, then apply any
+ # registered custom replacements.
if is_model_patchable(model):
_apply_custom_replacements(model, max_layers=max_layers)
- elif auto_detect_moe and is_transformers_version_greater_or_equal_5():
- _handle_moe_modules(model)
+ else:
+ if auto_detect_moe and is_transformers_version_greater_or_equal_5():
+ _handle_moe_modules(model)
+ _apply_custom_replacements(model, max_layers=max_layers)
_log_first_moe_block(model, "after replacement")
From 0edb50d0dc20a265e2d68102066eb7b1211836c6 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Wed, 18 Mar 2026 14:18:09 +0000
Subject: [PATCH 6/6] refractor llama4 patching to be generic
---
defuser/checkpoint_ops.py | 7 ++
defuser/defuser.py | 3 +
defuser/model_registry.py | 9 +-
defuser/modeling/fused_moe/__init__.py | 0
defuser/modeling/fused_moe/llama4.py | 103 -----------------
defuser/modeling/glm4v.py | 1 +
defuser/modeling/model_patches.py | 6 +
defuser/modeling/moe_experts_interface.py | 109 ++++++++++++++++--
defuser/modeling/replace_modules.py | 49 ++------
defuser/modeling/unfused_moe/glm4_moe.py | 8 +-
defuser/modeling/unfused_moe/mixtral.py | 5 +-
defuser/modeling/unfused_moe/qwen2_moe.py | 4 +-
defuser/modeling/unfused_moe/qwen3_moe.py | 4 +-
defuser/modeling/unfused_moe/qwen3_next.py | 4 +-
.../modeling/unfused_moe/qwen3_omni_moe.py | 4 +-
defuser/modeling/update_module.py | 1 +
defuser/utils/common.py | 3 +
defuser/utils/device.py | 10 +-
defuser/utils/hf.py | 4 +-
defuser/utils/model.py | 10 +-
tests/test_convert_model.py | 21 ++++
21 files changed, 187 insertions(+), 178 deletions(-)
delete mode 100644 defuser/modeling/fused_moe/__init__.py
delete mode 100644 defuser/modeling/fused_moe/llama4.py
diff --git a/defuser/checkpoint_ops.py b/defuser/checkpoint_ops.py
index 2309fb3..63a9887 100644
--- a/defuser/checkpoint_ops.py
+++ b/defuser/checkpoint_ops.py
@@ -24,6 +24,7 @@ def __init__(self, expert_dim: int = 0, proj_dim: int = 0):
@staticmethod
def _expert_target(pattern: str, expert_idx: int) -> str:
+ """Expand one target pattern into the per-expert key for ``expert_idx``."""
if "*" in pattern:
return pattern.replace("*", str(expert_idx))
return pattern.replace(".0.", f".{expert_idx}.", 1)
@@ -32,6 +33,7 @@ def _expert_target(pattern: str, expert_idx: int) -> str:
def convert(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
+ """Split one fused gate/up tensor into cloned per-expert gate and up tensors."""
if len(target_patterns) != 2:
raise ValueError("SplitFusedExpertGateUpProj expects exactly two target patterns.")
@@ -50,6 +52,7 @@ def convert(
@property
def reverse_op(self) -> ConversionOps:
+ """Return the inverse merge op used when writing fused checkpoints."""
return MergeSplitExpertGateUpProj()
@@ -66,6 +69,7 @@ def __init__(self, expert_dim: int = 0, proj_dim: int = 0):
def convert(
self, input_dict: dict[str, list[torch.Tensor]], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
+ """Merge per-expert gate/up tensors back into one fused expert tensor."""
if len(source_patterns) != 2:
raise ValueError("MergeSplitExpertGateUpProj expects exactly two source patterns.")
if len(target_patterns) != 1:
@@ -102,6 +106,7 @@ def __init__(self, expert_dim: int = 0):
@staticmethod
def _expert_target(pattern: str, expert_idx: int) -> str:
+ """Expand one target pattern into the per-expert key for ``expert_idx``."""
if "*" in pattern:
return pattern.replace("*", str(expert_idx))
return pattern.replace(".0.", f".{expert_idx}.", 1)
@@ -110,6 +115,7 @@ def _expert_target(pattern: str, expert_idx: int) -> str:
def convert(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
+ """Split one fused expert down projection into cloned per-expert tensors."""
if len(target_patterns) != 1:
raise ValueError("SplitFusedExpertDownProj expects a single target pattern.")
@@ -126,4 +132,5 @@ def convert(
@property
def reverse_op(self) -> ConversionOps:
+ """Return the inverse merge op used when writing fused checkpoints."""
return MergeModulelist(dim=self.expert_dim)
diff --git a/defuser/defuser.py b/defuser/defuser.py
index 9a9ccbe..345e909 100644
--- a/defuser/defuser.py
+++ b/defuser/defuser.py
@@ -22,6 +22,7 @@
logger = LogBar(__name__)
def get_checkpoint_conversion_mapping(model_type):
+ """Return Defuser's checkpoint remapping rules for one registered model type."""
from transformers import conversion_mapping
if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
@@ -41,6 +42,7 @@ class PatchError(Exception):
def replace_fused_blocks(model_type: str) -> bool:
+ """Patch supported HF model classes so future loads instantiate defused blocks."""
if warn_if_public_api_transformers_unsupported("replace_fused_blocks()", logger):
return False
@@ -117,6 +119,7 @@ def convert_model(
cleanup_original: bool = False,
max_layers: int | None = None,
) -> bool:
+ """Convert one loaded model in place from fused experts to defused modules."""
if warn_if_public_api_transformers_unsupported("convert_model()", logger):
return False
diff --git a/defuser/model_registry.py b/defuser/model_registry.py
index b9ca11e..9a4d7f0 100644
--- a/defuser/model_registry.py
+++ b/defuser/model_registry.py
@@ -12,7 +12,7 @@
class PATCH(str, Enum):
REPLACE_MODULE = "replace_module"
- DEFUSE = "defuse"
+ EXPERTS_DEFUSE = "experts_defuse"
MODEL_CONFIG = {
@@ -121,6 +121,11 @@ class PATCH(str, Enum):
},
"llama4": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
- PATCH.DEFUSE: "defuser.modeling.fused_moe.llama4",
+ PATCH.EXPERTS_DEFUSE: [
+ {
+ "module_class": "transformers.models.llama4.modeling_llama4.Llama4TextExperts",
+ "forward_impl": "batched_input",
+ }
+ ],
},
}
diff --git a/defuser/modeling/fused_moe/__init__.py b/defuser/modeling/fused_moe/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/defuser/modeling/fused_moe/llama4.py b/defuser/modeling/fused_moe/llama4.py
deleted file mode 100644
index 9a4cc05..0000000
--- a/defuser/modeling/fused_moe/llama4.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# Adapted from intel/auto-round
-# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/fused_moe/llama4.py
-
-import torch
-import transformers
-from packaging import version
-
-transformers_version = version.parse(transformers.__version__)
-if transformers_version < version.parse("5.0.0"):
- from transformers.modeling_utils import no_init_weights
-else:
- from transformers.initialization import no_init_weights
-from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4TextMLP
-
-from defuser.modeling.replace_modules import ReplacementModuleBase
-from defuser.utils.model import _update_parameter, unsupported_meta_device
-from defuser.utils.device import clear_memory
-
-
-class SequentialLlama4TextExperts(torch.nn.ModuleList):
- def __init__(self, config, original):
- self.num_experts = original.gate_up_proj.shape[0]
- target_device = next(original.parameters()).device
- with no_init_weights(), torch.device("meta"):
- super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
-
- def _materialize_weights(self, original) -> None:
- if not unsupported_meta_device(original):
- intermediate_size = original.down_proj.shape[1]
-
- for i in range(self.num_experts):
- gate_up = original.gate_up_proj[i]
- down = original.down_proj[i]
- gate_proj = gate_up[:, :intermediate_size]
- up_proj = gate_up[:, intermediate_size:]
- _update_parameter(self[i].gate_proj, "weight", gate_proj.t().contiguous())
- _update_parameter(self[i].up_proj, "weight", up_proj.t().contiguous())
- _update_parameter(self[i].down_proj, "weight", down.t().contiguous())
- del gate_up, down, gate_proj, up_proj
- original.to_empty(device="meta") # release original experts parameters
-
-
-class SequentialLlama4TextMoe(ReplacementModuleBase):
- def __init__(self, original, config):
- super().__init__(original)
- config = config.text_config
- self.top_k = config.num_experts_per_tok
- self.hidden_dim = config.hidden_size
- self.num_experts = config.num_local_experts
- with torch.device("meta"):
- self.experts = SequentialLlama4TextExperts(config, original.experts)
-
- self.router = original.router
- self.shared_expert = original.shared_expert
-
- def forward(self, hidden_states: torch.Tensor):
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
- router_logits = self.router(hidden_states)
- if isinstance(router_logits, tuple):
- router_scores, router_logits = router_logits
- router_scores = router_scores.t()
- else:
- # transformers < 4.54.0 only returns router_logits
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
-
- router_scores = (
- torch.full_like(router_logits, float("-inf"))
- .scatter_(1, router_indices, router_top_value)
- .transpose(0, 1)
- )
- router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
-
- out = self.shared_expert(hidden_states)
-
- # Only process experts that actually received tokens (expert_hit pattern),
- # skipping experts with zero routing weight to save compute during calibration.
- with torch.no_grad():
- expert_hit = torch.greater(router_scores.sum(dim=-1), 0).nonzero()
- for expert_idx in expert_hit:
- expert_idx = expert_idx[0]
- out += self.experts[expert_idx](hidden_states) * router_scores[expert_idx].reshape(-1, 1)
-
- return out, router_logits
-
- @classmethod
- def original_module_class(cls) -> str:
- """Return the class name of the module this replaces."""
- return "Llama4TextMoe"
-
- @classmethod
- def from_original(
- cls,
- original: torch.nn.Module,
- config: Llama4Config,
- **kwargs,
- ) -> "SequentialLlama4TextMoe":
- """Create an instance from the original module."""
- return cls(original, config)
-
- def _materialize_weights(self) -> None:
- original = self._get_original_module()
- self.experts._materialize_weights(original.experts)
- clear_memory()
diff --git a/defuser/modeling/glm4v.py b/defuser/modeling/glm4v.py
index 254ea16..462b0d4 100644
--- a/defuser/modeling/glm4v.py
+++ b/defuser/modeling/glm4v.py
@@ -14,6 +14,7 @@ def __init__(self, config):
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
+ """Reproduce the original fused GLM4V text MLP using split linear layers."""
gate = self.gate_proj(hidden_states)
up = self.up_proj(hidden_states)
# Match the original fused `gate_up_proj.chunk(2, dim=-1)` activation path.
diff --git a/defuser/modeling/model_patches.py b/defuser/modeling/model_patches.py
index 1982588..b1202bf 100644
--- a/defuser/modeling/model_patches.py
+++ b/defuser/modeling/model_patches.py
@@ -18,6 +18,7 @@
def register_model_class_patch(model_type: str):
+ """Register a one-time class patch that runs before model construction."""
def decorator(func: Callable):
_MODEL_CLASS_PATCH_REGISTRY[model_type] = func
return func
@@ -26,6 +27,7 @@ def decorator(func: Callable):
def register_model_patch(model_type: str):
+ """Register a runtime patch that runs on an instantiated model object."""
def decorator(func: Callable):
_MODEL_PATCH_REGISTRY[model_type] = func
return func
@@ -34,6 +36,7 @@ def decorator(func: Callable):
@register_model_class_patch("qwen3_omni_moe")
def patch_qwen3_omni_text_class() -> list[str]:
+ """Teach HF init code how to initialize unfused qwen3-omni thinker experts."""
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoePreTrainedModel
from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock
orig_init_weights = Qwen3OmniMoePreTrainedModel._init_weights
@@ -66,6 +69,7 @@ def patched_init_weights(self, module):
@register_model_patch("qwen3_omni_moe")
def patch_qwen3_omni_text_runtime(model) -> list[str]:
+ """Restore text-only ``forward`` and ``generate`` behavior after class swapping."""
model_cls = type(model)
if not getattr(model_cls, "__module__", "").startswith("transformers.models.qwen3_omni_moe."):
return []
@@ -95,6 +99,7 @@ def forward(self, *args, **kwargs):
def apply_model_class_patches(model_type) -> list[str]:
+ """Run any registered pre-construction patch for ``model_type``."""
patch_model_class = _MODEL_CLASS_PATCH_REGISTRY.get(model_type)
if patch_model_class is None:
return []
@@ -106,6 +111,7 @@ def apply_model_class_patches(model_type) -> list[str]:
def apply_model_patches(model) -> list[str]:
+ """Run any registered runtime patch for the instantiated ``model``."""
config = getattr(model, "config", None)
model_type = getattr(config, "model_type", None)
patch = _MODEL_PATCH_REGISTRY.get(model_type)
diff --git a/defuser/modeling/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py
index 7756042..99e0f8d 100644
--- a/defuser/modeling/moe_experts_interface.py
+++ b/defuser/modeling/moe_experts_interface.py
@@ -25,10 +25,13 @@
# Now the model uses linear_loop forward which supports quantized nn.Linear layers
"""
+from types import MethodType
+
import torch
from logbar import LogBar
from torch import nn
+from defuser.model_registry import MODEL_CONFIG, PATCH
from defuser.utils.device import clear_memory, to_meta
from defuser import DEBUG_ON
@@ -45,6 +48,7 @@
# Expert implementation name - change this if transformers want to use a different name
LINEAR_LOOP_IMPL = "linear_loop"
+BATCHED_INPUT_IMPL = "batched_input"
# Known expert projection patterns for reference
# These are used as hints when auto-detection needs to infer projection properties
@@ -79,6 +83,17 @@ def is_linear_loop_available() -> bool:
return HAS_EXPERTS_INTERFACE
+def _apply_expert_gate(module: nn.Module, gate_out: torch.Tensor, up_out: torch.Tensor) -> torch.Tensor:
+ """Apply the expert's activation path using either a custom gate hook or ``act_fn``."""
+ if hasattr(module, "_apply_gate"):
+ return module._apply_gate(torch.cat([gate_out, up_out], dim=-1))
+
+ act_fn = getattr(module, "act_fn", None)
+ if act_fn is None:
+ raise AttributeError(f"{module.__class__.__name__} must define either `_apply_gate` or `act_fn`.")
+ return act_fn(gate_out) * up_out
+
+
def linear_loop_experts_forward(
self: nn.Module,
hidden_states: torch.Tensor,
@@ -152,13 +167,7 @@ def linear_loop_experts_forward(
expert = getattr(self, str(expert_idx))
gate_out = expert.gate_proj(expert_input) # (num_samples, intermediate_dim)
up_out = expert.up_proj(expert_input) # (num_samples, intermediate_dim)
-
- # Apply gating
- if hasattr(self, "_apply_gate"):
- gate_up_out = torch.cat([gate_out, up_out], dim=-1)
- gated_out = self._apply_gate(gate_up_out) # (num_samples, intermediate_dim)
- else:
- gated_out = self.act_fn(gate_out) * up_out # (num_samples, intermediate_dim)
+ gated_out = _apply_expert_gate(self, gate_out, up_out)
# Down projection
expert_out = expert.down_proj(gated_out) # (num_samples, hidden_dim)
@@ -180,6 +189,31 @@ def linear_loop_experts_forward(
return final_hidden_states
+def batched_input_experts_forward(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Run defused experts for models that feed experts as expert-major input batches.
+
+ Llama4 is the current example: upstream code repeats and pre-weights tokens,
+ then calls ``experts(hidden_states)`` where the leading dimension is laid out
+ as ``[expert0_tokens, expert1_tokens, ...]``. This forward keeps that public
+ contract while still executing per-expert ``nn.Linear`` modules internally.
+ """
+ if DEBUG_ON: logger.debug(f"Using {BATCHED_INPUT_IMPL} experts forward for {self.__class__.__name__}")
+
+ hidden_dim = hidden_states.size(-1)
+ expert_inputs = hidden_states.view(self.num_experts, -1, hidden_dim)
+ expert_outputs = []
+
+ for expert_idx in range(self.num_experts):
+ expert = getattr(self, str(expert_idx))
+ expert_input = expert_inputs[expert_idx]
+ gate_out = expert.gate_proj(expert_input)
+ up_out = expert.up_proj(expert_input)
+ gated_out = _apply_expert_gate(self, gate_out, up_out)
+ expert_outputs.append(expert.down_proj(gated_out))
+
+ return torch.stack(expert_outputs, dim=0).reshape(-1, hidden_dim)
+
+
def register_linear_loop_experts() -> bool:
"""Register the linear_loop experts implementation with transformers.
@@ -201,6 +235,45 @@ def register_linear_loop_experts() -> bool:
return True
+def _model_experts_defuse_specs(model: nn.Module) -> list[dict]:
+ """Return declarative experts-defusion specs for the current model type."""
+ config = getattr(model, "config", None)
+ model_type = getattr(config, "model_type", None)
+ if model_type is None:
+ return []
+
+ specs = MODEL_CONFIG.get(model_type, {}).get(PATCH.EXPERTS_DEFUSE, [])
+ if isinstance(specs, dict):
+ return [specs]
+ return list(specs)
+
+
+def _module_class_path(module: nn.Module) -> str:
+ """Return a stable import-style class path for matching model specs."""
+ return f"{module.__class__.__module__}.{module.__class__.__name__}"
+
+
+def _matching_experts_defuse_spec(module: nn.Module, specs: list[dict]) -> dict | None:
+ """Find the first declarative experts-defusion spec that matches ``module``."""
+ module_path = _module_class_path(module)
+ module_name = module.__class__.__name__
+
+ for spec in specs:
+ target = spec.get("module_class")
+ if target in {module_path, module_name}:
+ return spec
+ return None
+
+
+def _install_instance_forward(module: nn.Module, implementation: str) -> None:
+ """Attach a generic forward implementation directly to one experts module."""
+ if implementation == BATCHED_INPUT_IMPL:
+ module.forward = MethodType(batched_input_experts_forward, module)
+ return
+
+ raise ValueError(f"Unsupported experts forward implementation: {implementation}")
+
+
def _detect_expert_projections(module: nn.Module) -> dict[str, dict]:
"""Detect which expert projections exist in the module.
@@ -613,11 +686,26 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L
"This requires transformers >= 5.0.0 with MOE integration support."
)
- # Unfuse all fused experts modules (only those supporting @use_experts_implementation)
+ # Unfuse all fused experts modules, including models that need a generic
+ # instance-level forward override instead of transformers' decorator path.
unfused_modules = []
+ decorated_unfused_modules = []
+ experts_defuse_specs = _model_experts_defuse_specs(model)
for name, module in model.named_modules():
+ spec = _matching_experts_defuse_spec(module, experts_defuse_specs)
+ if spec is not None and _unfuse_experts_weights_inplace(
+ module,
+ check_decorator=False,
+ projection_names=spec.get("projection_names"),
+ ):
+ _install_instance_forward(module, spec["forward_impl"])
+ unfused_modules.append(name)
+ if DEBUG_ON: logger.debug(f"[MoE Prep] Unfused '{name}' via declarative spec")
+ continue
+
if _unfuse_experts_weights_inplace(module):
unfused_modules.append(name)
+ decorated_unfused_modules.append(name)
if DEBUG_ON: logger.debug(f"[MoE Prep] Unfused '{name}'")
# Only set config if we actually unfused something
@@ -627,8 +715,9 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L
if DEBUG_ON: logger.info(f"[MoE Prep] Unfused {len(unfused_modules)} MOE experts modules")
clear_memory()
- # Set config for linear_loop forward
- if hasattr(model, "config"):
+ # Set config for linear_loop forward only when the upstream model uses
+ # the decorator-based experts interface.
+ if decorated_unfused_modules and hasattr(model, "config"):
saved_impl = getattr(model.config, "experts_implementation", None)
impl_to_set = saved_impl if saved_impl else implementation
model.config._experts_implementation = impl_to_set
diff --git a/defuser/modeling/replace_modules.py b/defuser/modeling/replace_modules.py
index 8f0cc62..dfc0385 100644
--- a/defuser/modeling/replace_modules.py
+++ b/defuser/modeling/replace_modules.py
@@ -6,7 +6,6 @@
# Adapted from intel/auto-round
# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/fused_moe/replace_modules.py
-import importlib
import weakref
from abc import ABC, abstractmethod
from dataclasses import dataclass
@@ -20,34 +19,12 @@
from defuser import DEBUG_ON
-from defuser.model_registry import MODEL_CONFIG, PATCH
-
logger = LogBar(__name__)
-def is_model_patchable(model: torch.nn.Module) -> bool:
- """Check if the model has a custom replacement registered via MODEL_CONFIG.
-
- Returns True if the model's model_type matches a key in MODEL_CONFIG.
- """
- if hasattr(model, "config") and hasattr(model.config, "model_type"):
- return model.config.model_type in MODEL_CONFIG and PATCH.DEFUSE in MODEL_CONFIG[model.config.model_type]
- return False
-
-
-def _import_required_replacements(model: torch.nn.Module) -> None:
- """Import replacement modules required for the model's defuse workflow."""
- if not is_model_patchable(model):
- return
- model_type = model.config.model_type
- module_path = MODEL_CONFIG[model_type].get(PATCH.DEFUSE)
- if not module_path:
- return
- importlib.import_module(module_path)
- logger.debug(f"Loaded replacement module for {model_type}: {module_path}")
-
-
def materialize_model(model: torch.nn.Module) -> None:
+ """Materialize any deferred replacement weights attached to ``model``."""
+
def _materialize_module(module: torch.nn.Module) -> None:
if isinstance(module, ReplacementModuleBase):
module.materialize_weights()
@@ -73,6 +50,8 @@ def _materialize_module(module: torch.nn.Module) -> None:
def release_original_module_(model: torch.nn.Module) -> None:
+ """Drop references to original fused modules after replacement is complete."""
+
def _clear_source_module(module: torch.nn.Module) -> None:
if isinstance(module, ReplacementModuleBase):
module.release_original_module()
@@ -263,20 +242,16 @@ def apply_replacements(
Returns:
The model with modules replaced.
"""
- _import_required_replacements(model)
-
_log_first_moe_block(model, "before replacement")
- # Models with an explicit DEFUSE workflow (for example llama4) need their
- # custom replacements applied directly. All other models should still keep
- # the legacy behavior: auto-detect fused experts first, then apply any
- # registered custom replacements.
- if is_model_patchable(model):
- _apply_custom_replacements(model, max_layers=max_layers)
- else:
- if auto_detect_moe and is_transformers_version_greater_or_equal_5():
- _handle_moe_modules(model)
- _apply_custom_replacements(model, max_layers=max_layers)
+ # Run the generic MoE tensor defusion pass first so models with supported
+ # fused experts can stay on their upstream module structure.
+ if auto_detect_moe and is_transformers_version_greater_or_equal_5():
+ _handle_moe_modules(model)
+
+ # Fall back to replacement modules for any models that still need a custom
+ # structural wrapper after the generic experts pass.
+ _apply_custom_replacements(model, max_layers=max_layers)
_log_first_moe_block(model, "after replacement")
diff --git a/defuser/modeling/unfused_moe/glm4_moe.py b/defuser/modeling/unfused_moe/glm4_moe.py
index 24b1299..61353f1 100644
--- a/defuser/modeling/unfused_moe/glm4_moe.py
+++ b/defuser/modeling/unfused_moe/glm4_moe.py
@@ -12,6 +12,7 @@ class LinearGlm4MoeMoE(nn.Module):
"""
def __init__(self, config):
+ """Build a GLM4 MoE block that exposes one explicit module per routed expert."""
super().__init__()
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP, Glm4MoeTopkRouter
@@ -34,6 +35,7 @@ def __init__(self, config):
self.top_k = config.num_experts_per_tok
def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Replicate GLM4's grouped expert selection and top-k weight normalization."""
# Keep the expert selection identical to the upstream GLM4 MoE router.
router_logits = router_logits.sigmoid()
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
@@ -59,10 +61,7 @@ def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Te
return topk_indices, topk_weights
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
- r"""
- CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
- to not have to do a loop here (deepseek has 256 experts soooo yeah).
- """
+ """Run the routed experts one by one and accumulate their weighted outputs."""
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
expert_mask = expert_mask.permute(2, 0, 1)
@@ -85,6 +84,7 @@ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weig
return final_hidden_states.type(hidden_states.dtype)
def forward(self, hidden_states):
+ """Apply routed experts plus the shared expert branch, matching upstream GLM4 MoE."""
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = self.gate(hidden_states)
diff --git a/defuser/modeling/unfused_moe/mixtral.py b/defuser/modeling/unfused_moe/mixtral.py
index bf62878..edfbf20 100644
--- a/defuser/modeling/unfused_moe/mixtral.py
+++ b/defuser/modeling/unfused_moe/mixtral.py
@@ -10,6 +10,8 @@
class MixtralBlockSparseTop2MLP(nn.Module):
+ """Per-expert Mixtral MLP with explicit gate, up, and down projections."""
+
def __init__(self, config: MixtralConfig):
super().__init__()
self.hidden_size = config.hidden_size
@@ -20,6 +22,7 @@ def __init__(self, config: MixtralConfig):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
+ """Apply the standard SwiGLU-style Mixtral expert math."""
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
@@ -53,7 +56,7 @@ def __init__(self, config):
self.jitter_noise = config.router_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
+ """Match HF Mixtral MoE routing while executing explicit per-expert modules."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
diff --git a/defuser/modeling/unfused_moe/qwen2_moe.py b/defuser/modeling/unfused_moe/qwen2_moe.py
index 8a49df6..61ef68e 100644
--- a/defuser/modeling/unfused_moe/qwen2_moe.py
+++ b/defuser/modeling/unfused_moe/qwen2_moe.py
@@ -9,6 +9,8 @@
class LinearQwen2MoeSparseMoeBlock(nn.Module):
+ """Qwen2 MoE block rewritten to expose one ``nn.Module`` per expert."""
+
def __init__(self, config):
super().__init__()
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP, Qwen2MoeTopKRouter
@@ -27,7 +29,7 @@ def __init__(self, config):
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
+ """Route tokens exactly like HF Qwen2 MoE, then run explicit expert modules."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
diff --git a/defuser/modeling/unfused_moe/qwen3_moe.py b/defuser/modeling/unfused_moe/qwen3_moe.py
index 90f4888..89d382d 100644
--- a/defuser/modeling/unfused_moe/qwen3_moe.py
+++ b/defuser/modeling/unfused_moe/qwen3_moe.py
@@ -11,6 +11,8 @@
class LinearQwen3MoeSparseMoeBlock(nn.Module):
+ """Qwen3 MoE block rewritten to expose one ``nn.Module`` per expert."""
+
def __init__(self, config):
super().__init__()
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP, Qwen3MoeTopKRouter
@@ -26,7 +28,7 @@ def __init__(self, config):
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
+ """Route tokens exactly like HF Qwen3 MoE, then run explicit expert modules."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
diff --git a/defuser/modeling/unfused_moe/qwen3_next.py b/defuser/modeling/unfused_moe/qwen3_next.py
index 10bfc2f..d9bdcc2 100644
--- a/defuser/modeling/unfused_moe/qwen3_next.py
+++ b/defuser/modeling/unfused_moe/qwen3_next.py
@@ -8,6 +8,8 @@
from torch.nn import functional as F
class LinearQwen3NextSparseMoeBlock(nn.Module):
+ """Qwen3-Next MoE block rewritten to expose one ``nn.Module`` per expert."""
+
def __init__(self, config):
super().__init__()
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP, Qwen3NextTopKRouter
@@ -26,7 +28,7 @@ def __init__(self, config):
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
+ """Route tokens exactly like HF Qwen3-Next MoE, then run explicit experts."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
diff --git a/defuser/modeling/unfused_moe/qwen3_omni_moe.py b/defuser/modeling/unfused_moe/qwen3_omni_moe.py
index 0fb6476..94c0d4f 100644
--- a/defuser/modeling/unfused_moe/qwen3_omni_moe.py
+++ b/defuser/modeling/unfused_moe/qwen3_omni_moe.py
@@ -7,6 +7,8 @@
import torch.nn as nn
class LinearQwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
+ """Text thinker MoE block for qwen3-omni with explicit per-expert modules."""
+
def __init__(self, config):
super().__init__()
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
@@ -28,7 +30,7 @@ def __init__(self, config):
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
+ """Route tokens exactly like HF qwen3-omni text MoE, then run explicit experts."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
diff --git a/defuser/modeling/update_module.py b/defuser/modeling/update_module.py
index c56aa63..6a4b2cf 100644
--- a/defuser/modeling/update_module.py
+++ b/defuser/modeling/update_module.py
@@ -14,6 +14,7 @@ def update_module(
cleanup_original: bool = True,
max_layers: int | None = None,
):
+ """Run Defuser's replacement pipeline and optionally drop original modules."""
model = apply_replacements(model, max_layers=max_layers)
if cleanup_original:
diff --git a/defuser/utils/common.py b/defuser/utils/common.py
index 30620ae..842797e 100644
--- a/defuser/utils/common.py
+++ b/defuser/utils/common.py
@@ -33,18 +33,21 @@ def env_flag(name: str, default: str | bool | None = "0") -> bool:
@lru_cache(None)
def is_transformers_version_greater_or_equal_5():
+ """Cache the coarse ``transformers>=5`` capability check used by fast paths."""
import transformers
return version.parse(transformers.__version__) >= version.parse("5.0.0")
def is_supported_transformers_version() -> bool:
+ """Return whether the installed transformers version is supported by Defuser's public API."""
import transformers
return version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION)
def warn_if_public_api_transformers_unsupported(api_name: str, logger) -> bool:
+ """Emit a single consistent warning when the runtime transformers version is too old."""
import transformers
if is_supported_transformers_version():
diff --git a/defuser/utils/device.py b/defuser/utils/device.py
index eaf800e..ebb118c 100644
--- a/defuser/utils/device.py
+++ b/defuser/utils/device.py
@@ -35,6 +35,7 @@ def clear_memory(
tensor: torch.Tensor | list[torch.Tensor] | None = None,
device_list: tuple | list | str | torch.device | None = None,
):
+ """Release Python references and flush accelerator caches after tensor surgery."""
# ------------------------
# Clear CPU-side references
# ------------------------
@@ -87,14 +88,7 @@ def clear_memory(
def unsupported_meta_device(model):
- """Checks if the model is a valid model.
-
- Args:
- model: The model to be checked.
-
- Returns:
- bool: True if the model is valid, False otherwise.
- """
+ """Return ``True`` when a model mixes meta and real devices in unsupported ways."""
target_device = None
for param in model.parameters():
if target_device is None:
diff --git a/defuser/utils/hf.py b/defuser/utils/hf.py
index f8b7d1f..af370f5 100644
--- a/defuser/utils/hf.py
+++ b/defuser/utils/hf.py
@@ -6,7 +6,6 @@
# Adapted from intel/auto-round
# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/unfused_moe/__init__.py
-import importlib
import os
from typing import Final
@@ -48,6 +47,7 @@ def modelscope_requested() -> bool:
def get_file_path_via_model_name(model_or_path: str, file_name):
+ """Resolve one HF or ModelScope file path from either a repo id or local directory."""
from huggingface_hub import hf_hub_download
# 1) local folder
@@ -74,6 +74,7 @@ def get_file_path_via_model_name(model_or_path: str, file_name):
def pre_check_config(model_name: str | torch.nn.Module):
+ """Quickly decide whether a model likely still needs Defuser's runtime conversion."""
if warn_if_public_api_transformers_unsupported("pre_check_config()", logger):
return False
@@ -101,6 +102,7 @@ def pre_check_config(model_name: str | torch.nn.Module):
with open(file_path, "r") as f:
index_data = json.load(f)
+ # A fused gate/up tensor in the weight map means runtime defusion is still needed.
model_keys = list(index_data.get("weight_map", {}).keys())
for key in model_keys:
if "gate_up_proj" in key:
diff --git a/defuser/utils/model.py b/defuser/utils/model.py
index 349e1ed..a7bbc76 100644
--- a/defuser/utils/model.py
+++ b/defuser/utils/model.py
@@ -14,20 +14,14 @@ def _update_parameter(
name: str,
data: torch.Tensor,
) -> None:
+ """Replace one module parameter while preserving its ``requires_grad`` flag."""
old_param = getattr(module, name)
new_param = torch.nn.Parameter(data, requires_grad=old_param.requires_grad)
setattr(module, name, new_param)
def unsupported_meta_device(model):
- """Checks if the model is a valid model for auto_round.
-
- Args:
- model: The model to be checked.
-
- Returns:
- bool: True if the model is valid, False otherwise.
- """
+ """Return ``True`` when mixed real/meta parameters make lazy materialization unsafe."""
target_device = None
for param in model.parameters():
if target_device is None:
diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py
index d688d2e..a2a4789 100644
--- a/tests/test_convert_model.py
+++ b/tests/test_convert_model.py
@@ -924,6 +924,27 @@ def test_llama4():
torch.testing.assert_close(expert0.down_proj.weight, expected_down)
+def test_llama4_experts_forward_matches_fused_math():
+ model = Llama4ForConditionalGeneration(_tiny_llama4_config())
+ fused_experts = model.language_model.model.layers[0].feed_forward.experts
+
+ hidden_states = torch.randn(fused_experts.num_experts * 5, model.config.text_config.hidden_size, dtype=torch.float32)
+ with torch.no_grad():
+ expected = fused_experts(hidden_states)
+
+ converted = convert_model(model, cleanup_original=False, max_layers=1)
+ assert converted
+
+ split_experts = model.language_model.model.layers[0].feed_forward.experts
+ _assert_unfused_expert_module(split_experts)
+ materialize_model(model.language_model.model.layers[0])
+ with torch.no_grad():
+ actual = split_experts(hidden_states)
+
+ # The batched-input generic path should preserve the original llama4 experts math.
+ torch.testing.assert_close(actual, expected)
+
+
def test_llama4_split_forward_matches_fused_math():
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP