Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 193 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,197 @@
<a href="https://github.com/ModelCloud/Defuser/blob/main/LICENSE"><img src="https://img.shields.io/pypi/l/Defuser"></a>
<a href="https://huggingface.co/modelcloud/"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face-ModelCloud-%23ff8811.svg"></a>
</p>
Model defuser helper for HF Transformers >= 5.0. In HF Transformers 5.x releases, many MoE modules became auto-stacked or auto-fused by new modeling code which has benefits but also downsides.

* Goal is to provide naive module/layer forwarding code for all models supported by HF transformers where run-time
weight and structure level optimizations such weight merging, stacking, fusing are reversed so the model is operating
in a simple naive state.
* There are cases, quantization libraries, where we need to run inference where module input/output needs to be
individually captured and this pkg can help complete this task.
Defuser converts select Hugging Face Transformers `5.3.0+` fused or stacked MoE and MLP blocks back into plain, per-expert `nn.Linear` modules. It keeps the forward math intact while exposing individual projections again so quantizers, activation capture, debugging hooks, and checkpoint tooling can work against a simple module layout instead of fused expert tensors.

Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range.

## Purpose

Defuser exists for cases where newer Transformers modeling code optimizes model structure in ways that are good for runtime, but harder for tooling that needs direct access to individual projections.

Depending on the model family, Defuser can:

- patch a supported model class before load so HF instantiates a defused block directly
- split fused tensors such as `gate_up_proj` into `gate_proj` + `up_proj`
- convert 3D expert tensors into numbered expert `nn.Linear` modules
- preserve the original fused math while presenting a naive module structure again

Public API:

```python
from defuser import convert_model, replace_fused_blocks
```

- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints.
- Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported.

## Supported Models

| Model type | Recommended entrypoint | Defused op performed |
| --- | --- | --- |
| `mixtral` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
| `qwen2_moe` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_moe` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_5_moe` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. |
| `qwen3_5_moe_text` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. |
| `qwen3_next` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |

## Workflow Summary

Use `replace_fused_blocks()` for model families that Defuser can patch before load:

```python
from defuser import replace_fused_blocks
from transformers import MixtralForCausalLM

replace_fused_blocks("mixtral")
model = MixtralForCausalLM.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1",
dtype="auto",
device_map="auto",
)
```

Use `convert_model()` for already loaded models whose expert tensors still need runtime defusion:

```python
from defuser import convert_model

converted = convert_model(model)
print(converted) # True when runtime defusion happened
```

## Real Qwen3.5 MoE Example

The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser supports `transformers>=5.3.0`.

### Fused Weights Before And After

Before `convert_model(model)`:

```text
+--------------------------------------------------------+---------------------------------------------+
| State dict key | Layout |
+--------------------------------------------------------+---------------------------------------------+
| model.language_model.layers.0.mlp.experts.gate_up_proj | fused gate+up tensor for all experts |
| | [num_experts, 2 * moe_intermediate, hidden] |
| model.language_model.layers.0.mlp.experts.down_proj | fused per-expert down tensor |
| | [num_experts, hidden, moe_intermediate] |
+--------------------------------------------------------+---------------------------------------------+
```

After `convert_model(model)`:

```text
+-----------------------------------------------------------------+--------------------------------------+
| State dict key | Layout |
+-----------------------------------------------------------------+--------------------------------------+
| model.language_model.layers.0.mlp.experts.0.gate_proj.weight | expert 0 gate projection |
| model.language_model.layers.0.mlp.experts.0.up_proj.weight | expert 0 up projection |
| model.language_model.layers.0.mlp.experts.0.down_proj.weight | expert 0 down projection |
| ... repeated for experts 1..N-1 | numbered expert nn.Linear modules |
+-----------------------------------------------------------------+--------------------------------------+
```

### Sample 1: Inspect The Conversion In Place

```python
from defuser import convert_model
from transformers import Qwen3_5MoeForConditionalGeneration

model_id = "Qwen/Qwen3.5-35B-A3B-Instruct"

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
model_id,
dtype="auto",
device_map="auto",
)

prefix = "model.language_model.layers.0.mlp.experts"

before = [name for name, _ in model.named_parameters() if name.startswith(prefix)]
print(before)
# [
# "model.language_model.layers.0.mlp.experts.gate_up_proj",
# "model.language_model.layers.0.mlp.experts.down_proj",
# ]

converted = convert_model(model)
assert converted is True

after = [name for name, _ in model.named_parameters() if name.startswith(prefix)]
print(after[:6])
# [
# "model.language_model.layers.0.mlp.experts.0.down_proj.weight",
# "model.language_model.layers.0.mlp.experts.0.gate_proj.weight",
# "model.language_model.layers.0.mlp.experts.0.up_proj.weight",
# "model.language_model.layers.0.mlp.experts.1.down_proj.weight",
# "model.language_model.layers.0.mlp.experts.1.gate_proj.weight",
# "model.language_model.layers.0.mlp.experts.1.up_proj.weight",
# ]
```

### Sample 2: Convert And Keep Using The Model Normally

```python
import torch

from defuser import convert_model
from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration

model_id = "Qwen/Qwen3.5-35B-A3B-Instruct"

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
model_id,
dtype="auto",
device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

convert_model(model)

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Explain mixture-of-experts routing in one sentence."},
],
}
]

inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)

with torch.inference_mode():
output_ids = model.generate(**inputs, max_new_tokens=64)

generated_ids = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
]
text = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
print(text)
```

After conversion, the first routed expert in the first MoE layer is exposed as normal submodules:

```python
expert0 = model.model.language_model.layers[0].mlp.experts[0]
print(type(expert0.gate_proj).__name__) # Linear
print(type(expert0.up_proj).__name__) # Linear
print(type(expert0.down_proj).__name__) # Linear
```
7 changes: 7 additions & 0 deletions defuser/checkpoint_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, expert_dim: int = 0, proj_dim: int = 0):

@staticmethod
def _expert_target(pattern: str, expert_idx: int) -> str:
"""Expand one target pattern into the per-expert key for ``expert_idx``."""
if "*" in pattern:
return pattern.replace("*", str(expert_idx))
return pattern.replace(".0.", f".{expert_idx}.", 1)
Expand All @@ -32,6 +33,7 @@ def _expert_target(pattern: str, expert_idx: int) -> str:
def convert(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
"""Split one fused gate/up tensor into cloned per-expert gate and up tensors."""
if len(target_patterns) != 2:
raise ValueError("SplitFusedExpertGateUpProj expects exactly two target patterns.")

Expand All @@ -50,6 +52,7 @@ def convert(

@property
def reverse_op(self) -> ConversionOps:
"""Return the inverse merge op used when writing fused checkpoints."""
return MergeSplitExpertGateUpProj()


Expand All @@ -66,6 +69,7 @@ def __init__(self, expert_dim: int = 0, proj_dim: int = 0):
def convert(
self, input_dict: dict[str, list[torch.Tensor]], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
"""Merge per-expert gate/up tensors back into one fused expert tensor."""
if len(source_patterns) != 2:
raise ValueError("MergeSplitExpertGateUpProj expects exactly two source patterns.")
if len(target_patterns) != 1:
Expand Down Expand Up @@ -102,6 +106,7 @@ def __init__(self, expert_dim: int = 0):

@staticmethod
def _expert_target(pattern: str, expert_idx: int) -> str:
"""Expand one target pattern into the per-expert key for ``expert_idx``."""
if "*" in pattern:
return pattern.replace("*", str(expert_idx))
return pattern.replace(".0.", f".{expert_idx}.", 1)
Expand All @@ -110,6 +115,7 @@ def _expert_target(pattern: str, expert_idx: int) -> str:
def convert(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
"""Split one fused expert down projection into cloned per-expert tensors."""
if len(target_patterns) != 1:
raise ValueError("SplitFusedExpertDownProj expects a single target pattern.")

Expand All @@ -126,4 +132,5 @@ def convert(

@property
def reverse_op(self) -> ConversionOps:
"""Return the inverse merge op used when writing fused checkpoints."""
return MergeModulelist(dim=self.expert_dim)
19 changes: 18 additions & 1 deletion defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
from defuser.model_registry import MODEL_CONFIG, PATCH
from defuser.modeling.model_patches import apply_model_class_patches, apply_model_patches
from defuser.modeling.update_module import update_module
from defuser.utils.common import (
MIN_SUPPORTED_TRANSFORMERS_VERSION,
is_supported_transformers_version,
warn_if_public_api_transformers_unsupported,
)
from packaging import version
import transformers
from logbar import LogBar

logger = LogBar(__name__)

def get_checkpoint_conversion_mapping(model_type):
"""Return Defuser's checkpoint remapping rules for one registered model type."""
from transformers import conversion_mapping

if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
Expand All @@ -36,6 +42,10 @@ class PatchError(Exception):


def replace_fused_blocks(model_type: str) -> bool:
"""Patch supported HF model classes so future loads instantiate defused blocks."""
if warn_if_public_api_transformers_unsupported("replace_fused_blocks()", logger):
return False

apply_model_class_patches(model_type)

cfg = MODEL_CONFIG.get(model_type)
Expand All @@ -60,7 +70,7 @@ def replace_fused_blocks(model_type: str) -> bool:
custom_class = getattr(custom_module, custom_class_name)
setattr(orig_module, orig_class_name, custom_class)

if version.parse(transformers.__version__) >= version.parse("5.0.0"):
if version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION):
from transformers import conversion_mapping

if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
Expand Down Expand Up @@ -89,6 +99,9 @@ def check_model_compatibility(model: nn.Module) -> bool:
if model_type not in MODEL_CONFIG:
return False

if not is_supported_transformers_version():
return False

min_ver = MODEL_CONFIG[model_type].get("min_transformers_version")
current_ver = version.parse(transformers.__version__)
if min_ver and current_ver < version.parse(min_ver):
Expand All @@ -106,6 +119,10 @@ def convert_model(
cleanup_original: bool = False,
max_layers: int | None = None,
) -> bool:
"""Convert one loaded model in place from fused experts to defused modules."""
if warn_if_public_api_transformers_unsupported("convert_model()", logger):
return False

if max_layers is not None and max_layers < 1:
raise ValueError("max_layers must be >= 1 when provided")

Expand Down
Loading