Skip to content

Commit 7dd10ba

Browse files
Added replace_fused_blocks() (#23)
* Added replace_fused_blocks(), which replaces fused MLP blocks in the modeling module with unfused counterparts prior to model loading, and bypasses the transformers conversion_mapping. Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add LinearQwen2MoeSparseMoeBlock Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add LinearMixtralSparseMoeBlock Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add LinearQwen3NextSparseMoeBlock Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add LinearQwen3OmniMoeThinkerTextSparseMoeBlock Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add apply_model_patches() Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add LinearGlm4MoeMoE Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * update version to 0.0.11 Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> --------- Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 6317736 commit 7dd10ba

13 files changed

Lines changed: 688 additions & 38 deletions

defuser/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,12 @@ def convert_model(*args, **kwargs):
1414
return _convert_model(*args, **kwargs)
1515

1616

17-
__all__ = ["convert_model"]
17+
def replace_fused_blocks(*args, **kwargs):
18+
"""Lazily import conversion entrypoint to avoid import-time cycles."""
19+
from .defuser import replace_fused_blocks as _replace_fused_blocks
20+
21+
return _replace_fused_blocks(*args, **kwargs)
22+
23+
24+
25+
__all__ = ["convert_model", "replace_fused_blocks"]

defuser/defuser.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,82 @@
22
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
import importlib
6+
57
from torch import nn
68

7-
from defuser.model_registry import MODEL_CONFIG
8-
from defuser.modeling.model_patches import apply_model_patches
9+
from defuser.model_registry import MODEL_CONFIG, PATCH
10+
from defuser.modeling.model_patches import apply_model_class_patches, apply_model_patches
911
from defuser.modeling.update_module import update_module
1012
from packaging import version
1113
import transformers
1214
from logbar import LogBar
1315

1416
logger = LogBar(__name__)
1517

18+
def get_checkpoint_conversion_mapping(model_type):
19+
from transformers import conversion_mapping
20+
21+
if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
22+
conversion_mapping.orig_get_checkpoint_conversion_mapping = conversion_mapping.get_checkpoint_conversion_mapping
23+
24+
cfg = MODEL_CONFIG.get(model_type)
25+
if cfg:
26+
return cfg.get("checkpoint_mapping", [])
27+
28+
from transformers import conversion_mapping
29+
30+
return conversion_mapping.orig_get_checkpoint_conversion_mapping(model_type)
31+
32+
33+
class PatchError(Exception):
34+
pass
35+
36+
37+
def replace_fused_blocks(model_type: str) -> bool:
38+
apply_model_class_patches(model_type)
39+
40+
cfg = MODEL_CONFIG[model_type]
41+
for orig_path, custom_path in cfg.get(PATCH.REPLACE_MODULE, []):
42+
orig_module_path, orig_class_name = orig_path.rsplit(".", 1)
43+
custom_module_path, custom_class_name = custom_path.rsplit(".", 1)
44+
45+
try:
46+
orig_module = importlib.import_module(orig_module_path)
47+
custom_module = importlib.import_module(custom_module_path)
48+
print("orig_module", orig_module, orig_class_name)
49+
# Validate class existence before patching
50+
if not hasattr(orig_module, orig_class_name):
51+
raise PatchError(f"Original class[{orig_class_name}] not found: {orig_module}")
52+
53+
if not hasattr(custom_module, custom_class_name):
54+
raise PatchError(f"Custom class[{custom_class_name}] not found: {custom_module}")
55+
56+
custom_class = getattr(custom_module, custom_class_name)
57+
setattr(orig_module, orig_class_name, custom_class)
58+
59+
if version.parse(transformers.__version__) >= version.parse("5.0.0"):
60+
from transformers import conversion_mapping
61+
62+
if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
63+
conversion_mapping.orig_get_checkpoint_conversion_mapping = (
64+
conversion_mapping.get_checkpoint_conversion_mapping
65+
)
66+
67+
conversion_mapping.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping
68+
transformers.modeling_utils.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping
69+
logger.info(f"Patched {orig_path} -> {custom_path}")
70+
return True
71+
72+
except Exception as e:
73+
if isinstance(e, PatchError):
74+
raise e
75+
76+
logger.warning(f"Failed to patch {orig_path}: {e}")
77+
return False
78+
return False
79+
80+
1681
def check_model_compatibility(model: nn.Module) -> bool:
1782
"""Validate model type and transformers version compatibility."""
1883
config = getattr(model, "config", None)
@@ -36,7 +101,7 @@ def convert_model(
36101
model: nn.Module,
37102
cleanup_original: bool = False,
38103
max_layers: int | None = None,
39-
) -> nn.Module:
104+
) -> bool:
40105
if max_layers is not None and max_layers < 1:
41106
raise ValueError("max_layers must be >= 1 when provided")
42107

@@ -113,14 +178,24 @@ def convert_model(
113178
# and the runtime model implementation that operates on defused weights.
114179

115180
if not check_model_compatibility(model):
116-
return model
181+
return False
117182

118183
apply_model_patches(model)
119184

120-
return update_module(
185+
# If fused blocks have already been structurally replaced at load model before,
186+
# there is no need to perform runtime defusing again
187+
if MODEL_CONFIG[model.config.model_type].get(PATCH.REPLACE_MODULE):
188+
return False
189+
190+
# Perform runtime defusing of fused projections
191+
# Split already-loaded fused modules (e.g., gate_up_proj/down_proj) into
192+
# independent expert layers: gate_proj / up_proj / down_proj
193+
update_module(
121194
model,
122195
cleanup_original=cleanup_original,
123196
max_layers=max_layers,
124197
)
125198

126-
__all__ = ["convert_model"]
199+
return True
200+
201+
__all__ = ["convert_model", "replace_fused_blocks"]

defuser/model_registry.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,41 @@
22
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
from enum import Enum
6+
7+
8+
class PATCH(str, Enum):
9+
REPLACE_MODULE = "replace_module"
10+
511

612
MODEL_CONFIG = {
713
"mixtral": {
814
"min_transformers_version": "5.0.0",
15+
PATCH.REPLACE_MODULE: [
16+
(
17+
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
18+
"defuser.modeling.unfused_moe.mixtral.LinearMixtralSparseMoeBlock",
19+
)
20+
],
921
},
1022
"qwen2_moe": {
1123
"min_transformers_version": "5.0.0",
24+
PATCH.REPLACE_MODULE: [
25+
(
26+
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
27+
"defuser.modeling.unfused_moe.qwen2_moe.LinearQwen2MoeSparseMoeBlock",
28+
)
29+
],
1230
},
1331
"qwen3_moe": {
1432
"min_transformers_version": "5.0.0",
33+
# structure path only replaces modeling structure
34+
PATCH.REPLACE_MODULE: [
35+
(
36+
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
37+
"defuser.modeling.unfused_moe.qwen3_moe.LinearQwen3MoeSparseMoeBlock",
38+
)
39+
],
1540
},
1641
"qwen3_5_moe": {
1742
"min_transformers_version": "5.2.0",
@@ -21,8 +46,29 @@
2146
},
2247
"qwen3_next": {
2348
"min_transformers_version": "5.0.0",
49+
PATCH.REPLACE_MODULE: [
50+
(
51+
"transformers.models.qwen3_next.modeling_qwen3_next.Qwen3NextSparseMoeBlock",
52+
"defuser.modeling.unfused_moe.qwen3_next.LinearQwen3NextSparseMoeBlock",
53+
)
54+
],
2455
},
2556
"qwen3_omni_moe": {
26-
"min_transformers_version": "5.2.0",
57+
"min_transformers_version": "5.0.0",
58+
PATCH.REPLACE_MODULE: [
59+
(
60+
"transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock",
61+
"defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeThinkerTextSparseMoeBlock",
62+
)
63+
],
64+
},
65+
"glm4_moe": {
66+
"min_transformers_version": "5.0.0",
67+
PATCH.REPLACE_MODULE: [
68+
(
69+
"transformers.models.glm4_moe.modeling_glm4_moe.Glm4MoeMoE",
70+
"defuser.modeling.unfused_moe.glm4_moe.LinearGlm4MoeMoE",
71+
)
72+
],
2773
},
2874
}

defuser/modeling/model_patches.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,61 @@
88
from logbar import LogBar
99

1010
from defuser import DEBUG_ON
11+
import torch
1112

1213
logger = LogBar(__name__)
1314

1415

16+
_MODEL_CLASS_PATCH_REGISTRY: dict[str, Callable] = {}
1517
_MODEL_PATCH_REGISTRY: dict[str, Callable] = {}
1618

1719

20+
def register_model_class_patch(model_type: str):
21+
def decorator(func: Callable):
22+
_MODEL_CLASS_PATCH_REGISTRY[model_type] = func
23+
return func
24+
25+
return decorator
26+
27+
1828
def register_model_patch(model_type: str):
1929
def decorator(func: Callable):
2030
_MODEL_PATCH_REGISTRY[model_type] = func
2131
return func
2232

2333
return decorator
2434

35+
@register_model_class_patch("qwen3_omni_moe")
36+
def patch_qwen3_omni_text_class() -> list[str]:
37+
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoePreTrainedModel
38+
from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock
39+
orig_init_weights = Qwen3OmniMoePreTrainedModel._init_weights
40+
41+
def patched_init_weights(self, module):
42+
try:
43+
orig_init_weights(self, module)
44+
except AttributeError as e:
45+
# fallback for unfused experts
46+
if isinstance(module, LinearQwen3OmniMoeThinkerTextSparseMoeBlock):
47+
std = self.config.initializer_range
48+
experts = module.experts
49+
50+
if hasattr(experts, "gate_proj"):
51+
torch.nn.init.normal_(experts.gate_proj.weight, 0.0, std)
52+
if hasattr(experts, "up_proj"):
53+
torch.nn.init.normal_(experts.up_proj.weight, 0.0, std)
54+
if hasattr(experts, "down_proj"):
55+
torch.nn.init.normal_(experts.down_proj.weight, 0.0, std)
56+
57+
if hasattr(module, "gate"):
58+
torch.nn.init.normal_(module.gate.weight, 0.0, std)
59+
else:
60+
raise e
61+
62+
Qwen3OmniMoePreTrainedModel._init_weights = patched_init_weights
63+
64+
return []
65+
2566

2667
@register_model_patch("qwen3_omni_moe")
2768
def patch_qwen3_omni_text_runtime(model) -> list[str]:
@@ -43,7 +84,6 @@ def generate(self, *args, return_audio=None, **kwargs):
4384
applied.append("generate")
4485

4586
if "forward" not in model_cls.__dict__:
46-
4787
def forward(self, *args, **kwargs):
4888
return self.thinker(*args, **kwargs)
4989

@@ -54,6 +94,17 @@ def forward(self, *args, **kwargs):
5494
return applied
5595

5696

97+
def apply_model_class_patches(model_type) -> list[str]:
98+
patch_model_class = _MODEL_CLASS_PATCH_REGISTRY.get(model_type)
99+
if patch_model_class is None:
100+
return []
101+
102+
applied = patch_model_class()
103+
if applied and DEBUG_ON:
104+
logger.debug(f"Applied model class patches for model_type={model_type}: {', '.join(applied)}")
105+
return applied
106+
107+
57108
def apply_model_patches(model) -> list[str]:
58109
config = getattr(model, "config", None)
59110
model_type = getattr(config, "model_type", None)
@@ -65,4 +116,3 @@ def apply_model_patches(model) -> list[str]:
65116
if applied and DEBUG_ON:
66117
logger.debug(f"Applied model patches for model_type={model_type}: {', '.join(applied)}")
67118
return applied
68-

defuser/modeling/unfused_moe/__init__.py

Whitespace-only changes.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
class LinearGlm4MoeMoE(nn.Module):
10+
"""
11+
A mixed expert module containing shared experts.
12+
"""
13+
14+
def __init__(self, config):
15+
super().__init__()
16+
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP, Glm4MoeTopkRouter
17+
18+
self.config = config
19+
self.experts = nn.ModuleList(
20+
[
21+
Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)
22+
for _ in range(config.n_routed_experts)
23+
]
24+
)
25+
self.gate = Glm4MoeTopkRouter(config)
26+
self.shared_experts = Glm4MoeMLP(
27+
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
28+
)
29+
30+
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
31+
r"""
32+
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
33+
to not have to do a loop here (deepseek has 256 experts soooo yeah).
34+
"""
35+
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
36+
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
37+
expert_mask = expert_mask.permute(2, 0, 1)
38+
39+
for expert_idx in range(len(self.experts)):
40+
expert = self.experts[expert_idx]
41+
mask = expert_mask[expert_idx]
42+
token_indices, weight_indices = torch.where(mask)
43+
44+
if token_indices.numel() > 0:
45+
expert_weights = topk_weights[token_indices, weight_indices]
46+
expert_input = hidden_states[token_indices]
47+
expert_output = expert(expert_input)
48+
weighted_output = expert_output * expert_weights.unsqueeze(-1)
49+
final_hidden_states.index_add_(0, token_indices, weighted_output)
50+
51+
# in original deepseek, the output of the experts are gathered once we leave this module
52+
# thus the moe module is itelsf an IsolatedParallel module
53+
# and all expert are "local" meaning we shard but we don't gather
54+
return final_hidden_states.type(hidden_states.dtype)
55+
56+
def forward(self, hidden_states):
57+
residuals = hidden_states
58+
orig_shape = hidden_states.shape
59+
topk_indices, topk_weights = self.gate(hidden_states)
60+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
61+
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
62+
hidden_states = hidden_states + self.shared_experts(residuals)
63+
return hidden_states

0 commit comments

Comments
 (0)