diff --git a/README.md b/README.md index 17ead0e..d4499fe 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,197 @@

-Model defuser helper for HF Transformers >= 5.0. In HF Transformers 5.x releases, many MoE modules became auto-stacked or auto-fused by new modeling code which has benefits but also downsides. -* Goal is to provide naive module/layer forwarding code for all models supported by HF transformers where run-time - weight and structure level optimizations such weight merging, stacking, fusing are reversed so the model is operating - in a simple naive state. -* There are cases, quantization libraries, where we need to run inference where module input/output needs to be - individually captured and this pkg can help complete this task. +Defuser converts select Hugging Face Transformers `5.3.0+` fused or stacked MoE and MLP blocks back into plain, per-expert `nn.Linear` modules. It keeps the forward math intact while exposing individual projections again so quantizers, activation capture, debugging hooks, and checkpoint tooling can work against a simple module layout instead of fused expert tensors. + +Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. + +## Purpose + +Defuser exists for cases where newer Transformers modeling code optimizes model structure in ways that are good for runtime, but harder for tooling that needs direct access to individual projections. + +Depending on the model family, Defuser can: + +- patch a supported model class before load so HF instantiates a defused block directly +- split fused tensors such as `gate_up_proj` into `gate_proj` + `up_proj` +- convert 3D expert tensors into numbered expert `nn.Linear` modules +- preserve the original fused math while presenting a naive module structure again + +Public API: + +```python +from defuser import convert_model, replace_fused_blocks +``` + +- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction. +- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints. +- Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported. + +## Supported Models + +| Model type | Recommended entrypoint | Defused op performed | +| --- | --- | --- | +| `mixtral` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. | +| `qwen2_moe` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. | +| `qwen3_moe` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. | +| `qwen3_5_moe` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. | +| `qwen3_5_moe_text` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. | +| `qwen3_next` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. | +| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. | +| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. | +| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. | + +## Workflow Summary + +Use `replace_fused_blocks()` for model families that Defuser can patch before load: + +```python +from defuser import replace_fused_blocks +from transformers import MixtralForCausalLM + +replace_fused_blocks("mixtral") +model = MixtralForCausalLM.from_pretrained( + "mistralai/Mixtral-8x7B-v0.1", + dtype="auto", + device_map="auto", +) +``` + +Use `convert_model()` for already loaded models whose expert tensors still need runtime defusion: + +```python +from defuser import convert_model + +converted = convert_model(model) +print(converted) # True when runtime defusion happened +``` + +## Real Qwen3.5 MoE Example + +The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser supports `transformers>=5.3.0`. + +### Fused Weights Before And After + +Before `convert_model(model)`: + +```text ++--------------------------------------------------------+---------------------------------------------+ +| State dict key | Layout | ++--------------------------------------------------------+---------------------------------------------+ +| model.language_model.layers.0.mlp.experts.gate_up_proj | fused gate+up tensor for all experts | +| | [num_experts, 2 * moe_intermediate, hidden] | +| model.language_model.layers.0.mlp.experts.down_proj | fused per-expert down tensor | +| | [num_experts, hidden, moe_intermediate] | ++--------------------------------------------------------+---------------------------------------------+ +``` + +After `convert_model(model)`: + +```text ++-----------------------------------------------------------------+--------------------------------------+ +| State dict key | Layout | ++-----------------------------------------------------------------+--------------------------------------+ +| model.language_model.layers.0.mlp.experts.0.gate_proj.weight | expert 0 gate projection | +| model.language_model.layers.0.mlp.experts.0.up_proj.weight | expert 0 up projection | +| model.language_model.layers.0.mlp.experts.0.down_proj.weight | expert 0 down projection | +| ... repeated for experts 1..N-1 | numbered expert nn.Linear modules | ++-----------------------------------------------------------------+--------------------------------------+ +``` + +### Sample 1: Inspect The Conversion In Place + +```python +from defuser import convert_model +from transformers import Qwen3_5MoeForConditionalGeneration + +model_id = "Qwen/Qwen3.5-35B-A3B-Instruct" + +model = Qwen3_5MoeForConditionalGeneration.from_pretrained( + model_id, + dtype="auto", + device_map="auto", +) + +prefix = "model.language_model.layers.0.mlp.experts" + +before = [name for name, _ in model.named_parameters() if name.startswith(prefix)] +print(before) +# [ +# "model.language_model.layers.0.mlp.experts.gate_up_proj", +# "model.language_model.layers.0.mlp.experts.down_proj", +# ] + +converted = convert_model(model) +assert converted is True + +after = [name for name, _ in model.named_parameters() if name.startswith(prefix)] +print(after[:6]) +# [ +# "model.language_model.layers.0.mlp.experts.0.down_proj.weight", +# "model.language_model.layers.0.mlp.experts.0.gate_proj.weight", +# "model.language_model.layers.0.mlp.experts.0.up_proj.weight", +# "model.language_model.layers.0.mlp.experts.1.down_proj.weight", +# "model.language_model.layers.0.mlp.experts.1.gate_proj.weight", +# "model.language_model.layers.0.mlp.experts.1.up_proj.weight", +# ] +``` + +### Sample 2: Convert And Keep Using The Model Normally + +```python +import torch + +from defuser import convert_model +from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration + +model_id = "Qwen/Qwen3.5-35B-A3B-Instruct" + +model = Qwen3_5MoeForConditionalGeneration.from_pretrained( + model_id, + dtype="auto", + device_map="auto", +) +processor = AutoProcessor.from_pretrained(model_id) + +convert_model(model) + +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Explain mixture-of-experts routing in one sentence."}, + ], + } +] + +inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", +) +inputs = inputs.to(model.device) + +with torch.inference_mode(): + output_ids = model.generate(**inputs, max_new_tokens=64) + +generated_ids = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids) +] +text = processor.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, +)[0] +print(text) +``` + +After conversion, the first routed expert in the first MoE layer is exposed as normal submodules: + +```python +expert0 = model.model.language_model.layers[0].mlp.experts[0] +print(type(expert0.gate_proj).__name__) # Linear +print(type(expert0.up_proj).__name__) # Linear +print(type(expert0.down_proj).__name__) # Linear +``` diff --git a/defuser/checkpoint_ops.py b/defuser/checkpoint_ops.py index 2309fb3..63a9887 100644 --- a/defuser/checkpoint_ops.py +++ b/defuser/checkpoint_ops.py @@ -24,6 +24,7 @@ def __init__(self, expert_dim: int = 0, proj_dim: int = 0): @staticmethod def _expert_target(pattern: str, expert_idx: int) -> str: + """Expand one target pattern into the per-expert key for ``expert_idx``.""" if "*" in pattern: return pattern.replace("*", str(expert_idx)) return pattern.replace(".0.", f".{expert_idx}.", 1) @@ -32,6 +33,7 @@ def _expert_target(pattern: str, expert_idx: int) -> str: def convert( self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs ) -> dict[str, torch.Tensor]: + """Split one fused gate/up tensor into cloned per-expert gate and up tensors.""" if len(target_patterns) != 2: raise ValueError("SplitFusedExpertGateUpProj expects exactly two target patterns.") @@ -50,6 +52,7 @@ def convert( @property def reverse_op(self) -> ConversionOps: + """Return the inverse merge op used when writing fused checkpoints.""" return MergeSplitExpertGateUpProj() @@ -66,6 +69,7 @@ def __init__(self, expert_dim: int = 0, proj_dim: int = 0): def convert( self, input_dict: dict[str, list[torch.Tensor]], source_patterns: list[str], target_patterns: list[str], **kwargs ) -> dict[str, torch.Tensor]: + """Merge per-expert gate/up tensors back into one fused expert tensor.""" if len(source_patterns) != 2: raise ValueError("MergeSplitExpertGateUpProj expects exactly two source patterns.") if len(target_patterns) != 1: @@ -102,6 +106,7 @@ def __init__(self, expert_dim: int = 0): @staticmethod def _expert_target(pattern: str, expert_idx: int) -> str: + """Expand one target pattern into the per-expert key for ``expert_idx``.""" if "*" in pattern: return pattern.replace("*", str(expert_idx)) return pattern.replace(".0.", f".{expert_idx}.", 1) @@ -110,6 +115,7 @@ def _expert_target(pattern: str, expert_idx: int) -> str: def convert( self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs ) -> dict[str, torch.Tensor]: + """Split one fused expert down projection into cloned per-expert tensors.""" if len(target_patterns) != 1: raise ValueError("SplitFusedExpertDownProj expects a single target pattern.") @@ -126,4 +132,5 @@ def convert( @property def reverse_op(self) -> ConversionOps: + """Return the inverse merge op used when writing fused checkpoints.""" return MergeModulelist(dim=self.expert_dim) diff --git a/defuser/defuser.py b/defuser/defuser.py index 1e8ab55..345e909 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -10,6 +10,11 @@ from defuser.model_registry import MODEL_CONFIG, PATCH from defuser.modeling.model_patches import apply_model_class_patches, apply_model_patches from defuser.modeling.update_module import update_module +from defuser.utils.common import ( + MIN_SUPPORTED_TRANSFORMERS_VERSION, + is_supported_transformers_version, + warn_if_public_api_transformers_unsupported, +) from packaging import version import transformers from logbar import LogBar @@ -17,6 +22,7 @@ logger = LogBar(__name__) def get_checkpoint_conversion_mapping(model_type): + """Return Defuser's checkpoint remapping rules for one registered model type.""" from transformers import conversion_mapping if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"): @@ -36,6 +42,10 @@ class PatchError(Exception): def replace_fused_blocks(model_type: str) -> bool: + """Patch supported HF model classes so future loads instantiate defused blocks.""" + if warn_if_public_api_transformers_unsupported("replace_fused_blocks()", logger): + return False + apply_model_class_patches(model_type) cfg = MODEL_CONFIG.get(model_type) @@ -60,7 +70,7 @@ def replace_fused_blocks(model_type: str) -> bool: custom_class = getattr(custom_module, custom_class_name) setattr(orig_module, orig_class_name, custom_class) - if version.parse(transformers.__version__) >= version.parse("5.0.0"): + if version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION): from transformers import conversion_mapping if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"): @@ -89,6 +99,9 @@ def check_model_compatibility(model: nn.Module) -> bool: if model_type not in MODEL_CONFIG: return False + if not is_supported_transformers_version(): + return False + min_ver = MODEL_CONFIG[model_type].get("min_transformers_version") current_ver = version.parse(transformers.__version__) if min_ver and current_ver < version.parse(min_ver): @@ -106,6 +119,10 @@ def convert_model( cleanup_original: bool = False, max_layers: int | None = None, ) -> bool: + """Convert one loaded model in place from fused experts to defused modules.""" + if warn_if_public_api_transformers_unsupported("convert_model()", logger): + return False + if max_layers is not None and max_layers < 1: raise ValueError("max_layers must be >= 1 when provided") diff --git a/defuser/model_registry.py b/defuser/model_registry.py index f9a468a..9a4d7f0 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -7,16 +7,17 @@ from transformers.core_model_loading import WeightConverter, WeightRenaming from defuser.checkpoint_ops import OwnedChunk, SplitFusedExpertDownProj, SplitFusedExpertGateUpProj +from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION class PATCH(str, Enum): REPLACE_MODULE = "replace_module" - DEFUSE = "defuse" + EXPERTS_DEFUSE = "experts_defuse" MODEL_CONFIG = { "mixtral": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ ( "transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock", @@ -44,7 +45,7 @@ class PATCH(str, Enum): ], }, "qwen2_moe": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ ( "transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock", @@ -53,7 +54,7 @@ class PATCH(str, Enum): ], }, "qwen3_moe": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, # structure path only replaces modeling structure PATCH.REPLACE_MODULE: [ ( @@ -63,13 +64,13 @@ class PATCH(str, Enum): ], }, "qwen3_5_moe": { - "min_transformers_version": "5.2.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, }, "qwen3_5_moe_text": { - "min_transformers_version": "5.2.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, }, "qwen3_next": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ ( "transformers.models.qwen3_next.modeling_qwen3_next.Qwen3NextSparseMoeBlock", @@ -78,7 +79,7 @@ class PATCH(str, Enum): ], }, "qwen3_omni_moe": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ ( "transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock", @@ -87,7 +88,7 @@ class PATCH(str, Enum): ], }, "glm4_moe": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ ( "transformers.models.glm4_moe.modeling_glm4_moe.Glm4MoeMoE", @@ -96,7 +97,7 @@ class PATCH(str, Enum): ], }, "glm4v": { - "min_transformers_version": "5.0.0", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ ( "transformers.models.glm4v.modeling_glm4v.Glm4vTextMLP", @@ -116,10 +117,15 @@ class PATCH(str, Enum): ], }, "gpt_oss": { - # "min_transformers_version": "", # When `gpt_oss` was added to `transformers`, it was already implemented as "fused experts." + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, }, "llama4": { - # "min_transformers_version": "", # When `llama4` was added to `transformers`, it was already implemented as "fused experts." - PATCH.DEFUSE: "defuser.modeling.fused_moe.llama4", + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + PATCH.EXPERTS_DEFUSE: [ + { + "module_class": "transformers.models.llama4.modeling_llama4.Llama4TextExperts", + "forward_impl": "batched_input", + } + ], }, } diff --git a/defuser/modeling/fused_moe/__init__.py b/defuser/modeling/fused_moe/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/defuser/modeling/fused_moe/llama4.py b/defuser/modeling/fused_moe/llama4.py deleted file mode 100644 index 9a4cc05..0000000 --- a/defuser/modeling/fused_moe/llama4.py +++ /dev/null @@ -1,103 +0,0 @@ -# Adapted from intel/auto-round -# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/fused_moe/llama4.py - -import torch -import transformers -from packaging import version - -transformers_version = version.parse(transformers.__version__) -if transformers_version < version.parse("5.0.0"): - from transformers.modeling_utils import no_init_weights -else: - from transformers.initialization import no_init_weights -from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4TextMLP - -from defuser.modeling.replace_modules import ReplacementModuleBase -from defuser.utils.model import _update_parameter, unsupported_meta_device -from defuser.utils.device import clear_memory - - -class SequentialLlama4TextExperts(torch.nn.ModuleList): - def __init__(self, config, original): - self.num_experts = original.gate_up_proj.shape[0] - target_device = next(original.parameters()).device - with no_init_weights(), torch.device("meta"): - super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) - - def _materialize_weights(self, original) -> None: - if not unsupported_meta_device(original): - intermediate_size = original.down_proj.shape[1] - - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - gate_proj = gate_up[:, :intermediate_size] - up_proj = gate_up[:, intermediate_size:] - _update_parameter(self[i].gate_proj, "weight", gate_proj.t().contiguous()) - _update_parameter(self[i].up_proj, "weight", up_proj.t().contiguous()) - _update_parameter(self[i].down_proj, "weight", down.t().contiguous()) - del gate_up, down, gate_proj, up_proj - original.to_empty(device="meta") # release original experts parameters - - -class SequentialLlama4TextMoe(ReplacementModuleBase): - def __init__(self, original, config): - super().__init__(original) - config = config.text_config - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts - with torch.device("meta"): - self.experts = SequentialLlama4TextExperts(config, original.experts) - - self.router = original.router - self.shared_expert = original.shared_expert - - def forward(self, hidden_states: torch.Tensor): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - if isinstance(router_logits, tuple): - router_scores, router_logits = router_logits - router_scores = router_scores.t() - else: - # transformers < 4.54.0 only returns router_logits - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) - - router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - - out = self.shared_expert(hidden_states) - - # Only process experts that actually received tokens (expert_hit pattern), - # skipping experts with zero routing weight to save compute during calibration. - with torch.no_grad(): - expert_hit = torch.greater(router_scores.sum(dim=-1), 0).nonzero() - for expert_idx in expert_hit: - expert_idx = expert_idx[0] - out += self.experts[expert_idx](hidden_states) * router_scores[expert_idx].reshape(-1, 1) - - return out, router_logits - - @classmethod - def original_module_class(cls) -> str: - """Return the class name of the module this replaces.""" - return "Llama4TextMoe" - - @classmethod - def from_original( - cls, - original: torch.nn.Module, - config: Llama4Config, - **kwargs, - ) -> "SequentialLlama4TextMoe": - """Create an instance from the original module.""" - return cls(original, config) - - def _materialize_weights(self) -> None: - original = self._get_original_module() - self.experts._materialize_weights(original.experts) - clear_memory() diff --git a/defuser/modeling/glm4v.py b/defuser/modeling/glm4v.py index 254ea16..462b0d4 100644 --- a/defuser/modeling/glm4v.py +++ b/defuser/modeling/glm4v.py @@ -14,6 +14,7 @@ def __init__(self, config): self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): + """Reproduce the original fused GLM4V text MLP using split linear layers.""" gate = self.gate_proj(hidden_states) up = self.up_proj(hidden_states) # Match the original fused `gate_up_proj.chunk(2, dim=-1)` activation path. diff --git a/defuser/modeling/model_patches.py b/defuser/modeling/model_patches.py index 1982588..b1202bf 100644 --- a/defuser/modeling/model_patches.py +++ b/defuser/modeling/model_patches.py @@ -18,6 +18,7 @@ def register_model_class_patch(model_type: str): + """Register a one-time class patch that runs before model construction.""" def decorator(func: Callable): _MODEL_CLASS_PATCH_REGISTRY[model_type] = func return func @@ -26,6 +27,7 @@ def decorator(func: Callable): def register_model_patch(model_type: str): + """Register a runtime patch that runs on an instantiated model object.""" def decorator(func: Callable): _MODEL_PATCH_REGISTRY[model_type] = func return func @@ -34,6 +36,7 @@ def decorator(func: Callable): @register_model_class_patch("qwen3_omni_moe") def patch_qwen3_omni_text_class() -> list[str]: + """Teach HF init code how to initialize unfused qwen3-omni thinker experts.""" from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoePreTrainedModel from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock orig_init_weights = Qwen3OmniMoePreTrainedModel._init_weights @@ -66,6 +69,7 @@ def patched_init_weights(self, module): @register_model_patch("qwen3_omni_moe") def patch_qwen3_omni_text_runtime(model) -> list[str]: + """Restore text-only ``forward`` and ``generate`` behavior after class swapping.""" model_cls = type(model) if not getattr(model_cls, "__module__", "").startswith("transformers.models.qwen3_omni_moe."): return [] @@ -95,6 +99,7 @@ def forward(self, *args, **kwargs): def apply_model_class_patches(model_type) -> list[str]: + """Run any registered pre-construction patch for ``model_type``.""" patch_model_class = _MODEL_CLASS_PATCH_REGISTRY.get(model_type) if patch_model_class is None: return [] @@ -106,6 +111,7 @@ def apply_model_class_patches(model_type) -> list[str]: def apply_model_patches(model) -> list[str]: + """Run any registered runtime patch for the instantiated ``model``.""" config = getattr(model, "config", None) model_type = getattr(config, "model_type", None) patch = _MODEL_PATCH_REGISTRY.get(model_type) diff --git a/defuser/modeling/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py index 7756042..99e0f8d 100644 --- a/defuser/modeling/moe_experts_interface.py +++ b/defuser/modeling/moe_experts_interface.py @@ -25,10 +25,13 @@ # Now the model uses linear_loop forward which supports quantized nn.Linear layers """ +from types import MethodType + import torch from logbar import LogBar from torch import nn +from defuser.model_registry import MODEL_CONFIG, PATCH from defuser.utils.device import clear_memory, to_meta from defuser import DEBUG_ON @@ -45,6 +48,7 @@ # Expert implementation name - change this if transformers want to use a different name LINEAR_LOOP_IMPL = "linear_loop" +BATCHED_INPUT_IMPL = "batched_input" # Known expert projection patterns for reference # These are used as hints when auto-detection needs to infer projection properties @@ -79,6 +83,17 @@ def is_linear_loop_available() -> bool: return HAS_EXPERTS_INTERFACE +def _apply_expert_gate(module: nn.Module, gate_out: torch.Tensor, up_out: torch.Tensor) -> torch.Tensor: + """Apply the expert's activation path using either a custom gate hook or ``act_fn``.""" + if hasattr(module, "_apply_gate"): + return module._apply_gate(torch.cat([gate_out, up_out], dim=-1)) + + act_fn = getattr(module, "act_fn", None) + if act_fn is None: + raise AttributeError(f"{module.__class__.__name__} must define either `_apply_gate` or `act_fn`.") + return act_fn(gate_out) * up_out + + def linear_loop_experts_forward( self: nn.Module, hidden_states: torch.Tensor, @@ -152,13 +167,7 @@ def linear_loop_experts_forward( expert = getattr(self, str(expert_idx)) gate_out = expert.gate_proj(expert_input) # (num_samples, intermediate_dim) up_out = expert.up_proj(expert_input) # (num_samples, intermediate_dim) - - # Apply gating - if hasattr(self, "_apply_gate"): - gate_up_out = torch.cat([gate_out, up_out], dim=-1) - gated_out = self._apply_gate(gate_up_out) # (num_samples, intermediate_dim) - else: - gated_out = self.act_fn(gate_out) * up_out # (num_samples, intermediate_dim) + gated_out = _apply_expert_gate(self, gate_out, up_out) # Down projection expert_out = expert.down_proj(gated_out) # (num_samples, hidden_dim) @@ -180,6 +189,31 @@ def linear_loop_experts_forward( return final_hidden_states +def batched_input_experts_forward(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: + """Run defused experts for models that feed experts as expert-major input batches. + + Llama4 is the current example: upstream code repeats and pre-weights tokens, + then calls ``experts(hidden_states)`` where the leading dimension is laid out + as ``[expert0_tokens, expert1_tokens, ...]``. This forward keeps that public + contract while still executing per-expert ``nn.Linear`` modules internally. + """ + if DEBUG_ON: logger.debug(f"Using {BATCHED_INPUT_IMPL} experts forward for {self.__class__.__name__}") + + hidden_dim = hidden_states.size(-1) + expert_inputs = hidden_states.view(self.num_experts, -1, hidden_dim) + expert_outputs = [] + + for expert_idx in range(self.num_experts): + expert = getattr(self, str(expert_idx)) + expert_input = expert_inputs[expert_idx] + gate_out = expert.gate_proj(expert_input) + up_out = expert.up_proj(expert_input) + gated_out = _apply_expert_gate(self, gate_out, up_out) + expert_outputs.append(expert.down_proj(gated_out)) + + return torch.stack(expert_outputs, dim=0).reshape(-1, hidden_dim) + + def register_linear_loop_experts() -> bool: """Register the linear_loop experts implementation with transformers. @@ -201,6 +235,45 @@ def register_linear_loop_experts() -> bool: return True +def _model_experts_defuse_specs(model: nn.Module) -> list[dict]: + """Return declarative experts-defusion specs for the current model type.""" + config = getattr(model, "config", None) + model_type = getattr(config, "model_type", None) + if model_type is None: + return [] + + specs = MODEL_CONFIG.get(model_type, {}).get(PATCH.EXPERTS_DEFUSE, []) + if isinstance(specs, dict): + return [specs] + return list(specs) + + +def _module_class_path(module: nn.Module) -> str: + """Return a stable import-style class path for matching model specs.""" + return f"{module.__class__.__module__}.{module.__class__.__name__}" + + +def _matching_experts_defuse_spec(module: nn.Module, specs: list[dict]) -> dict | None: + """Find the first declarative experts-defusion spec that matches ``module``.""" + module_path = _module_class_path(module) + module_name = module.__class__.__name__ + + for spec in specs: + target = spec.get("module_class") + if target in {module_path, module_name}: + return spec + return None + + +def _install_instance_forward(module: nn.Module, implementation: str) -> None: + """Attach a generic forward implementation directly to one experts module.""" + if implementation == BATCHED_INPUT_IMPL: + module.forward = MethodType(batched_input_experts_forward, module) + return + + raise ValueError(f"Unsupported experts forward implementation: {implementation}") + + def _detect_expert_projections(module: nn.Module) -> dict[str, dict]: """Detect which expert projections exist in the module. @@ -613,11 +686,26 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L "This requires transformers >= 5.0.0 with MOE integration support." ) - # Unfuse all fused experts modules (only those supporting @use_experts_implementation) + # Unfuse all fused experts modules, including models that need a generic + # instance-level forward override instead of transformers' decorator path. unfused_modules = [] + decorated_unfused_modules = [] + experts_defuse_specs = _model_experts_defuse_specs(model) for name, module in model.named_modules(): + spec = _matching_experts_defuse_spec(module, experts_defuse_specs) + if spec is not None and _unfuse_experts_weights_inplace( + module, + check_decorator=False, + projection_names=spec.get("projection_names"), + ): + _install_instance_forward(module, spec["forward_impl"]) + unfused_modules.append(name) + if DEBUG_ON: logger.debug(f"[MoE Prep] Unfused '{name}' via declarative spec") + continue + if _unfuse_experts_weights_inplace(module): unfused_modules.append(name) + decorated_unfused_modules.append(name) if DEBUG_ON: logger.debug(f"[MoE Prep] Unfused '{name}'") # Only set config if we actually unfused something @@ -627,8 +715,9 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L if DEBUG_ON: logger.info(f"[MoE Prep] Unfused {len(unfused_modules)} MOE experts modules") clear_memory() - # Set config for linear_loop forward - if hasattr(model, "config"): + # Set config for linear_loop forward only when the upstream model uses + # the decorator-based experts interface. + if decorated_unfused_modules and hasattr(model, "config"): saved_impl = getattr(model.config, "experts_implementation", None) impl_to_set = saved_impl if saved_impl else implementation model.config._experts_implementation = impl_to_set diff --git a/defuser/modeling/replace_modules.py b/defuser/modeling/replace_modules.py index 34df127..dfc0385 100644 --- a/defuser/modeling/replace_modules.py +++ b/defuser/modeling/replace_modules.py @@ -6,7 +6,6 @@ # Adapted from intel/auto-round # at https://github.com/intel/auto-round/blob/main/auto_round/modeling/fused_moe/replace_modules.py -import importlib import weakref from abc import ABC, abstractmethod from dataclasses import dataclass @@ -20,34 +19,12 @@ from defuser import DEBUG_ON -from defuser.model_registry import MODEL_CONFIG, PATCH - logger = LogBar(__name__) -def is_model_patchable(model: torch.nn.Module) -> bool: - """Check if the model has a custom replacement registered via MODEL_CONFIG. - - Returns True if the model's model_type matches a key in MODEL_CONFIG. - """ - if hasattr(model, "config") and hasattr(model.config, "model_type"): - return model.config.model_type in MODEL_CONFIG and PATCH.DEFUSE in MODEL_CONFIG[model.config.model_type] - return False - - -def _import_required_replacements(model: torch.nn.Module) -> None: - """Import replacement modules required for the model's defuse workflow.""" - if not is_model_patchable(model): - return - model_type = model.config.model_type - module_path = MODEL_CONFIG[model_type].get(PATCH.DEFUSE) - if not module_path: - return - importlib.import_module(module_path) - logger.debug(f"Loaded replacement module for {model_type}: {module_path}") - - def materialize_model(model: torch.nn.Module) -> None: + """Materialize any deferred replacement weights attached to ``model``.""" + def _materialize_module(module: torch.nn.Module) -> None: if isinstance(module, ReplacementModuleBase): module.materialize_weights() @@ -73,6 +50,8 @@ def _materialize_module(module: torch.nn.Module) -> None: def release_original_module_(model: torch.nn.Module) -> None: + """Drop references to original fused modules after replacement is complete.""" + def _clear_source_module(module: torch.nn.Module) -> None: if isinstance(module, ReplacementModuleBase): module.release_original_module() @@ -263,16 +242,17 @@ def apply_replacements( Returns: The model with modules replaced. """ - _import_required_replacements(model) - _log_first_moe_block(model, "before replacement") - # Custom replacements first - if is_model_patchable(model): - _apply_custom_replacements(model, max_layers=max_layers) - elif auto_detect_moe and is_transformers_version_greater_or_equal_5(): + # Run the generic MoE tensor defusion pass first so models with supported + # fused experts can stay on their upstream module structure. + if auto_detect_moe and is_transformers_version_greater_or_equal_5(): _handle_moe_modules(model) + # Fall back to replacement modules for any models that still need a custom + # structural wrapper after the generic experts pass. + _apply_custom_replacements(model, max_layers=max_layers) + _log_first_moe_block(model, "after replacement") return model diff --git a/defuser/modeling/unfused_moe/glm4_moe.py b/defuser/modeling/unfused_moe/glm4_moe.py index 24b1299..61353f1 100644 --- a/defuser/modeling/unfused_moe/glm4_moe.py +++ b/defuser/modeling/unfused_moe/glm4_moe.py @@ -12,6 +12,7 @@ class LinearGlm4MoeMoE(nn.Module): """ def __init__(self, config): + """Build a GLM4 MoE block that exposes one explicit module per routed expert.""" super().__init__() from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP, Glm4MoeTopkRouter @@ -34,6 +35,7 @@ def __init__(self, config): self.top_k = config.num_experts_per_tok def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Replicate GLM4's grouped expert selection and top-k weight normalization.""" # Keep the expert selection identical to the upstream GLM4 MoE router. router_logits = router_logits.sigmoid() router_logits_for_choice = router_logits + self.gate.e_score_correction_bias @@ -59,10 +61,7 @@ def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Te return topk_indices, topk_weights def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - r""" - CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused - to not have to do a loop here (deepseek has 256 experts soooo yeah). - """ + """Run the routed experts one by one and accumulate their weighted outputs.""" final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) expert_mask = expert_mask.permute(2, 0, 1) @@ -85,6 +84,7 @@ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weig return final_hidden_states.type(hidden_states.dtype) def forward(self, hidden_states): + """Apply routed experts plus the shared expert branch, matching upstream GLM4 MoE.""" residuals = hidden_states orig_shape = hidden_states.shape router_logits = self.gate(hidden_states) diff --git a/defuser/modeling/unfused_moe/mixtral.py b/defuser/modeling/unfused_moe/mixtral.py index bf62878..edfbf20 100644 --- a/defuser/modeling/unfused_moe/mixtral.py +++ b/defuser/modeling/unfused_moe/mixtral.py @@ -10,6 +10,8 @@ class MixtralBlockSparseTop2MLP(nn.Module): + """Per-expert Mixtral MLP with explicit gate, up, and down projections.""" + def __init__(self, config: MixtralConfig): super().__init__() self.hidden_size = config.hidden_size @@ -20,6 +22,7 @@ def __init__(self, config: MixtralConfig): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): + """Apply the standard SwiGLU-style Mixtral expert math.""" down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -53,7 +56,7 @@ def __init__(self, config): self.jitter_noise = config.router_jitter_noise def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Match HF Mixtral MoE routing while executing explicit per-expert modules.""" batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) diff --git a/defuser/modeling/unfused_moe/qwen2_moe.py b/defuser/modeling/unfused_moe/qwen2_moe.py index 8a49df6..61ef68e 100644 --- a/defuser/modeling/unfused_moe/qwen2_moe.py +++ b/defuser/modeling/unfused_moe/qwen2_moe.py @@ -9,6 +9,8 @@ class LinearQwen2MoeSparseMoeBlock(nn.Module): + """Qwen2 MoE block rewritten to expose one ``nn.Module`` per expert.""" + def __init__(self, config): super().__init__() from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP, Qwen2MoeTopKRouter @@ -27,7 +29,7 @@ def __init__(self, config): self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Route tokens exactly like HF Qwen2 MoE, then run explicit expert modules.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) diff --git a/defuser/modeling/unfused_moe/qwen3_moe.py b/defuser/modeling/unfused_moe/qwen3_moe.py index 90f4888..89d382d 100644 --- a/defuser/modeling/unfused_moe/qwen3_moe.py +++ b/defuser/modeling/unfused_moe/qwen3_moe.py @@ -11,6 +11,8 @@ class LinearQwen3MoeSparseMoeBlock(nn.Module): + """Qwen3 MoE block rewritten to expose one ``nn.Module`` per expert.""" + def __init__(self, config): super().__init__() from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP, Qwen3MoeTopKRouter @@ -26,7 +28,7 @@ def __init__(self, config): ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Route tokens exactly like HF Qwen3 MoE, then run explicit expert modules.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) diff --git a/defuser/modeling/unfused_moe/qwen3_next.py b/defuser/modeling/unfused_moe/qwen3_next.py index 10bfc2f..d9bdcc2 100644 --- a/defuser/modeling/unfused_moe/qwen3_next.py +++ b/defuser/modeling/unfused_moe/qwen3_next.py @@ -8,6 +8,8 @@ from torch.nn import functional as F class LinearQwen3NextSparseMoeBlock(nn.Module): + """Qwen3-Next MoE block rewritten to expose one ``nn.Module`` per expert.""" + def __init__(self, config): super().__init__() from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP, Qwen3NextTopKRouter @@ -26,7 +28,7 @@ def __init__(self, config): self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Route tokens exactly like HF Qwen3-Next MoE, then run explicit experts.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) diff --git a/defuser/modeling/unfused_moe/qwen3_omni_moe.py b/defuser/modeling/unfused_moe/qwen3_omni_moe.py index 0fb6476..94c0d4f 100644 --- a/defuser/modeling/unfused_moe/qwen3_omni_moe.py +++ b/defuser/modeling/unfused_moe/qwen3_omni_moe.py @@ -7,6 +7,8 @@ import torch.nn as nn class LinearQwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module): + """Text thinker MoE block for qwen3-omni with explicit per-expert modules.""" + def __init__(self, config): super().__init__() from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( @@ -28,7 +30,7 @@ def __init__(self, config): ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Route tokens exactly like HF qwen3-omni text MoE, then run explicit experts.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) diff --git a/defuser/modeling/update_module.py b/defuser/modeling/update_module.py index c56aa63..6a4b2cf 100644 --- a/defuser/modeling/update_module.py +++ b/defuser/modeling/update_module.py @@ -14,6 +14,7 @@ def update_module( cleanup_original: bool = True, max_layers: int | None = None, ): + """Run Defuser's replacement pipeline and optionally drop original modules.""" model = apply_replacements(model, max_layers=max_layers) if cleanup_original: diff --git a/defuser/utils/common.py b/defuser/utils/common.py index 8d49a4c..842797e 100644 --- a/defuser/utils/common.py +++ b/defuser/utils/common.py @@ -8,12 +8,14 @@ from functools import lru_cache import re +from packaging import version # Match module paths like "...layers.0..." and capture the numeric layer index. _LAYER_NAME_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)") TRUTHFUL = {"1", "true", "yes", "on", "y"} +MIN_SUPPORTED_TRANSFORMERS_VERSION = "5.3.0" def env_flag(name: str, default: str | bool | None = "0") -> bool: @@ -31,12 +33,33 @@ def env_flag(name: str, default: str | bool | None = "0") -> bool: @lru_cache(None) def is_transformers_version_greater_or_equal_5(): + """Cache the coarse ``transformers>=5`` capability check used by fast paths.""" import transformers - from packaging import version return version.parse(transformers.__version__) >= version.parse("5.0.0") +def is_supported_transformers_version() -> bool: + """Return whether the installed transformers version is supported by Defuser's public API.""" + import transformers + + return version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION) + + +def warn_if_public_api_transformers_unsupported(api_name: str, logger) -> bool: + """Emit a single consistent warning when the runtime transformers version is too old.""" + import transformers + + if is_supported_transformers_version(): + return False + + logger.warning( + f"Defuser public API `{api_name}` requires transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}. " + f"Current version is {transformers.__version__}. This call is unsupported and will be skipped." + ) + return True + + def is_within_max_layers(module_name: str, max_layers: int | None) -> bool: """Return True when module path is within requested layer limit.""" if max_layers is None: diff --git a/defuser/utils/device.py b/defuser/utils/device.py index eaf800e..ebb118c 100644 --- a/defuser/utils/device.py +++ b/defuser/utils/device.py @@ -35,6 +35,7 @@ def clear_memory( tensor: torch.Tensor | list[torch.Tensor] | None = None, device_list: tuple | list | str | torch.device | None = None, ): + """Release Python references and flush accelerator caches after tensor surgery.""" # ------------------------ # Clear CPU-side references # ------------------------ @@ -87,14 +88,7 @@ def clear_memory( def unsupported_meta_device(model): - """Checks if the model is a valid model. - - Args: - model: The model to be checked. - - Returns: - bool: True if the model is valid, False otherwise. - """ + """Return ``True`` when a model mixes meta and real devices in unsupported ways.""" target_device = None for param in model.parameters(): if target_device is None: diff --git a/defuser/utils/hf.py b/defuser/utils/hf.py index b4168f9..af370f5 100644 --- a/defuser/utils/hf.py +++ b/defuser/utils/hf.py @@ -6,7 +6,6 @@ # Adapted from intel/auto-round # at https://github.com/intel/auto-round/blob/main/auto_round/modeling/unfused_moe/__init__.py -import importlib import os from typing import Final @@ -17,6 +16,7 @@ from transformers import AutoConfig from defuser.model_registry import MODEL_CONFIG +from defuser.utils.common import warn_if_public_api_transformers_unsupported logger = LogBar(__name__) @@ -47,6 +47,7 @@ def modelscope_requested() -> bool: def get_file_path_via_model_name(model_or_path: str, file_name): + """Resolve one HF or ModelScope file path from either a repo id or local directory.""" from huggingface_hub import hf_hub_download # 1) local folder @@ -73,6 +74,10 @@ def get_file_path_via_model_name(model_or_path: str, file_name): def pre_check_config(model_name: str | torch.nn.Module): + """Quickly decide whether a model likely still needs Defuser's runtime conversion.""" + if warn_if_public_api_transformers_unsupported("pre_check_config()", logger): + return False + if isinstance(model_name, str): config = AutoConfig.from_pretrained(model_name) elif isinstance(model_name, torch.nn.Module): @@ -97,6 +102,7 @@ def pre_check_config(model_name: str | torch.nn.Module): with open(file_path, "r") as f: index_data = json.load(f) + # A fused gate/up tensor in the weight map means runtime defusion is still needed. model_keys = list(index_data.get("weight_map", {}).keys()) for key in model_keys: if "gate_up_proj" in key: diff --git a/defuser/utils/model.py b/defuser/utils/model.py index 349e1ed..a7bbc76 100644 --- a/defuser/utils/model.py +++ b/defuser/utils/model.py @@ -14,20 +14,14 @@ def _update_parameter( name: str, data: torch.Tensor, ) -> None: + """Replace one module parameter while preserving its ``requires_grad`` flag.""" old_param = getattr(module, name) new_param = torch.nn.Parameter(data, requires_grad=old_param.requires_grad) setattr(module, name, new_param) def unsupported_meta_device(model): - """Checks if the model is a valid model for auto_round. - - Args: - model: The model to be checked. - - Returns: - bool: True if the model is valid, False otherwise. - """ + """Return ``True`` when mixed real/meta parameters make lazy materialization unsafe.""" target_device = None for param in model.parameters(): if target_device is None: diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 7a458ec..a2a4789 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -38,8 +38,11 @@ Qwen3OmniMoeThinkerTextSparseMoeBlock, ) +import defuser.defuser as defuser_api +import defuser.utils.hf as hf_utils from defuser import convert_model, replace_fused_blocks from defuser.checkpoint_ops import OwnedChunk, SplitFusedExpertGateUpProj +from defuser.model_registry import MODEL_CONFIG from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model from defuser.modeling.unfused_moe.glm4_moe import LinearGlm4MoeMoE from defuser.modeling.unfused_moe.mixtral import LinearMixtralSparseMoeBlock @@ -47,6 +50,7 @@ from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock from defuser.modeling.unfused_moe.qwen3_next import LinearQwen3NextSparseMoeBlock from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock +from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION @@ -328,7 +332,12 @@ def _copy_sparse_moe_weights(original_block: nn.Module, defused_block: nn.Module def _assert_sparse_moe_defused_matches_fused_math( - original_block: nn.Module, defused_block: nn.Module, hidden_states: torch.Tensor + original_block: nn.Module, + defused_block: nn.Module, + hidden_states: torch.Tensor, + *, + atol: float | None = None, + rtol: float | None = None, ) -> None: _seed_floating_tensors(original_block) _copy_sparse_moe_weights(original_block, defused_block) @@ -337,7 +346,12 @@ def _assert_sparse_moe_defused_matches_fused_math( actual = defused_block.eval()(hidden_states) # The defused replacement must preserve the exact MoE matmul path of the fused block. - torch.testing.assert_close(actual, expected) + assert_close_kwargs = {} + if atol is not None: + assert_close_kwargs["atol"] = atol + if rtol is not None: + assert_close_kwargs["rtol"] = rtol + torch.testing.assert_close(actual, expected, **assert_close_kwargs) def test_qwen2_moe(): @@ -453,6 +467,56 @@ def test_replace_fused_blocks_returns_false_for_unregistered_model(): assert replace_fused_blocks("unsupported_model_type") is False +def test_model_registry_requires_transformers_5_3_or_newer(): + assert {cfg["min_transformers_version"] for cfg in MODEL_CONFIG.values()} == {MIN_SUPPORTED_TRANSFORMERS_VERSION} + + +def test_replace_fused_blocks_warns_on_unsupported_transformers(monkeypatch): + warnings = [] + + monkeypatch.setattr(defuser_api.transformers, "__version__", "5.2.9") + monkeypatch.setattr(defuser_api.logger, "warning", warnings.append) + + assert defuser_api.replace_fused_blocks("mixtral") is False + assert len(warnings) == 1 + assert "replace_fused_blocks()" in warnings[0] + assert f"transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}" in warnings[0] + + +def test_convert_model_warns_on_unsupported_transformers(monkeypatch): + warnings = [] + + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(model_type="mixtral") + + monkeypatch.setattr(defuser_api.transformers, "__version__", "5.2.9") + monkeypatch.setattr(defuser_api.logger, "warning", warnings.append) + + assert defuser_api.convert_model(DummyModel()) is False + assert len(warnings) == 1 + assert "convert_model()" in warnings[0] + assert f"transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}" in warnings[0] + + +def test_pre_check_config_warns_on_unsupported_transformers(monkeypatch): + warnings = [] + + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(model_type="mixtral") + + monkeypatch.setattr(hf_utils.transformers, "__version__", "5.2.9") + monkeypatch.setattr(hf_utils.logger, "warning", warnings.append) + + assert hf_utils.pre_check_config(DummyModel()) is False + assert len(warnings) == 1 + assert "pre_check_config()" in warnings[0] + assert f"transformers>={MIN_SUPPORTED_TRANSFORMERS_VERSION}" in warnings[0] + + def test_qwen3_5_moe(): from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock @@ -680,9 +744,40 @@ def test_glm4_moe_defused_forward_matches_fused_math(): Glm4MoeMoE(config), LinearGlm4MoeMoE(config), hidden_states, + # GLM4 MoE now shows tiny fp32 roundoff drift in 5.3.0 because the fused gate/up matmul + # is compared against two split linears. Keep the tolerance narrow enough to catch real regressions. + atol=3e-4, + rtol=1e-4, ) +def test_defused_models_preserve_output_router_logits_capture(): + cases = [ + ( + "mixtral", + lambda: MixtralForCausalLM(_tiny_mixtral_config()), + ), + ( + "qwen2_moe", + lambda: Qwen2MoeForCausalLM(_tiny_moe_config(Qwen2MoeConfig)), + ), + ( + "qwen3_moe", + lambda: Qwen3MoeForCausalLM(_tiny_moe_config(Qwen3MoeConfig)), + ), + ] + + for model_type, build_model in cases: + replace_fused_blocks(model_type) + model = build_model().eval() + outputs = model(input_ids=torch.tensor([[1, 2, 3]]), output_router_logits=True) + + # Router logits are captured through upstream hooks, so defused blocks must keep the same router module type. + assert outputs.router_logits is not None + assert len(outputs.router_logits) == 1 + assert outputs.router_logits[0].shape == (3, model.config.num_experts) + + def test_glm4v_checkpoint_mapping_splits_gate_up_proj(): from defuser.defuser import get_checkpoint_conversion_mapping @@ -829,6 +924,27 @@ def test_llama4(): torch.testing.assert_close(expert0.down_proj.weight, expected_down) +def test_llama4_experts_forward_matches_fused_math(): + model = Llama4ForConditionalGeneration(_tiny_llama4_config()) + fused_experts = model.language_model.model.layers[0].feed_forward.experts + + hidden_states = torch.randn(fused_experts.num_experts * 5, model.config.text_config.hidden_size, dtype=torch.float32) + with torch.no_grad(): + expected = fused_experts(hidden_states) + + converted = convert_model(model, cleanup_original=False, max_layers=1) + assert converted + + split_experts = model.language_model.model.layers[0].feed_forward.experts + _assert_unfused_expert_module(split_experts) + materialize_model(model.language_model.model.layers[0]) + with torch.no_grad(): + actual = split_experts(hidden_states) + + # The batched-input generic path should preserve the original llama4 experts math. + torch.testing.assert_close(actual, expected) + + def test_llama4_split_forward_matches_fused_math(): from transformers.models.llama4.modeling_llama4 import Llama4TextMLP