diff --git a/README.md b/README.md index b5b4292..0851a76 100644 --- a/README.md +++ b/README.md @@ -33,24 +33,46 @@ from defuser import convert_model, replace_fused_blocks ``` - `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction. -- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints. +- `convert_model(model, cleanup_original=True, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints. - Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported. +`filter` is an optional list of PCRE regex rules evaluated against full module paths such as `model.layers.0.mlp.experts`: + +- `+:regex` explicitly includes matching candidate module paths +- `-:regex` explicitly excludes matching candidate module paths +- `regex` is shorthand for `+:regex` +- negative rules take priority over positive rules +- when `filter` is provided, a candidate module is defused only if it matches at least one positive rule and no negative rules + ## Supported Models -| Model type | Recommended entrypoint | Defused op performed | +Defuser currently supports the following `transformers==5.3.0` `model_type` values. + +### `replace_fused_blocks(model_type)` before load + +| Model type | Defused op performed | +| --- | --- | +| `glm4_moe` | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. | +| `glm4v` | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. | +| `mixtral` | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. | +| `qwen2_moe` | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. | +| `qwen3_moe` | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. | +| `qwen3_next` | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. | +| `qwen3_omni_moe` | Replaces both thinker and talker text sparse MoE blocks with defused per-expert linear blocks and applies small runtime compatibility patches for text `forward()` and `generate()`. | + +### `convert_model(model)` after load + +| Pattern | Supported model types | Defused op performed | | --- | --- | --- | -| `mixtral` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. | -| `qwen2_moe` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. | -| `qwen3_moe` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. | -| `qwen3_5_moe` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. | -| `qwen3_5_moe_text` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. | -| `qwen3_next` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. | -| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. | -| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. | -| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. | -| `gpt_oss` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, carries over expert biases, and converts fused expert tensors into numbered expert `nn.Linear` modules. | -| `llama4` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, converts fused expert tensors into numbered expert `nn.Linear` modules, and preserves the llama4 batched expert-input execution contract. | +| Standard routed expert tensors | `deepseek_v2`, `dots1`, `ernie4_5_moe`, `ernie4_5_vl_moe`, `exaone_moe`, `flex_olmo`, `glm4_moe_lite`, `glm4v_moe`, `hunyuan_v1_moe`, `jamba`, `lfm2_moe`, `minimax`, `minimax_m2`, `olmoe`, `qwen3_vl_moe`, `solar_open` | Splits fused expert tensors into numbered expert `nn.Linear` modules with per-expert `gate_proj`, `up_proj`, and `down_proj`. | +| Mixed sparse and shared experts | `deepseek_v3`, `glm_moe_dsa`, `qwen3_5_moe`, `qwen3_5_moe_text` | Runtime expert tensor defusion for routed experts while preserving the model's shared-expert path. | +| Transposed or packed expert tensors | `gpt_oss`, `phimoe` | Splits transposed fused expert `gate_up_proj` tensors into per-expert `gate_proj` + `up_proj`, preserves expert bias when present, and converts expert tensors into numbered expert `nn.Linear` modules. | +| Flattened expert layout | `dbrx` | Rebuilds the flattened DBRX expert FFN weights into numbered expert `gate_proj`, `up_proj`, and `down_proj` `nn.Linear` modules. | +| Batched expert-input execution | `llama4` | Runtime expert tensor defusion plus preservation of the llama4 batched expert-input execution contract. | +| Non-gated expert MLPs | `nemotron_h` | Converts routed expert tensors into numbered `up_proj` and `down_proj` `nn.Linear` modules for non-gated experts. | +| Parallel expert blocks | `granitemoe`, `granitemoehybrid`, `granitemoeshared`, `jetmoe` | Converts packed expert weight tensors into numbered expert `linear` modules while keeping grouped expert execution intact. | +| Routed experts with identity experts | `longcat_flash` | Defuses routed experts into numbered `gate_proj`, `up_proj`, and `down_proj` modules and preserves zero or identity experts. | +| Fused dense `gate_up_proj` MLPs | `dia`, `glm`, `glm4`, `glm_image`, `glm_ocr`, `phi3`, `phi4_multimodal`, `zamba2` | Splits fused dense `gate_up_proj` layers into `gate_proj` + `up_proj` and updates the block `forward()` to preserve the original MLP math. | ## Workflow Summary @@ -77,6 +99,20 @@ converted = convert_model(model) print(converted) # True when runtime defusion happened ``` +Use `filter` when only specific blocks should be defused: + +```python +from defuser import convert_model + +convert_model( + model, + filter=[ + r"+:^model\.layers\.0\.mlp\.experts$", + r"-:^model\.layers\.0\.mlp\.experts\.shared_", + ], +) +``` + ## Real Qwen3.5 MoE Example The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser supports `transformers>=5.3.0`. diff --git a/defuser/defuser.py b/defuser/defuser.py index 719eaf2..5bb49f6 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -117,6 +117,7 @@ def convert_model( model: nn.Module, cleanup_original: bool = False, max_layers: int | None = None, + filter: list[str] | None = None, ) -> bool: """Convert one loaded model in place from fused experts to defused modules.""" if warn_if_public_api_transformers_unsupported("convert_model()", logger): @@ -200,7 +201,7 @@ def convert_model( if not check_model_compatibility(model): return False - apply_model_patches(model) + apply_model_patches(model, max_layers=max_layers, filter_rules=filter) # If fused blocks have already been structurally replaced at load model before, # there is no need to perform runtime defusing again @@ -214,6 +215,7 @@ def convert_model( model, cleanup_original=cleanup_original, max_layers=max_layers, + filter_rules=filter, ) return True diff --git a/defuser/model_registry.py b/defuser/model_registry.py index a8617ac..0cfa5bb 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -16,6 +16,39 @@ class PATCH(str, Enum): MODEL_CONFIG = { + "dbrx": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "deepseek_v2": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "deepseek_v3": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "dia": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "dots1": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "ernie4_5_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "ernie4_5_vl_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "exaone_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "flex_olmo": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "glm": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "glm4": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, "mixtral": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ @@ -84,6 +117,10 @@ class PATCH(str, Enum): ( "transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock", "defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeThinkerTextSparseMoeBlock", + ), + ( + "transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeTalkerTextSparseMoeBlock", + "defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeTalkerTextSparseMoeBlock", ) ], }, @@ -96,6 +133,9 @@ class PATCH(str, Enum): ) ], }, + "glm4_moe_lite": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, "glm4v": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.REPLACE_MODULE: [ @@ -116,9 +156,39 @@ class PATCH(str, Enum): ), ], }, + "glm4v_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "glm_image": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "glm_moe_dsa": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "glm_ocr": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, "gpt_oss": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, }, + "granitemoe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "granitemoehybrid": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "granitemoeshared": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "hunyuan_v1_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "jamba": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "jetmoe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, "llama4": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, PATCH.EXPERTS_DEFUSE: [ @@ -128,7 +198,40 @@ class PATCH(str, Enum): } ], }, + "lfm2_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "longcat_flash": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "minimax": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "minimax_m2": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "nemotron_h": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "olmoe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "phi3": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "phi4_multimodal": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, "phimoe": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, }, + "qwen3_vl_moe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "solar_open": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, + "zamba2": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, } diff --git a/defuser/modeling/model_patches.py b/defuser/modeling/model_patches.py index b1202bf..594d515 100644 --- a/defuser/modeling/model_patches.py +++ b/defuser/modeling/model_patches.py @@ -8,6 +8,13 @@ from logbar import LogBar from defuser import DEBUG_ON +from defuser.modeling.runtime_defusion import ( + patch_dbrx_experts, + patch_longcat_flash_experts, + patch_parallel_experts, + patch_split_gate_up_mlp, +) +from defuser.utils.common import compile_module_name_filter, is_within_max_layers, matches_module_name_filter import torch logger = LogBar(__name__) @@ -38,7 +45,10 @@ def decorator(func: Callable): def patch_qwen3_omni_text_class() -> list[str]: """Teach HF init code how to initialize unfused qwen3-omni thinker experts.""" from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoePreTrainedModel - from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock + from defuser.modeling.unfused_moe.qwen3_omni_moe import ( + LinearQwen3OmniMoeTalkerTextSparseMoeBlock, + LinearQwen3OmniMoeThinkerTextSparseMoeBlock, + ) orig_init_weights = Qwen3OmniMoePreTrainedModel._init_weights def patched_init_weights(self, module): @@ -46,7 +56,7 @@ def patched_init_weights(self, module): orig_init_weights(self, module) except AttributeError as e: # fallback for unfused experts - if isinstance(module, LinearQwen3OmniMoeThinkerTextSparseMoeBlock): + if isinstance(module, (LinearQwen3OmniMoeThinkerTextSparseMoeBlock, LinearQwen3OmniMoeTalkerTextSparseMoeBlock)): std = self.config.initializer_range experts = module.experts @@ -56,9 +66,18 @@ def patched_init_weights(self, module): torch.nn.init.normal_(experts.up_proj.weight, 0.0, std) if hasattr(experts, "down_proj"): torch.nn.init.normal_(experts.down_proj.weight, 0.0, std) + if isinstance(experts, torch.nn.ModuleList): + for expert in experts: + torch.nn.init.normal_(expert.gate_proj.weight, 0.0, std) + torch.nn.init.normal_(expert.up_proj.weight, 0.0, std) + torch.nn.init.normal_(expert.down_proj.weight, 0.0, std) if hasattr(module, "gate"): torch.nn.init.normal_(module.gate.weight, 0.0, std) + if hasattr(module, "shared_expert"): + module.shared_expert._is_hf_initialized = True + if hasattr(module, "shared_expert_gate"): + torch.nn.init.normal_(module.shared_expert_gate.weight, 0.0, std) else: raise e @@ -68,7 +87,7 @@ def patched_init_weights(self, module): @register_model_patch("qwen3_omni_moe") -def patch_qwen3_omni_text_runtime(model) -> list[str]: +def patch_qwen3_omni_text_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: """Restore text-only ``forward`` and ``generate`` behavior after class swapping.""" model_cls = type(model) if not getattr(model_cls, "__module__", "").startswith("transformers.models.qwen3_omni_moe."): @@ -98,6 +117,218 @@ def forward(self, *args, **kwargs): return applied +def _patch_modules_by_class( + model, + patchers: dict[str, Callable], + *, + max_layers: int | None = None, + filter_rules=None, +) -> list[str]: + module_name_filter = compile_module_name_filter(filter_rules) + applied = [] + for name, module in list(model.named_modules()): + if not is_within_max_layers(name, max_layers): + continue + if not matches_module_name_filter(name, module_name_filter): + continue + class_path = f"{module.__class__.__module__}.{module.__class__.__name__}" + patcher = patchers.get(class_path) + if patcher is None: + continue + if patcher(module): + applied.append(name) + return applied + + +def _patch_split_gate_up_mlps( + model, + patchers: dict[str, str], + *, + max_layers: int | None = None, + filter_rules=None, +) -> list[str]: + return _patch_modules_by_class( + model, + { + class_path: (lambda module, variant=variant: patch_split_gate_up_mlp(module, variant=variant)) + for class_path, variant in patchers.items() + }, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +_STANDARD_SPLIT_GATE_UP_CLASSES = { + "transformers.models.dia.modeling_dia.DiaMLP": "standard", + "transformers.models.glm.modeling_glm.GlmMLP": "standard", + "transformers.models.glm4.modeling_glm4.Glm4MLP": "standard", + "transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP": "standard", + "transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP": "standard", + "transformers.models.phi3.modeling_phi3.Phi3MLP": "standard", + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP": "standard", + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP": "phi4_audio", + "transformers.models.zamba2.modeling_zamba2.Zamba2MLP": "zamba2", +} + + +@register_model_patch("dia") +def patch_dia_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.dia.modeling_dia.DiaMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.dia.modeling_dia.DiaMLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("glm") +def patch_glm_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.glm.modeling_glm.GlmMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm.modeling_glm.GlmMLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("glm4") +def patch_glm4_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.glm4.modeling_glm4.Glm4MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm4.modeling_glm4.Glm4MLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("glm_image") +def patch_glm_image_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("glm_ocr") +def patch_glm_ocr_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("phi3") +def patch_phi3_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.phi3.modeling_phi3.Phi3MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.phi3.modeling_phi3.Phi3MLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("phi4_multimodal") +def patch_phi4_multimodal_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + { + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP": + _STANDARD_SPLIT_GATE_UP_CLASSES[ + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP" + ], + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP": + _STANDARD_SPLIT_GATE_UP_CLASSES[ + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP" + ], + }, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("zamba2") +def patch_zamba2_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_split_gate_up_mlps( + model, + {"transformers.models.zamba2.modeling_zamba2.Zamba2MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.zamba2.modeling_zamba2.Zamba2MLP"]}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("dbrx") +def patch_dbrx_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_modules_by_class( + model, + {"transformers.models.dbrx.modeling_dbrx.DbrxExperts": patch_dbrx_experts}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +def _patch_parallel_runtime(model, class_path: str, *, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_modules_by_class( + model, + {class_path: patch_parallel_experts}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("granitemoe") +def patch_granitemoe_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_parallel_runtime( + model, + "transformers.models.granitemoe.modeling_granitemoe.GraniteMoeParallelExperts", + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("granitemoehybrid") +def patch_granitemoehybrid_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_parallel_runtime( + model, + "transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridParallelExperts", + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("granitemoeshared") +def patch_granitemoeshared_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_parallel_runtime( + model, + "transformers.models.granitemoeshared.modeling_granitemoeshared.GraniteMoeSharedParallelExperts", + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("jetmoe") +def patch_jetmoe_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_parallel_runtime( + model, + "transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts", + max_layers=max_layers, + filter_rules=filter_rules, + ) + + +@register_model_patch("longcat_flash") +def patch_longcat_flash_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]: + return _patch_modules_by_class( + model, + {"transformers.models.longcat_flash.modeling_longcat_flash.LongcatFlashExperts": patch_longcat_flash_experts}, + max_layers=max_layers, + filter_rules=filter_rules, + ) + + def apply_model_class_patches(model_type) -> list[str]: """Run any registered pre-construction patch for ``model_type``.""" patch_model_class = _MODEL_CLASS_PATCH_REGISTRY.get(model_type) @@ -110,7 +341,7 @@ def apply_model_class_patches(model_type) -> list[str]: return applied -def apply_model_patches(model) -> list[str]: +def apply_model_patches(model, max_layers: int | None = None, filter_rules=None) -> list[str]: """Run any registered runtime patch for the instantiated ``model``.""" config = getattr(model, "config", None) model_type = getattr(config, "model_type", None) @@ -118,7 +349,7 @@ def apply_model_patches(model) -> list[str]: if patch is None: return [] - applied = patch(model) + applied = patch(model, max_layers=max_layers, filter_rules=filter_rules) if applied and DEBUG_ON: logger.debug(f"Applied model patches for model_type={model_type}: {', '.join(applied)}") return applied diff --git a/defuser/modeling/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py index ad2b9d1..8f48a51 100644 --- a/defuser/modeling/moe_experts_interface.py +++ b/defuser/modeling/moe_experts_interface.py @@ -32,6 +32,7 @@ from torch import nn from defuser.model_registry import MODEL_CONFIG, PATCH +from defuser.utils.common import compile_module_name_filter, matches_module_name_filter from defuser.utils.device import clear_memory, to_meta from defuser import DEBUG_ON @@ -83,8 +84,18 @@ def is_linear_loop_available() -> bool: return HAS_EXPERTS_INTERFACE -def _apply_expert_gate(module: nn.Module, gate_out: torch.Tensor, up_out: torch.Tensor) -> torch.Tensor: - """Apply the expert's activation path using either a custom gate hook or ``act_fn``.""" +def _apply_expert_gate( + module: nn.Module, + gate_out: torch.Tensor | None, + up_out: torch.Tensor, +) -> torch.Tensor: + """Apply the expert activation path for gated and non-gated expert MLPs.""" + if gate_out is None: + act_fn = getattr(module, "act_fn", None) + if act_fn is None: + raise AttributeError(f"{module.__class__.__name__} must define `act_fn` for non-gated experts.") + return act_fn(up_out) + if hasattr(module, "_apply_gate"): return module._apply_gate(torch.cat([gate_out, up_out], dim=-1)) @@ -165,8 +176,12 @@ def linear_loop_experts_forward( # Get this expert's container with its projection layers expert = getattr(self, str(expert_idx)) - gate_out = expert.gate_proj(expert_input) # (num_samples, intermediate_dim) - up_out = expert.up_proj(expert_input) # (num_samples, intermediate_dim) + if hasattr(expert, "gate_proj"): + gate_out = expert.gate_proj(expert_input) # (num_samples, intermediate_dim) + up_out = expert.up_proj(expert_input) # (num_samples, intermediate_dim) + else: + gate_out = None + up_out = expert.up_proj(expert_input) gated_out = _apply_expert_gate(self, gate_out, up_out) # Down projection @@ -580,8 +595,15 @@ def _unfuse_experts_weights_inplace( is_transposed = getattr(module, "is_transposed", None) if is_transposed is None: # Infer from shape: typically hidden_dim < intermediate_dim - dim1, dim2 = first_param.shape[1], first_param.shape[2] - is_transposed = dim1 < dim2 + if ( + first_proj_name in {"up_proj", "down_proj"} + and "gate_up_proj" not in detected_projections + and "gate_proj" not in detected_projections + ): + is_transposed = False + else: + dim1, dim2 = first_param.shape[1], first_param.shape[2] + is_transposed = dim1 < dim2 dtype = first_param.dtype target_device = first_param.device if first_param.device.type != "meta" else "cpu" @@ -672,7 +694,11 @@ def _unfuse_experts_weights_inplace( return True -def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = LINEAR_LOOP_IMPL) -> list[str]: +def prepare_model_for_moe_quantization( + model: nn.Module, + implementation: str = LINEAR_LOOP_IMPL, + filter_rules=None, +) -> list[str]: """Prepare a model for MOE quantization using transformers' experts interface. This function: @@ -701,7 +727,10 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L unfused_modules = [] decorated_unfused_modules = [] experts_defuse_specs = _model_experts_defuse_specs(model) + module_name_filter = compile_module_name_filter(filter_rules) for name, module in model.named_modules(): + if not matches_module_name_filter(name, module_name_filter): + continue spec = _matching_experts_defuse_spec(module, experts_defuse_specs) if spec is not None and _unfuse_experts_weights_inplace( module, diff --git a/defuser/modeling/replace_modules.py b/defuser/modeling/replace_modules.py index dfc0385..97b6d7c 100644 --- a/defuser/modeling/replace_modules.py +++ b/defuser/modeling/replace_modules.py @@ -15,7 +15,12 @@ from logbar import LogBar from tqdm import tqdm -from defuser.utils.common import is_within_max_layers, is_transformers_version_greater_or_equal_5 +from defuser.utils.common import ( + compile_module_name_filter, + is_within_max_layers, + is_transformers_version_greater_or_equal_5, + matches_module_name_filter, +) from defuser import DEBUG_ON @@ -190,7 +195,7 @@ def _log_first_moe_block(model: torch.nn.Module, label: str) -> None: return -def _handle_moe_modules(model: torch.nn.Module) -> list[str]: +def _handle_moe_modules(model: torch.nn.Module, filter_rules=None) -> list[str]: """Handle fused MOE modules using transformers' linear_loop backend. Args: @@ -213,7 +218,7 @@ def _handle_moe_modules(model: torch.nn.Module) -> list[str]: return [] # Use transformers' experts interface - unfused = prepare_model_for_moe_quantization(model) + unfused = prepare_model_for_moe_quantization(model, filter_rules=filter_rules) if unfused: logger.info(f"Prepared {len(unfused)} MOE modules for quantization") return unfused @@ -223,6 +228,7 @@ def apply_replacements( model: torch.nn.Module, auto_detect_moe: bool = True, max_layers: int | None = None, + filter_rules=None, ) -> torch.nn.Module: """ Function to apply module replacements to a model. @@ -238,6 +244,8 @@ def apply_replacements( (transformers 5.0+ pattern). Default is True. max_layers: If provided, only replace modules under ``layers.`` where ``idx < max_layers``. + filter_rules: Optional regex rules selecting which candidate module paths + are allowed to be defused. Negative rules take priority over positive ones. Returns: The model with modules replaced. @@ -247,11 +255,11 @@ def apply_replacements( # Run the generic MoE tensor defusion pass first so models with supported # fused experts can stay on their upstream module structure. if auto_detect_moe and is_transformers_version_greater_or_equal_5(): - _handle_moe_modules(model) + _handle_moe_modules(model, filter_rules=filter_rules) # Fall back to replacement modules for any models that still need a custom # structural wrapper after the generic experts pass. - _apply_custom_replacements(model, max_layers=max_layers) + _apply_custom_replacements(model, max_layers=max_layers, filter_rules=filter_rules) _log_first_moe_block(model, "after replacement") @@ -261,6 +269,7 @@ def apply_replacements( def _apply_custom_replacements( model: torch.nn.Module, max_layers: int | None = None, + filter_rules=None, ) -> list: """Scan model and replace registered modules with custom implementations. @@ -271,6 +280,7 @@ def _apply_custom_replacements( List of (name, replacement_class) tuples for replaced modules. """ replaced = [] + module_name_filter = compile_module_name_filter(filter_rules) # Step 1: Collect all modules that need replacement if DEBUG_ON: logger.debug("Scanning for modules to replace") @@ -281,6 +291,8 @@ def _apply_custom_replacements( continue if not is_within_max_layers(name, max_layers): continue + if not matches_module_name_filter(name, module_name_filter): + continue class_name = module.__class__.__name__ if ReplacementModuleBase.is_registered(class_name) and ReplacementModuleBase.get_replacement_class( class_name @@ -300,6 +312,10 @@ def _apply_custom_replacements( f"Skipping replacement for {name}: class changed from {class_name} to {module.__class__.__name__}" ) continue + if not matches_module_name_filter(name, module_name_filter): + if DEBUG_ON: + logger.debug(f"Skipping replacement for {name}: module path excluded by filter") + continue replacement_cls = ReplacementModuleBase.get_replacement_class(class_name) if not replacement_cls.is_to_be_replaced(module): if DEBUG_ON: logger.debug(f"Skipping replacement for {name}: no longer matches replacement criteria") diff --git a/defuser/modeling/runtime_defusion.py b/defuser/modeling/runtime_defusion.py new file mode 100644 index 0000000..e05efc8 --- /dev/null +++ b/defuser/modeling/runtime_defusion.py @@ -0,0 +1,313 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from types import MethodType + +import torch +from torch import nn + +from defuser.modeling.moe_experts_interface import _ExpertContainer, _install_compact_expert_repr + + +def _activation_fn(module: nn.Module): + act_fn = getattr(module, "activation_fn", None) + if act_fn is None: + act_fn = getattr(module, "act_fn", None) + if act_fn is None: + raise AttributeError(f"{module.__class__.__name__} is missing `activation_fn`/`act_fn`.") + return act_fn + + +def _make_linear( + *, + in_features: int, + out_features: int, + weight: torch.Tensor, + bias: torch.Tensor | None = None, +) -> nn.Linear: + linear = nn.Linear( + in_features, + out_features, + bias=bias is not None, + device=weight.device, + dtype=weight.dtype, + ) + with torch.no_grad(): + linear.weight.copy_(weight) + if bias is not None: + linear.bias.copy_(bias) + return linear + + +def _split_gate_up_linear(gate_up_proj: nn.Linear) -> tuple[nn.Linear, nn.Linear]: + split_size = gate_up_proj.out_features // 2 + bias = gate_up_proj.bias + gate_bias = bias[:split_size].contiguous() if bias is not None else None + up_bias = bias[split_size:].contiguous() if bias is not None else None + gate_proj = _make_linear( + in_features=gate_up_proj.in_features, + out_features=split_size, + weight=gate_up_proj.weight[:split_size].contiguous(), + bias=gate_bias, + ) + up_proj = _make_linear( + in_features=gate_up_proj.in_features, + out_features=split_size, + weight=gate_up_proj.weight[split_size:].contiguous(), + bias=up_bias, + ) + return gate_proj, up_proj + + +def _standard_split_gate_up_forward(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + return self.down_proj(up * _activation_fn(self)(gate)) + + +def _phi4_audio_split_gate_up_forward(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.layer_norm(hidden_states) + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + up = up * self.act_fn(gate) + up = self.dropout(up) + hidden_states = self.down_proj(up) + return self.dropout(hidden_states) + + +def _zamba2_split_gate_up_forward( + self: nn.Module, + hidden_state: torch.Tensor, + layer_idx: int | None = None, +) -> torch.Tensor: + layer_idx = self.layer_dic[layer_idx] + adapter_out = self.gate_up_proj_adapter_list[layer_idx](hidden_state) + gate_adapter, up_adapter = torch.chunk(adapter_out, 2, dim=-1) + gate_state = self.gate_proj(hidden_state) + gate_adapter + up_state = self.up_proj(hidden_state) + up_adapter + return self.down_proj(self.act_fn(gate_state) * up_state) + + +def patch_split_gate_up_mlp(module: nn.Module, variant: str = "standard") -> bool: + if getattr(module, "_defuser_split_gate_up_runtime", False): + return False + + gate_up_proj = getattr(module, "gate_up_proj", None) + if not isinstance(gate_up_proj, nn.Linear): + return False + + gate_proj, up_proj = _split_gate_up_linear(gate_up_proj) + if variant == "phi4_audio": + gate_proj, up_proj = up_proj, gate_proj + module.add_module("gate_proj", gate_proj) + module.add_module("up_proj", up_proj) + delattr(module, "gate_up_proj") + + if variant == "standard": + module.forward = MethodType(_standard_split_gate_up_forward, module) + elif variant == "phi4_audio": + module.forward = MethodType(_phi4_audio_split_gate_up_forward, module) + elif variant == "zamba2": + module.forward = MethodType(_zamba2_split_gate_up_forward, module) + else: + raise ValueError(f"Unsupported split gate_up MLP variant: {variant}") + + module._defuser_split_gate_up_runtime = True + return True + + +def _parallel_experts_forward(self: nn.Module, inputs: torch.Tensor, expert_size) -> torch.Tensor: + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for expert_idx in range(self.num_experts): + output_list.append(getattr(self, str(expert_idx)).linear(input_list[expert_idx])) + return torch.cat(output_list, dim=0) + + +def patch_parallel_experts(module: nn.Module) -> bool: + if getattr(module, "_defuser_parallel_experts_runtime", False): + return False + + weight = getattr(module, "weight", None) + if not isinstance(weight, nn.Parameter) or weight.dim() != 3: + return False + + for expert_idx in range(module.num_experts): + container = _ExpertContainer() + linear = _make_linear( + in_features=module.input_size, + out_features=module.output_size, + weight=weight[expert_idx].contiguous(), + ) + container.add_module("linear", linear) + module.add_module(str(expert_idx), container) + + delattr(module, "weight") + module.forward = MethodType(_parallel_experts_forward, module) + module._unfused_experts = True + module._defuser_parallel_experts_runtime = True + _install_compact_expert_repr(module) + return True + + +def _longcat_flash_forward( + self: nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + if top_k_index.numel() == 0: + return final_hidden_states + + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False) + for expert_idx_tensor in expert_hit: + expert_idx = int(expert_idx_tensor.item()) + selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) + if token_idx.numel() == 0: + continue + + expert = getattr(self, str(expert_idx)) + current_state = hidden_states[token_idx] + if hasattr(expert, "identity"): + current_hidden_states = expert.identity(current_state) + else: + gate = expert.gate_proj(current_state) + up = expert.up_proj(current_state) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = expert.down_proj(current_hidden_states) + + current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype)) + + return final_hidden_states + + +def patch_longcat_flash_experts(module: nn.Module) -> bool: + if getattr(module, "_defuser_longcat_runtime", False): + return False + + gate_up_proj = getattr(module, "gate_up_proj", None) + down_proj = getattr(module, "down_proj", None) + if gate_up_proj is None and down_proj is None and hasattr(module, "0"): + return False + + for expert_idx in range(module.total_experts): + container = _ExpertContainer() + if expert_idx < module.num_routed_experts and gate_up_proj is not None and down_proj is not None: + fused_gate_up = gate_up_proj[expert_idx] + split_size = fused_gate_up.shape[0] // 2 + gate_proj = _make_linear( + in_features=module.hidden_size, + out_features=split_size, + weight=fused_gate_up[:split_size].contiguous(), + ) + up_proj = _make_linear( + in_features=module.hidden_size, + out_features=split_size, + weight=fused_gate_up[split_size:].contiguous(), + ) + down_linear = _make_linear( + in_features=module.intermediate_size, + out_features=module.hidden_size, + weight=down_proj[expert_idx].contiguous(), + ) + container.add_module("gate_proj", gate_proj) + container.add_module("up_proj", up_proj) + container.add_module("down_proj", down_linear) + else: + container.add_module("identity", nn.Identity()) + module.add_module(str(expert_idx), container) + + if hasattr(module, "gate_up_proj"): + delattr(module, "gate_up_proj") + if hasattr(module, "down_proj"): + delattr(module, "down_proj") + + module.num_experts = module.total_experts + module.forward = MethodType(_longcat_flash_forward, module) + module._unfused_experts = True + module._defuser_longcat_runtime = True + _install_compact_expert_repr(module) + return True + + +def _dbrx_experts_forward( + self: nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.ffn_hidden_size) + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero().flatten() + + for expert_idx in expert_hit.tolist(): + idx, token_idx = torch.where(expert_mask[expert_idx]) + expert = getattr(self, str(expert_idx)) + current_state = hidden_states[token_idx] + current_hidden_states = self.act_fn(expert.gate_proj(current_state)) * expert.up_proj(current_state) + current_hidden_states = expert.down_proj(current_hidden_states) + current_hidden_states = current_hidden_states.view(-1, self.ffn_hidden_size) * top_k_weights[token_idx, idx, None] + next_states.index_add_(0, token_idx, current_hidden_states) + + return next_states.view(batch_size, -1, self.ffn_hidden_size) + + +def patch_dbrx_experts(module: nn.Module) -> bool: + if getattr(module, "_defuser_dbrx_runtime", False): + return False + + mlp = getattr(module, "mlp", None) + if mlp is None or not hasattr(mlp, "w1") or not hasattr(mlp, "v1") or not hasattr(mlp, "w2"): + return False + + split_shape = (module.num_experts, module.ffn_hidden_size, module.hidden_size) + w1 = mlp.w1.view(split_shape) + v1 = mlp.v1.view(split_shape) + w2 = mlp.w2.view(split_shape) + + for expert_idx in range(module.num_experts): + container = _ExpertContainer() + container.add_module( + "gate_proj", + _make_linear( + in_features=module.ffn_hidden_size, + out_features=module.hidden_size, + weight=w1[expert_idx].t().contiguous(), + ), + ) + container.add_module( + "up_proj", + _make_linear( + in_features=module.ffn_hidden_size, + out_features=module.hidden_size, + weight=v1[expert_idx].t().contiguous(), + ), + ) + container.add_module( + "down_proj", + _make_linear( + in_features=module.hidden_size, + out_features=module.ffn_hidden_size, + weight=w2[expert_idx].contiguous(), + ), + ) + module.add_module(str(expert_idx), container) + + module.act_fn = mlp.activation_fn + delattr(module, "mlp") + module.forward = MethodType(_dbrx_experts_forward, module) + module._unfused_experts = True + module._defuser_dbrx_runtime = True + _install_compact_expert_repr(module) + return True diff --git a/defuser/modeling/unfused_moe/qwen3_omni_moe.py b/defuser/modeling/unfused_moe/qwen3_omni_moe.py index 798629a..c7054c4 100644 --- a/defuser/modeling/unfused_moe/qwen3_omni_moe.py +++ b/defuser/modeling/unfused_moe/qwen3_omni_moe.py @@ -46,3 +46,47 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states + + +class LinearQwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module): + """Text talker MoE block for qwen3-omni with explicit per-expert modules.""" + + def __init__(self, config): + super().__init__() + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerTextMLP, + Qwen3OmniMoeTalkerTextTopKRouter, + ) + + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + self.gate = Qwen3OmniMoeTalkerTextTopKRouter(config) + self.experts = nn.ModuleList( + [ + Qwen3OmniMoeTalkerTextMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(self.num_experts) + ] + ) + self.shared_expert = Qwen3OmniMoeTalkerTextMLP( + config, + intermediate_size=config.shared_expert_intermediate_size, + ) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_expert_output = self.shared_expert(hidden_states) + _, routing_weights, selected_experts = self.gate(hidden_states) + final_hidden_states = run_routed_experts( + self.experts, + hidden_states, + routing_weights.to(hidden_states.dtype), + selected_experts, + self.num_experts, + ) + shared_expert_output = torch.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + final_hidden_states = final_hidden_states + shared_expert_output + return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) diff --git a/defuser/modeling/update_module.py b/defuser/modeling/update_module.py index 6a4b2cf..398afca 100644 --- a/defuser/modeling/update_module.py +++ b/defuser/modeling/update_module.py @@ -13,9 +13,10 @@ def update_module( model, cleanup_original: bool = True, max_layers: int | None = None, + filter_rules=None, ): """Run Defuser's replacement pipeline and optionally drop original modules.""" - model = apply_replacements(model, max_layers=max_layers) + model = apply_replacements(model, max_layers=max_layers, filter_rules=filter_rules) if cleanup_original: release_original_module_(model) diff --git a/defuser/utils/common.py b/defuser/utils/common.py index 842797e..a1df6cc 100644 --- a/defuser/utils/common.py +++ b/defuser/utils/common.py @@ -6,18 +6,28 @@ # Adapted from intel/auto-round # at https://github.com/intel/auto-round/blob/main/auto_round/utils/common.py +from collections.abc import Sequence +from dataclasses import dataclass from functools import lru_cache -import re from packaging import version +import pcre # Match module paths like "...layers.0..." and capture the numeric layer index. -_LAYER_NAME_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)") +_LAYER_NAME_RE = pcre.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)") TRUTHFUL = {"1", "true", "yes", "on", "y"} MIN_SUPPORTED_TRANSFORMERS_VERSION = "5.3.0" +@dataclass(frozen=True) +class ModuleNameFilter: + """Compiled include or exclude rules for module path matching.""" + + positive: tuple[pcre.Pattern, ...] + negative: tuple[pcre.Pattern, ...] + + def env_flag(name: str, default: str | bool | None = "0") -> bool: """Return ``True`` when an env var is set to a truthy value.""" @@ -71,3 +81,68 @@ def is_within_max_layers(module_name: str, max_layers: int | None) -> bool: if match is None: return True return int(match.group(1)) < max_layers + + +def compile_module_name_filter( + filter_rules: Sequence[str] | ModuleNameFilter | None, +) -> ModuleNameFilter | None: + """Compile user-facing module filter rules once for repeated matching. + + Rules support three forms: + - ``+:regex`` explicit positive match + - ``-:regex`` explicit negative match + - ``regex`` implicit positive match + + Negative rules take priority over positive rules during matching. + """ + if filter_rules is None: + return None + + if isinstance(filter_rules, ModuleNameFilter): + return filter_rules + + if isinstance(filter_rules, (str, bytes)) or not isinstance(filter_rules, Sequence): + raise TypeError("filter must be a sequence of regex strings") + + positive: list[pcre.Pattern] = [] + negative: list[pcre.Pattern] = [] + for raw_rule in filter_rules: + if not isinstance(raw_rule, str): + raise TypeError("filter rules must be strings") + + if raw_rule.startswith("-:"): + bucket = negative + pattern = raw_rule[2:] + elif raw_rule.startswith("+:"): + bucket = positive + pattern = raw_rule[2:] + else: + bucket = positive + pattern = raw_rule + + bucket.append(pcre.compile(pattern)) + + return ModuleNameFilter( + positive=tuple(positive), + negative=tuple(negative), + ) + + +def matches_module_name_filter( + module_name: str, + filter_rules: Sequence[str] | ModuleNameFilter | None, +) -> bool: + """Return whether ``module_name`` is allowed by the configured filter rules.""" + compiled = compile_module_name_filter(filter_rules) + if compiled is None: + return True + + for pattern in compiled.negative: + if pattern.search(module_name): + return False + + for pattern in compiled.positive: + if pattern.search(module_name): + return True + + return False diff --git a/pyproject.toml b/pyproject.toml index ea58861..e03bf39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ classifiers = [ ] keywords = ["Defuser", "Model", "HF", "Transformers"] dependencies = [ + "pypcre>=0.2.13", "transformers", ] diff --git a/requirements.txt b/requirements.txt index 4e449cc..d9e36ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ +pypcre>=0.2.13 transformers logbar diff --git a/tests/test_candidate_coverage.py b/tests/test_candidate_coverage.py new file mode 100644 index 0000000..a2045db --- /dev/null +++ b/tests/test_candidate_coverage.py @@ -0,0 +1,663 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from importlib import import_module +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +from defuser import convert_model +from defuser.model_registry import MODEL_CONFIG + + +class _DummyLayer(nn.Module): + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + + +class _DummyModel(nn.Module): + def __init__(self, model_type: str, module: nn.Module): + super().__init__() + self.config = getattr(module, "config", None) + if self.config is None: + self.config = SimpleNamespace() + self.config.model_type = model_type + self.layers = nn.ModuleList([_DummyLayer(module)]) + + +def _load(module_path: str, attr_name: str): + return getattr(import_module(module_path), attr_name) + + +def _build_config(case: dict): + config = _load(case["config_module"], case["config_name"])() + sub_attr = case.get("sub_attr") + if sub_attr is not None: + config = getattr(config, sub_attr) + for attr, value in { + "hidden_size": 64, + "intermediate_size": 128, + "moe_intermediate_size": 32, + "expert_ffn_hidden_size": 32, + "num_local_experts": 4, + "num_experts": 4, + "n_routed_experts": 4, + "moe_num_experts": 4, + "num_experts_per_tok": 2, + "hidden_act": "silu", + "mlp_hidden_act": "silu", + "activation": "silu", + "activation_function": "silu", + "use_bias": False, + "mlp_bias": False, + "add_bias_linear": False, + "n_group": 1, + "topk_group": 1, + "n_shared_experts": 1, + "num_shared_experts": 1, + "routed_scaling_factor": 1.0, + "norm_topk_prob": False, + "use_expert_bias": False, + "moe_latent_size": None, + "zero_expert_num": 1, + "adapter_rank": 8, + "hybrid_layer_ids": [0], + "num_mem_blocks": 1, + "dropout_rate": 0.0, + }.items(): + if hasattr(config, attr): + setattr(config, attr, value) + for attr, value in case.get("config_updates", {}).items(): + if hasattr(config, attr): + setattr(config, attr, value) + return config + + +def _build_module(case: dict) -> nn.Module: + module_cls = _load(case["module_path"], case["class_name"]) + kind = case.get("kind", "config") + if kind == "parallel": + return module_cls(case["num_experts"], case["input_size"], case["output_size"]).eval() + if kind == "zamba2": + return module_cls(_build_config(case), num_fwd_mem_blocks=1, block_id=0).eval() + return module_cls(_build_config(case)).eval() + + +def _wrapped_model(case: dict) -> tuple[_DummyModel, nn.Module]: + module = _build_module(case) + generator = torch.Generator(device="cpu").manual_seed(0) + with torch.no_grad(): + for tensor in list(module.parameters()) + list(module.buffers()): + if tensor.is_floating_point(): + tensor.copy_(torch.randn(tensor.shape, generator=generator, dtype=tensor.dtype, device=tensor.device)) + return _DummyModel(case["model_type"], module), module + + +def _patched_module(model: _DummyModel) -> nn.Module: + return model.layers[0].module + + +def _assert_expert_container(module: nn.Module, attrs: tuple[str, ...]) -> None: + assert hasattr(module, "0") + expert0 = getattr(module, "0") + for attr in attrs: + assert hasattr(expert0, attr) + + +def _standard_hidden(case: dict) -> torch.Tensor: + return torch.randn(case.get("hidden_shape", (5, case["input_dim"])), dtype=torch.float32) + + +STANDARD_MOE_CASES = [ + { + "model_type": "deepseek_v2", + "module_path": "transformers.models.deepseek_v2.modeling_deepseek_v2", + "class_name": "DeepseekV2Experts", + "config_module": "transformers.models.deepseek_v2.configuration_deepseek_v2", + "config_name": "DeepseekV2Config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "deepseek_v3", + "module_path": "transformers.models.deepseek_v3.modeling_deepseek_v3", + "class_name": "DeepseekV3NaiveMoe", + "config_module": "transformers.models.deepseek_v3.configuration_deepseek_v3", + "config_name": "DeepseekV3Config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "dots1", + "module_path": "transformers.models.dots1.modeling_dots1", + "class_name": "Dots1NaiveMoe", + "config_module": "transformers.models.dots1.configuration_dots1", + "config_name": "Dots1Config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "ernie4_5_moe", + "module_path": "transformers.models.ernie4_5_moe.modeling_ernie4_5_moe", + "class_name": "Ernie4_5_MoeExperts", + "config_module": "transformers.models.ernie4_5_moe.configuration_ernie4_5_moe", + "config_name": "Ernie4_5_MoeConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu", "use_bias": False}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "ernie4_5_vl_moe", + "module_path": "transformers.models.ernie4_5_vl_moe.modeling_ernie4_5_vl_moe", + "class_name": "Ernie4_5_VLMoeMoeExperts", + "config_module": "transformers.models.ernie4_5_vl_moe.configuration_ernie4_5_vl_moe", + "config_name": "Ernie4_5_VLMoeConfig", + "sub_attr": "text_config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu", "use_bias": False}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "exaone_moe", + "module_path": "transformers.models.exaone_moe.modeling_exaone_moe", + "class_name": "ExaoneMoeExperts", + "config_module": "transformers.models.exaone_moe.configuration_exaone_moe", + "config_name": "ExaoneMoeConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "flex_olmo", + "module_path": "transformers.models.flex_olmo.modeling_flex_olmo", + "class_name": "FlexOlmoExperts", + "config_module": "transformers.models.flex_olmo.configuration_flex_olmo", + "config_name": "FlexOlmoConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "glm4_moe_lite", + "module_path": "transformers.models.glm4_moe_lite.modeling_glm4_moe_lite", + "class_name": "Glm4MoeLiteNaiveMoe", + "config_module": "transformers.models.glm4_moe_lite.configuration_glm4_moe_lite", + "config_name": "Glm4MoeLiteConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "glm4v_moe", + "module_path": "transformers.models.glm4v_moe.modeling_glm4v_moe", + "class_name": "Glm4vMoeTextNaiveMoe", + "config_module": "transformers.models.glm4v_moe.configuration_glm4v_moe", + "config_name": "Glm4vMoeConfig", + "sub_attr": "text_config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "glm_moe_dsa", + "module_path": "transformers.models.glm_moe_dsa.modeling_glm_moe_dsa", + "class_name": "GlmMoeDsaNaiveMoe", + "config_module": "transformers.models.glm_moe_dsa.configuration_glm_moe_dsa", + "config_name": "GlmMoeDsaConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "hunyuan_v1_moe", + "module_path": "transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe", + "class_name": "HunYuanMoEV1Experts", + "config_module": "transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe", + "config_name": "HunYuanMoEV1Config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "jamba", + "module_path": "transformers.models.jamba.modeling_jamba", + "class_name": "JambaExperts", + "config_module": "transformers.models.jamba.configuration_jamba", + "config_name": "JambaConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_experts": 4, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "lfm2_moe", + "module_path": "transformers.models.lfm2_moe.modeling_lfm2_moe", + "class_name": "Lfm2MoeExperts", + "config_module": "transformers.models.lfm2_moe.configuration_lfm2_moe", + "config_name": "Lfm2MoeConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "minimax", + "module_path": "transformers.models.minimax.modeling_minimax", + "class_name": "MiniMaxExperts", + "config_module": "transformers.models.minimax.configuration_minimax", + "config_name": "MiniMaxConfig", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "minimax_m2", + "module_path": "transformers.models.minimax_m2.modeling_minimax_m2", + "class_name": "MiniMaxM2Experts", + "config_module": "transformers.models.minimax_m2.configuration_minimax_m2", + "config_name": "MiniMaxM2Config", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "nemotron_h", + "module_path": "transformers.models.nemotron_h.modeling_nemotron_h", + "class_name": "NemotronHExperts", + "config_module": "transformers.models.nemotron_h.configuration_nemotron_h", + "config_name": "NemotronHConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 96, "n_routed_experts": 4, "mlp_hidden_act": "silu", "moe_latent_size": None}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("up_proj", "down_proj"), + }, + { + "model_type": "olmoe", + "module_path": "transformers.models.olmoe.modeling_olmoe", + "class_name": "OlmoeExperts", + "config_module": "transformers.models.olmoe.configuration_olmoe", + "config_name": "OlmoeConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "qwen3_vl_moe", + "module_path": "transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe", + "class_name": "Qwen3VLMoeTextExperts", + "config_module": "transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe", + "config_name": "Qwen3VLMoeConfig", + "sub_attr": "text_config", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_experts": 4, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "solar_open", + "module_path": "transformers.models.solar_open.modeling_solar_open", + "class_name": "SolarOpenNaiveMoe", + "config_module": "transformers.models.solar_open.configuration_solar_open", + "config_name": "SolarOpenConfig", + "config_updates": {"hidden_size": 64, "moe_intermediate_size": 32, "num_local_experts": 4, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0], [1], [2], [3], [0]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + }, + { + "model_type": "dbrx", + "module_path": "transformers.models.dbrx.modeling_dbrx", + "class_name": "DbrxExperts", + "config_module": "transformers.models.dbrx.configuration_dbrx", + "config_name": "DbrxFFNConfig", + "config_updates": {"hidden_size": 32, "ffn_hidden_size": 64, "moe_num_experts": 4}, + "input_dim": 64, + "hidden_shape": (2, 3, 64), + "route_indices": [[0], [1], [2], [3], [0], [1]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + "atol": 2e-4, + "rtol": 2e-5, + }, + { + "model_type": "longcat_flash", + "module_path": "transformers.models.longcat_flash.modeling_longcat_flash", + "class_name": "LongcatFlashExperts", + "config_module": "transformers.models.longcat_flash.configuration_longcat_flash", + "config_name": "LongcatFlashConfig", + "config_updates": {"hidden_size": 64, "expert_ffn_hidden_size": 32, "n_routed_experts": 4, "zero_expert_num": 1, "hidden_act": "silu"}, + "input_dim": 64, + "route_indices": [[0, 4], [4, 0], [2, 1], [3, 4], [1, 2]], + "route_weights": [[0.6, 0.4], [0.7, 0.3], [0.5, 0.5], [0.8, 0.2], [0.9, 0.1]], + "expert_attrs": ("gate_proj", "up_proj", "down_proj"), + "identity_expert_index": 4, + }, +] + + +PARALLEL_CASES = [ + { + "model_type": "granitemoe", + "module_path": "transformers.models.granitemoe.modeling_granitemoe", + "class_name": "GraniteMoeParallelExperts", + "kind": "parallel", + "num_experts": 4, + "input_size": 64, + "output_size": 96, + }, + { + "model_type": "granitemoehybrid", + "module_path": "transformers.models.granitemoehybrid.modeling_granitemoehybrid", + "class_name": "GraniteMoeHybridParallelExperts", + "kind": "parallel", + "num_experts": 4, + "input_size": 64, + "output_size": 96, + }, + { + "model_type": "granitemoeshared", + "module_path": "transformers.models.granitemoeshared.modeling_granitemoeshared", + "class_name": "GraniteMoeSharedParallelExperts", + "kind": "parallel", + "num_experts": 4, + "input_size": 64, + "output_size": 96, + }, + { + "model_type": "jetmoe", + "module_path": "transformers.models.jetmoe.modeling_jetmoe", + "class_name": "JetMoeParallelExperts", + "kind": "parallel", + "num_experts": 4, + "input_size": 64, + "output_size": 96, + }, +] + + +DENSE_CASES = [ + { + "label": "dia", + "model_type": "dia", + "module_path": "transformers.models.dia.modeling_dia", + "class_name": "DiaMLP", + "config_module": "transformers.models.dia.configuration_dia", + "config_name": "DiaConfig", + "sub_attr": "decoder_config", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "glm", + "model_type": "glm", + "module_path": "transformers.models.glm.modeling_glm", + "class_name": "GlmMLP", + "config_module": "transformers.models.glm.configuration_glm", + "config_name": "GlmConfig", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "glm4", + "model_type": "glm4", + "module_path": "transformers.models.glm4.modeling_glm4", + "class_name": "Glm4MLP", + "config_module": "transformers.models.glm4.configuration_glm4", + "config_name": "Glm4Config", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "glm_image", + "model_type": "glm_image", + "module_path": "transformers.models.glm_image.modeling_glm_image", + "class_name": "GlmImageTextMLP", + "config_module": "transformers.models.glm_image.configuration_glm_image", + "config_name": "GlmImageConfig", + "sub_attr": "text_config", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "glm_ocr", + "model_type": "glm_ocr", + "module_path": "transformers.models.glm_ocr.modeling_glm_ocr", + "class_name": "GlmOcrTextMLP", + "config_module": "transformers.models.glm_ocr.configuration_glm_ocr", + "config_name": "GlmOcrConfig", + "sub_attr": "text_config", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "phi3", + "model_type": "phi3", + "module_path": "transformers.models.phi3.modeling_phi3", + "class_name": "Phi3MLP", + "config_module": "transformers.models.phi3.configuration_phi3", + "config_name": "Phi3Config", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "phi4_multimodal_text", + "model_type": "phi4_multimodal", + "module_path": "transformers.models.phi4_multimodal.modeling_phi4_multimodal", + "class_name": "Phi4MultimodalMLP", + "config_module": "transformers.models.phi4_multimodal.configuration_phi4_multimodal", + "config_name": "Phi4MultimodalConfig", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "hidden_act": "silu"}, + "hidden_size": 64, + }, + { + "label": "phi4_multimodal_audio", + "model_type": "phi4_multimodal", + "module_path": "transformers.models.phi4_multimodal.modeling_phi4_multimodal", + "class_name": "Phi4MultimodalAudioMLP", + "config_module": "transformers.models.phi4_multimodal.configuration_phi4_multimodal", + "config_name": "Phi4MultimodalAudioConfig", + "config_updates": {"hidden_size": 64, "intermediate_size": 128, "activation": "silu", "dropout_rate": 0.0}, + "hidden_size": 64, + }, + { + "label": "zamba2", + "model_type": "zamba2", + "module_path": "transformers.models.zamba2.modeling_zamba2", + "class_name": "Zamba2MLP", + "config_module": "transformers.models.zamba2.configuration_zamba2", + "config_name": "Zamba2Config", + "config_updates": { + "hidden_size": 64, + "intermediate_size": 128, + "hidden_act": "silu", + "adapter_rank": 8, + "hybrid_layer_ids": [0], + "num_mem_blocks": 1, + "add_bias_linear": False, + }, + "hidden_size": 64, + "kind": "zamba2", + }, +] + + +ALL_CANDIDATE_MODEL_TYPES = { + "dbrx", + "deepseek_v2", + "deepseek_v3", + "dia", + "dots1", + "ernie4_5_moe", + "ernie4_5_vl_moe", + "exaone_moe", + "flex_olmo", + "glm", + "glm4", + "glm4_moe", + "glm4_moe_lite", + "glm4v", + "glm4v_moe", + "glm_image", + "glm_moe_dsa", + "glm_ocr", + "gpt_oss", + "granitemoe", + "granitemoehybrid", + "granitemoeshared", + "hunyuan_v1_moe", + "jamba", + "jetmoe", + "lfm2_moe", + "llama4", + "longcat_flash", + "minimax", + "minimax_m2", + "mixtral", + "nemotron_h", + "olmoe", + "phi3", + "phi4_multimodal", + "phimoe", + "qwen2_moe", + "qwen3_5_moe", + "qwen3_moe", + "qwen3_next", + "qwen3_omni_moe", + "qwen3_vl_moe", + "solar_open", + "zamba2", +} + + +def test_model_registry_covers_all_scanned_candidates(): + assert ALL_CANDIDATE_MODEL_TYPES.issubset(MODEL_CONFIG) + + +@pytest.mark.parametrize("case", STANDARD_MOE_CASES, ids=[case["model_type"] for case in STANDARD_MOE_CASES]) +def test_standard_moe_candidates_convert_and_preserve_forward(case): + torch.manual_seed(0) + model, original_module = _wrapped_model(case) + hidden_states = _standard_hidden(case) + top_k_index = torch.tensor(case["route_indices"], dtype=torch.long) + route_weights = case.get("route_weights") + if route_weights is None: + top_k_weights = torch.ones(top_k_index.shape, dtype=hidden_states.dtype) + else: + top_k_weights = torch.tensor(route_weights, dtype=hidden_states.dtype) + + with torch.no_grad(): + expected = original_module(hidden_states.clone(), top_k_index, top_k_weights) + + converted = convert_model(model) + assert converted is True + + patched = _patched_module(model) + _assert_expert_container(patched, case["expert_attrs"]) + + with torch.no_grad(): + actual = patched(hidden_states.clone(), top_k_index, top_k_weights) + + torch.testing.assert_close( + actual, + expected, + atol=case.get("atol", 1e-5), + rtol=case.get("rtol", 1.3e-6), + ) + + if case.get("identity_expert_index") is not None: + assert hasattr(getattr(patched, str(case["identity_expert_index"])), "identity") + + +@pytest.mark.parametrize("case", PARALLEL_CASES, ids=[case["model_type"] for case in PARALLEL_CASES]) +def test_parallel_expert_candidates_convert_and_preserve_forward(case): + torch.manual_seed(0) + model, original_module = _wrapped_model(case) + expert_size = [2, 1, 0, 3] + inputs = torch.randn(sum(expert_size), case["input_size"], dtype=torch.float32) + + with torch.no_grad(): + expected = original_module(inputs.clone(), expert_size) + + converted = convert_model(model) + assert converted is True + + patched = _patched_module(model) + _assert_expert_container(patched, ("linear",)) + assert not hasattr(patched, "weight") + + with torch.no_grad(): + actual = patched(inputs.clone(), expert_size) + + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("case", DENSE_CASES, ids=[case["label"] for case in DENSE_CASES]) +def test_dense_candidates_convert_and_preserve_forward(case): + torch.manual_seed(0) + model, original_module = _wrapped_model(case) + hidden_states = torch.randn(3, case["hidden_size"], dtype=torch.float32) + + with torch.no_grad(): + if case["label"] == "zamba2": + expected = original_module(hidden_states.clone(), layer_idx=0) + else: + expected = original_module(hidden_states.clone()) + + converted = convert_model(model) + assert converted is True + + patched = _patched_module(model) + assert hasattr(patched, "gate_proj") + assert hasattr(patched, "up_proj") + assert not hasattr(patched, "gate_up_proj") + + with torch.no_grad(): + if case["label"] == "zamba2": + actual = patched(hidden_states.clone(), layer_idx=0) + else: + actual = patched(hidden_states.clone()) + + torch.testing.assert_close(actual, expected) + + +def test_runtime_model_patches_respect_max_layers(): + module0 = _build_module(next(case for case in DENSE_CASES if case["label"] == "phi3")) + module1 = _build_module(next(case for case in DENSE_CASES if case["label"] == "phi3")) + + class TwoLayerModel(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(model_type="phi3") + self.layers = nn.ModuleList([_DummyLayer(module0), _DummyLayer(module1)]) + + model = TwoLayerModel() + converted = convert_model(model, max_layers=1) + + assert converted is True + assert hasattr(model.layers[0].module, "gate_proj") + assert not hasattr(model.layers[0].module, "gate_up_proj") + assert hasattr(model.layers[1].module, "gate_up_proj") + assert not hasattr(model.layers[1].module, "gate_proj") diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index e8e2b64..e7920db 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -37,6 +37,7 @@ from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4ForConditionalGeneration from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( Qwen3OmniMoeForConditionalGeneration, + Qwen3OmniMoeTalkerTextSparseMoeBlock, Qwen3OmniMoeThinkerTextSparseMoeBlock, ) @@ -51,7 +52,10 @@ from defuser.modeling.unfused_moe.qwen2_moe import LinearQwen2MoeSparseMoeBlock from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock from defuser.modeling.unfused_moe.qwen3_next import LinearQwen3NextSparseMoeBlock -from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock +from defuser.modeling.unfused_moe.qwen3_omni_moe import ( + LinearQwen3OmniMoeTalkerTextSparseMoeBlock, + LinearQwen3OmniMoeThinkerTextSparseMoeBlock, +) from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION @@ -116,6 +120,25 @@ def _tiny_qwen3_omni_config(): ) +def _tiny_qwen3_omni_talker_text_config(): + config = Qwen3OmniMoeConfig(enable_audio_output=True).talker_config.text_config + config.hidden_size = 64 + config.intermediate_size = 128 + config.moe_intermediate_size = 32 + config.shared_expert_intermediate_size = 32 + config.num_hidden_layers = 1 + config.num_attention_heads = 4 + config.num_key_value_heads = 4 + config.head_dim = 16 + config.num_experts = 4 + config.num_experts_per_tok = 2 + config.vocab_size = 128 + config.pad_token_id = 0 + config.bos_token_id = 1 + config.eos_token_id = 2 + return config + + def _tiny_qwen3_5_moe_config(): return Qwen3_5MoeConfig( text_config={ @@ -861,6 +884,17 @@ def test_qwen3_omni_defused_forward_matches_fused_math(): ) +def test_qwen3_omni_talker_defused_forward_matches_fused_math(): + config = _tiny_qwen3_omni_talker_text_config() + hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) + + _assert_sparse_moe_defused_matches_fused_math( + Qwen3OmniMoeTalkerTextSparseMoeBlock(config), + LinearQwen3OmniMoeTalkerTextSparseMoeBlock(config), + hidden_states, + ) + + def test_glm4_moe_defused_forward_matches_fused_math(): config = _tiny_glm4_moe_config() hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) diff --git a/tests/test_filter_rules.py b/tests/test_filter_rules.py new file mode 100644 index 0000000..3dbb9f3 --- /dev/null +++ b/tests/test_filter_rules.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from torch import nn +from transformers.models.deepseek_v2.configuration_deepseek_v2 import DeepseekV2Config +from transformers.models.deepseek_v2.modeling_deepseek_v2 import DeepseekV2Experts +from transformers.models.phi3.configuration_phi3 import Phi3Config +from transformers.models.phi3.modeling_phi3 import Phi3MLP + +from defuser import convert_model +from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements +from defuser.utils.common import compile_module_name_filter, matches_module_name_filter + + +class _DummyLayer(nn.Module): + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + + +class _WrappedModel(nn.Module): + def __init__(self, model_type: str, modules: list[nn.Module]): + super().__init__() + self.config = SimpleNamespace(model_type=model_type) + self.layers = nn.ModuleList([_DummyLayer(module) for module in modules]) + + +def _tiny_phi3_mlp() -> Phi3MLP: + config = Phi3Config( + hidden_size=64, + intermediate_size=128, + hidden_act="silu", + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=4, + vocab_size=128, + ) + return Phi3MLP(config).eval() + + +def _tiny_deepseek_v2_experts() -> DeepseekV2Experts: + config = DeepseekV2Config( + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=1, + num_experts_per_tok=2, + n_routed_experts=4, + num_local_experts=4, + vocab_size=128, + ) + return DeepseekV2Experts(config).eval() + + +class FilterDummyOriginal(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + + +class FilterDummyReplacement(ReplacementModuleBase): + @classmethod + def original_module_class(cls) -> str: + return "FilterDummyOriginal" + + @classmethod + def from_original(cls, original: torch.nn.Module, config): + replacement = cls(original) + replacement.config = config + return replacement + + def _materialize_weights(self) -> None: + return + + +def test_filter_rules_default_to_positive_and_negative_has_priority(): + positive = compile_module_name_filter(["layers\\.1\\.module$"]) + assert matches_module_name_filter("layers.1.module", positive) is True + assert matches_module_name_filter("layers.0.module", positive) is False + + overridden = compile_module_name_filter( + ["+:layers\\.0\\.module$", "-:layers\\.0\\.module$"] + ) + assert matches_module_name_filter("layers.0.module", overridden) is False + + +def test_filter_rules_use_pcre_syntax(): + module_filter = compile_module_name_filter([r"+:^layers\.\K0\.module$"]) + assert matches_module_name_filter("layers.0.module", module_filter) is True + assert matches_module_name_filter("layers.1.module", module_filter) is False + + +def test_filter_rules_match_full_hf_style_module_paths(): + module_filter = compile_module_name_filter([r"+:^model\.layers\.0\.mlp\.experts$"]) + + assert matches_module_name_filter("model.layers.0.mlp.experts", module_filter) is True + assert matches_module_name_filter("model.layers.1.mlp.experts", module_filter) is False + + +def test_filter_rules_require_positive_match_when_filter_is_present(): + module_filter = compile_module_name_filter([r"-:^model\.layers\.1\.mlp\.experts$"]) + + assert matches_module_name_filter("model.layers.0.mlp.experts", module_filter) is False + assert matches_module_name_filter("model.layers.1.mlp.experts", module_filter) is False + assert matches_module_name_filter("model.layers.0.mlp.experts", []) is False + + +def test_filter_rules_reject_invalid_filter_inputs(): + with pytest.raises(TypeError, match="filter must be a sequence of regex strings"): + compile_module_name_filter(r"+:^layers\.0$") + + with pytest.raises(TypeError, match="filter rules must be strings"): + compile_module_name_filter([r"+:^layers\.0$", 1]) + + +def test_convert_model_filter_limits_dense_runtime_patch(): + model = _WrappedModel("phi3", [_tiny_phi3_mlp(), _tiny_phi3_mlp()]) + + converted = convert_model(model, filter=[r"+:^layers\.\K0\.module$"]) + + assert converted is True + assert hasattr(model.layers[0].module, "gate_proj") + assert not hasattr(model.layers[0].module, "gate_up_proj") + assert hasattr(model.layers[1].module, "gate_up_proj") + assert not hasattr(model.layers[1].module, "gate_proj") + + +def test_convert_model_filter_negative_overrides_positive(): + model = _WrappedModel("phi3", [_tiny_phi3_mlp(), _tiny_phi3_mlp()]) + + converted = convert_model( + model, + filter=[ + r"+:^layers\.0\.module$", + r"-:^layers\.0\.module$", + ], + ) + + assert converted is True + assert hasattr(model.layers[0].module, "gate_up_proj") + assert not hasattr(model.layers[0].module, "gate_proj") + assert hasattr(model.layers[1].module, "gate_up_proj") + assert not hasattr(model.layers[1].module, "gate_proj") + + +def test_convert_model_filter_combines_with_max_layers(): + model = _WrappedModel("phi3", [_tiny_phi3_mlp(), _tiny_phi3_mlp()]) + + converted = convert_model( + model, + max_layers=1, + filter=[r"+:^layers\.1\.module$"], + ) + + assert converted is True + assert hasattr(model.layers[0].module, "gate_up_proj") + assert not hasattr(model.layers[0].module, "gate_proj") + assert hasattr(model.layers[1].module, "gate_up_proj") + assert not hasattr(model.layers[1].module, "gate_proj") + + +def test_convert_model_filter_applies_to_moe_tensor_defusion(): + matched = _WrappedModel("deepseek_v2", [_tiny_deepseek_v2_experts()]) + skipped = _WrappedModel("deepseek_v2", [_tiny_deepseek_v2_experts()]) + + assert convert_model(matched, filter=[r"+:^layers\.0\.module$"]) is True + assert hasattr(matched.layers[0].module, "0") + assert not hasattr(matched.layers[0].module, "gate_up_proj") + + assert convert_model(skipped, filter=[r"+:^layers\.1\.module$"]) is True + assert not hasattr(skipped.layers[0].module, "0") + assert hasattr(skipped.layers[0].module, "gate_up_proj") + + +def test_apply_replacements_filter_applies_to_custom_replacements(): + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace() + self.layers = nn.ModuleList([FilterDummyOriginal()]) + + skipped = DummyModel() + apply_replacements(skipped, auto_detect_moe=False, filter_rules=[r"+:^layers\.1$"]) + assert isinstance(skipped.layers[0], FilterDummyOriginal) + + matched = DummyModel() + apply_replacements(matched, auto_detect_moe=False, filter_rules=[r"+:^layers\.0$"]) + assert isinstance(matched.layers[0], FilterDummyReplacement) diff --git a/tests/test_meta_model_defusion.py b/tests/test_meta_model_defusion.py new file mode 100644 index 0000000..ccd0a6a --- /dev/null +++ b/tests/test_meta_model_defusion.py @@ -0,0 +1,763 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from importlib import import_module + +import pytest +import torch + +from defuser import convert_model, replace_fused_blocks +from defuser.model_registry import MODEL_CONFIG + + +def _load(module_path: str, attr_name: str): + return getattr(import_module(module_path), attr_name) + + +def _set_if_has(obj, **kwargs) -> None: + for attr, value in kwargs.items(): + if hasattr(obj, attr): + setattr(obj, attr, value) + + +def _mutate_common_config_tree(config, visited: set[int] | None = None) -> None: + if config is None or isinstance(config, (int, float, str, bool, list, tuple, dict)): + return + + if visited is None: + visited = set() + if id(config) in visited: + return + visited.add(id(config)) + + _set_if_has( + config, + vocab_size=128, + hidden_size=64, + hidden_dim=64, + d_model=64, + intermediate_size=128, + moe_intermediate_size=32, + shared_expert_intermediate_size=32, + expert_ffn_hidden_size=32, + ffn_hidden_size=128, + num_hidden_layers=2, + num_layers=2, + n_layers=2, + decoder_layers=2, + encoder_layers=1, + num_attention_heads=4, + num_key_value_heads=1, + head_dim=16, + num_heads=4, + encoder_attention_heads=4, + num_local_experts=4, + num_experts=4, + moe_num_experts=4, + n_routed_experts=4, + num_experts_per_tok=2, + top_k=2, + hidden_act="silu", + activation="silu", + activation_function="silu", + mlp_hidden_act="silu", + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + image_size=16, + patch_size=4, + num_channels=3, + out_hidden_size=64, + depth=1, + num_position_embeddings=64, + n_shared_experts=1, + num_shared_experts=1, + n_group=1, + topk_group=1, + use_bias=False, + mlp_bias=False, + add_bias_linear=False, + zero_expert_num=1, + adapter_rank=8, + hybrid_layer_ids=[0], + num_mem_blocks=1, + dropout_rate=0.0, + conv_chunksize=16, + n_window=4, + n_window_infer=4, + max_source_positions=32, + num_mel_bins=16, + encoder_ffn_dim=128, + output_dim=64, + downsample_hidden_size=32, + max_position_embeddings=64, + initializer_range=0.02, + ) + + for attr in ( + "text_config", + "vision_config", + "audio_config", + "decoder_config", + "thinker_config", + "talker_config", + "language_config", + "llm_config", + "attn_config", + "ffn_config", + "code_predictor_config", + ): + _mutate_common_config_tree(getattr(config, attr, None), visited) + + thinker = getattr(config, "thinker_config", None) + if thinker is not None: + for attr in ("text_config", "vision_config", "audio_config"): + _mutate_common_config_tree(getattr(thinker, attr, None), visited) + + talker = getattr(config, "talker_config", None) + if talker is not None: + for attr in ("text_config", "speech_config", "code_predictor_config"): + _mutate_common_config_tree(getattr(talker, attr, None), visited) + + +def _build_model_config(case: dict): + config = _load(case["config_module"], case["config_class"])() + _mutate_common_config_tree(config) + + model_type = case["model_type"] + if model_type == "dbrx": + config.max_seq_len = 64 + config.resid_pdrop = 0.0 + config.emb_pdrop = 0.0 + config.attn_config.kv_n_heads = 1 + config.attn_config.clip_qkv = None + config.attn_config.rope_theta = 10000.0 + config.ffn_config.ffn_hidden_size = 128 + config.ffn_config.moe_num_experts = 4 + config.ffn_config.moe_top_k = 2 + config.ffn_config.ffn_act_fn = {"name": "silu"} + elif model_type == "deepseek_v3": + config.first_k_dense_replace = 0 + elif model_type == "ernie4_5_vl_moe": + config.text_config.moe_intermediate_size = [32, 32] + config.text_config.rope_parameters = { + "rope_theta": 10000.0, + "rope_type": "default", + "mrope_section": [2, 2, 4], + } + elif model_type == "glm4_moe": + config.first_k_dense_replace = -1 + elif model_type == "glm_moe_dsa": + config.mlp_layer_types = ["sparse"] * config.num_hidden_layers + elif model_type == "granitemoehybrid": + config.layer_types = ["attention", "mamba"] + config.shared_intermediate_size = 64 + config.mamba_n_heads = 8 + elif model_type == "lfm2_moe": + config.layer_types = ["full_attention", "short_conv"] + config.num_dense_layers = 0 + elif model_type == "qwen3_omni_moe": + config.enable_audio_output = True + config.talker_config.spatial_merge_size = 2 + config.talker_config.thinker_hidden_size = 64 + config.talker_config.text_config.shared_expert_intermediate_size = 32 + code_predictor = getattr(config.talker_config, "code_predictor_config", None) + _set_if_has( + code_predictor, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=1, + head_dim=16, + ) + + return config + + +def _find_module_hits(model, class_paths: tuple[str, ...]) -> list[tuple[str, str]]: + hits = [] + wanted = set(class_paths) + for name, module in model.named_modules(): + class_path = f"{module.__class__.__module__}.{module.__class__.__name__}" + if class_path in wanted: + hits.append((name, class_path)) + return hits + + +def _assert_meta_parameters(module) -> None: + for _, param in module.named_parameters(recurse=True): + assert param.is_meta + + +def _assert_all_model_parameters_meta(model) -> None: + for _, param in model.named_parameters(): + assert param.is_meta + + +def _validate_defused_module(case: dict, module) -> None: + kind = case["validator"] + + if kind == "experts": + assert hasattr(module, "0") + expert0 = getattr(module, "0") + assert hasattr(expert0, "gate_proj") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") + assert not hasattr(module, "gate_up_proj") + _assert_meta_parameters(expert0) + return + + if kind == "nongated_experts": + assert hasattr(module, "0") + expert0 = getattr(module, "0") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") + assert not hasattr(expert0, "gate_proj") + _assert_meta_parameters(expert0) + return + + if kind == "parallel": + assert hasattr(module, "0") + assert hasattr(getattr(module, "0"), "linear") + assert not hasattr(module, "weight") + _assert_meta_parameters(getattr(module, "0")) + return + + if kind == "dense_split": + assert hasattr(module, "gate_proj") + assert hasattr(module, "up_proj") + assert hasattr(module, "down_proj") + assert not hasattr(module, "gate_up_proj") + _assert_meta_parameters(module) + return + + if kind == "longcat": + assert hasattr(module, "0") + assert hasattr(getattr(module, "0"), "gate_proj") + assert not hasattr(module, "gate_up_proj") + assert any(hasattr(getattr(module, str(idx)), "identity") for idx in range(module.num_experts)) + _assert_meta_parameters(getattr(module, "0")) + return + + if kind == "sparse_block": + expert0 = module.experts[0] + assert hasattr(expert0, "gate_proj") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") + assert not hasattr(module.experts, "gate_up_proj") + _assert_meta_parameters(expert0) + return + + raise AssertionError(f"Unsupported validator kind: {kind}") + + +META_MODEL_CASES = [ + { + "model_type": "dbrx", + "mode": "convert", + "model_module": "transformers.models.dbrx.modeling_dbrx", + "model_class": "DbrxForCausalLM", + "config_module": "transformers.models.dbrx.configuration_dbrx", + "config_class": "DbrxConfig", + "target_class_paths": ("transformers.models.dbrx.modeling_dbrx.DbrxExperts",), + "validator": "experts", + }, + { + "model_type": "deepseek_v2", + "mode": "convert", + "model_module": "transformers.models.deepseek_v2.modeling_deepseek_v2", + "model_class": "DeepseekV2ForCausalLM", + "config_module": "transformers.models.deepseek_v2.configuration_deepseek_v2", + "config_class": "DeepseekV2Config", + "target_class_paths": ("transformers.models.deepseek_v2.modeling_deepseek_v2.DeepseekV2Experts",), + "validator": "experts", + }, + { + "model_type": "deepseek_v3", + "mode": "convert", + "model_module": "transformers.models.deepseek_v3.modeling_deepseek_v3", + "model_class": "DeepseekV3ForCausalLM", + "config_module": "transformers.models.deepseek_v3.configuration_deepseek_v3", + "config_class": "DeepseekV3Config", + "target_class_paths": ("transformers.models.deepseek_v3.modeling_deepseek_v3.DeepseekV3NaiveMoe",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "dia", + "mode": "convert", + "model_module": "transformers.models.dia.modeling_dia", + "model_class": "DiaForConditionalGeneration", + "config_module": "transformers.models.dia.configuration_dia", + "config_class": "DiaConfig", + "target_class_paths": ("transformers.models.dia.modeling_dia.DiaMLP",), + "validator": "dense_split", + }, + { + "model_type": "dots1", + "mode": "convert", + "model_module": "transformers.models.dots1.modeling_dots1", + "model_class": "Dots1ForCausalLM", + "config_module": "transformers.models.dots1.configuration_dots1", + "config_class": "Dots1Config", + "target_class_paths": ("transformers.models.dots1.modeling_dots1.Dots1NaiveMoe",), + "validator": "experts", + }, + { + "model_type": "ernie4_5_moe", + "mode": "convert", + "model_module": "transformers.models.ernie4_5_moe.modeling_ernie4_5_moe", + "model_class": "Ernie4_5_MoeForCausalLM", + "config_module": "transformers.models.ernie4_5_moe.configuration_ernie4_5_moe", + "config_class": "Ernie4_5_MoeConfig", + "target_class_paths": ("transformers.models.ernie4_5_moe.modeling_ernie4_5_moe.Ernie4_5_MoeExperts",), + "validator": "experts", + }, + { + "model_type": "ernie4_5_vl_moe", + "mode": "convert", + "model_module": "transformers.models.ernie4_5_vl_moe.modeling_ernie4_5_vl_moe", + "model_class": "Ernie4_5_VLMoeForConditionalGeneration", + "config_module": "transformers.models.ernie4_5_vl_moe.configuration_ernie4_5_vl_moe", + "config_class": "Ernie4_5_VLMoeConfig", + "target_class_paths": ( + "transformers.models.ernie4_5_vl_moe.modeling_ernie4_5_vl_moe.Ernie4_5_VLMoeMoeExperts", + ), + "validator": "experts", + }, + { + "model_type": "exaone_moe", + "mode": "convert", + "model_module": "transformers.models.exaone_moe.modeling_exaone_moe", + "model_class": "ExaoneMoeForCausalLM", + "config_module": "transformers.models.exaone_moe.configuration_exaone_moe", + "config_class": "ExaoneMoeConfig", + "target_class_paths": ("transformers.models.exaone_moe.modeling_exaone_moe.ExaoneMoeExperts",), + "validator": "experts", + }, + { + "model_type": "flex_olmo", + "mode": "convert", + "model_module": "transformers.models.flex_olmo.modeling_flex_olmo", + "model_class": "FlexOlmoForCausalLM", + "config_module": "transformers.models.flex_olmo.configuration_flex_olmo", + "config_class": "FlexOlmoConfig", + "target_class_paths": ("transformers.models.flex_olmo.modeling_flex_olmo.FlexOlmoExperts",), + "validator": "experts", + }, + { + "model_type": "glm", + "mode": "convert", + "model_module": "transformers.models.glm.modeling_glm", + "model_class": "GlmForCausalLM", + "config_module": "transformers.models.glm.configuration_glm", + "config_class": "GlmConfig", + "target_class_paths": ("transformers.models.glm.modeling_glm.GlmMLP",), + "validator": "dense_split", + }, + { + "model_type": "glm4", + "mode": "convert", + "model_module": "transformers.models.glm4.modeling_glm4", + "model_class": "Glm4ForCausalLM", + "config_module": "transformers.models.glm4.configuration_glm4", + "config_class": "Glm4Config", + "target_class_paths": ("transformers.models.glm4.modeling_glm4.Glm4MLP",), + "validator": "dense_split", + }, + { + "model_type": "glm4_moe", + "mode": "replace", + "model_module": "transformers.models.glm4_moe.modeling_glm4_moe", + "model_class": "Glm4MoeForCausalLM", + "config_module": "transformers.models.glm4_moe.configuration_glm4_moe", + "config_class": "Glm4MoeConfig", + "target_class_paths": ("defuser.modeling.unfused_moe.glm4_moe.LinearGlm4MoeMoE",), + "validator": "sparse_block", + }, + { + "model_type": "glm4_moe_lite", + "mode": "convert", + "model_module": "transformers.models.glm4_moe_lite.modeling_glm4_moe_lite", + "model_class": "Glm4MoeLiteForCausalLM", + "config_module": "transformers.models.glm4_moe_lite.configuration_glm4_moe_lite", + "config_class": "Glm4MoeLiteConfig", + "target_class_paths": ("transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteNaiveMoe",), + "validator": "experts", + }, + { + "model_type": "glm4v", + "mode": "replace", + "model_module": "transformers.models.glm4v.modeling_glm4v", + "model_class": "Glm4vForConditionalGeneration", + "config_module": "transformers.models.glm4v.configuration_glm4v", + "config_class": "Glm4vConfig", + "target_class_paths": ("defuser.modeling.glm4v.LinearGlm4vTextMLP",), + "validator": "dense_split", + }, + { + "model_type": "glm4v_moe", + "mode": "convert", + "model_module": "transformers.models.glm4v_moe.modeling_glm4v_moe", + "model_class": "Glm4vMoeForConditionalGeneration", + "config_module": "transformers.models.glm4v_moe.configuration_glm4v_moe", + "config_class": "Glm4vMoeConfig", + "target_class_paths": ("transformers.models.glm4v_moe.modeling_glm4v_moe.Glm4vMoeTextNaiveMoe",), + "validator": "experts", + }, + { + "model_type": "glm_image", + "mode": "convert", + "model_module": "transformers.models.glm_image.modeling_glm_image", + "model_class": "GlmImageForConditionalGeneration", + "config_module": "transformers.models.glm_image.configuration_glm_image", + "config_class": "GlmImageConfig", + "target_class_paths": ("transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP",), + "validator": "dense_split", + }, + { + "model_type": "glm_moe_dsa", + "mode": "convert", + "model_module": "transformers.models.glm_moe_dsa.modeling_glm_moe_dsa", + "model_class": "GlmMoeDsaForCausalLM", + "config_module": "transformers.models.glm_moe_dsa.configuration_glm_moe_dsa", + "config_class": "GlmMoeDsaConfig", + "target_class_paths": ("transformers.models.glm_moe_dsa.modeling_glm_moe_dsa.GlmMoeDsaNaiveMoe",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "glm_ocr", + "mode": "convert", + "model_module": "transformers.models.glm_ocr.modeling_glm_ocr", + "model_class": "GlmOcrForConditionalGeneration", + "config_module": "transformers.models.glm_ocr.configuration_glm_ocr", + "config_class": "GlmOcrConfig", + "target_class_paths": ("transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP",), + "validator": "dense_split", + }, + { + "model_type": "gpt_oss", + "mode": "convert", + "model_module": "transformers.models.gpt_oss.modeling_gpt_oss", + "model_class": "GptOssForCausalLM", + "config_module": "transformers.models.gpt_oss.configuration_gpt_oss", + "config_class": "GptOssConfig", + "target_class_paths": ("transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "granitemoe", + "mode": "convert", + "model_module": "transformers.models.granitemoe.modeling_granitemoe", + "model_class": "GraniteMoeForCausalLM", + "config_module": "transformers.models.granitemoe.configuration_granitemoe", + "config_class": "GraniteMoeConfig", + "target_class_paths": ("transformers.models.granitemoe.modeling_granitemoe.GraniteMoeParallelExperts",), + "validator": "parallel", + "min_targets": 4, + }, + { + "model_type": "granitemoehybrid", + "mode": "convert", + "model_module": "transformers.models.granitemoehybrid.modeling_granitemoehybrid", + "model_class": "GraniteMoeHybridForCausalLM", + "config_module": "transformers.models.granitemoehybrid.configuration_granitemoehybrid", + "config_class": "GraniteMoeHybridConfig", + "target_class_paths": ( + "transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridParallelExperts", + ), + "validator": "parallel", + "min_targets": 4, + }, + { + "model_type": "granitemoeshared", + "mode": "convert", + "model_module": "transformers.models.granitemoeshared.modeling_granitemoeshared", + "model_class": "GraniteMoeSharedForCausalLM", + "config_module": "transformers.models.granitemoeshared.configuration_granitemoeshared", + "config_class": "GraniteMoeSharedConfig", + "target_class_paths": ( + "transformers.models.granitemoeshared.modeling_granitemoeshared.GraniteMoeSharedParallelExperts", + ), + "validator": "parallel", + "min_targets": 4, + }, + { + "model_type": "hunyuan_v1_moe", + "mode": "convert", + "model_module": "transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe", + "model_class": "HunYuanMoEV1ForCausalLM", + "config_module": "transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe", + "config_class": "HunYuanMoEV1Config", + "target_class_paths": ("transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe.HunYuanMoEV1Experts",), + "validator": "experts", + }, + { + "model_type": "jamba", + "mode": "convert", + "model_module": "transformers.models.jamba.modeling_jamba", + "model_class": "JambaForCausalLM", + "config_module": "transformers.models.jamba.configuration_jamba", + "config_class": "JambaConfig", + "target_class_paths": ("transformers.models.jamba.modeling_jamba.JambaExperts",), + "validator": "experts", + }, + { + "model_type": "jetmoe", + "mode": "convert", + "model_module": "transformers.models.jetmoe.modeling_jetmoe", + "model_class": "JetMoeForCausalLM", + "config_module": "transformers.models.jetmoe.configuration_jetmoe", + "config_class": "JetMoeConfig", + "target_class_paths": ("transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts",), + "validator": "parallel", + "min_targets": 4, + }, + { + "model_type": "lfm2_moe", + "mode": "convert", + "model_module": "transformers.models.lfm2_moe.modeling_lfm2_moe", + "model_class": "Lfm2MoeForCausalLM", + "config_module": "transformers.models.lfm2_moe.configuration_lfm2_moe", + "config_class": "Lfm2MoeConfig", + "target_class_paths": ("transformers.models.lfm2_moe.modeling_lfm2_moe.Lfm2MoeExperts",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "llama4", + "mode": "convert", + "model_module": "transformers.models.llama4.modeling_llama4", + "model_class": "Llama4ForConditionalGeneration", + "config_module": "transformers.models.llama4.configuration_llama4", + "config_class": "Llama4Config", + "target_class_paths": ("transformers.models.llama4.modeling_llama4.Llama4TextExperts",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "longcat_flash", + "mode": "convert", + "model_module": "transformers.models.longcat_flash.modeling_longcat_flash", + "model_class": "LongcatFlashForCausalLM", + "config_module": "transformers.models.longcat_flash.configuration_longcat_flash", + "config_class": "LongcatFlashConfig", + "target_class_paths": ("transformers.models.longcat_flash.modeling_longcat_flash.LongcatFlashExperts",), + "validator": "longcat", + "min_targets": 2, + }, + { + "model_type": "minimax", + "mode": "convert", + "model_module": "transformers.models.minimax.modeling_minimax", + "model_class": "MiniMaxForCausalLM", + "config_module": "transformers.models.minimax.configuration_minimax", + "config_class": "MiniMaxConfig", + "target_class_paths": ("transformers.models.minimax.modeling_minimax.MiniMaxExperts",), + "validator": "experts", + }, + { + "model_type": "minimax_m2", + "mode": "convert", + "model_module": "transformers.models.minimax_m2.modeling_minimax_m2", + "model_class": "MiniMaxM2ForCausalLM", + "config_module": "transformers.models.minimax_m2.configuration_minimax_m2", + "config_class": "MiniMaxM2Config", + "target_class_paths": ("transformers.models.minimax_m2.modeling_minimax_m2.MiniMaxM2Experts",), + "validator": "experts", + }, + { + "model_type": "mixtral", + "mode": "replace", + "model_module": "transformers.models.mixtral.modeling_mixtral", + "model_class": "MixtralForCausalLM", + "config_module": "transformers.models.mixtral.configuration_mixtral", + "config_class": "MixtralConfig", + "target_class_paths": ("defuser.modeling.unfused_moe.mixtral.LinearMixtralSparseMoeBlock",), + "validator": "sparse_block", + }, + { + "model_type": "nemotron_h", + "mode": "convert", + "model_module": "transformers.models.nemotron_h.modeling_nemotron_h", + "model_class": "NemotronHForCausalLM", + "config_module": "transformers.models.nemotron_h.configuration_nemotron_h", + "config_class": "NemotronHConfig", + "target_class_paths": ("transformers.models.nemotron_h.modeling_nemotron_h.NemotronHExperts",), + "validator": "nongated_experts", + "min_targets": 2, + }, + { + "model_type": "olmoe", + "mode": "convert", + "model_module": "transformers.models.olmoe.modeling_olmoe", + "model_class": "OlmoeForCausalLM", + "config_module": "transformers.models.olmoe.configuration_olmoe", + "config_class": "OlmoeConfig", + "target_class_paths": ("transformers.models.olmoe.modeling_olmoe.OlmoeExperts",), + "validator": "experts", + }, + { + "model_type": "phi3", + "mode": "convert", + "model_module": "transformers.models.phi3.modeling_phi3", + "model_class": "Phi3ForCausalLM", + "config_module": "transformers.models.phi3.configuration_phi3", + "config_class": "Phi3Config", + "target_class_paths": ("transformers.models.phi3.modeling_phi3.Phi3MLP",), + "validator": "dense_split", + }, + { + "model_type": "phi4_multimodal", + "mode": "convert", + "model_module": "transformers.models.phi4_multimodal.modeling_phi4_multimodal", + "model_class": "Phi4MultimodalForCausalLM", + "config_module": "transformers.models.phi4_multimodal.configuration_phi4_multimodal", + "config_class": "Phi4MultimodalConfig", + "target_class_paths": ( + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP", + "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP", + ), + "validator": "dense_split", + "min_targets": 4, + }, + { + "model_type": "phimoe", + "mode": "convert", + "model_module": "transformers.models.phimoe.modeling_phimoe", + "model_class": "PhimoeForCausalLM", + "config_module": "transformers.models.phimoe.configuration_phimoe", + "config_class": "PhimoeConfig", + "target_class_paths": ("transformers.models.phimoe.modeling_phimoe.PhimoeExperts",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "qwen2_moe", + "mode": "replace", + "model_module": "transformers.models.qwen2_moe.modeling_qwen2_moe", + "model_class": "Qwen2MoeForCausalLM", + "config_module": "transformers.models.qwen2_moe.configuration_qwen2_moe", + "config_class": "Qwen2MoeConfig", + "target_class_paths": ("defuser.modeling.unfused_moe.qwen2_moe.LinearQwen2MoeSparseMoeBlock",), + "validator": "sparse_block", + }, + { + "model_type": "qwen3_5_moe", + "mode": "convert", + "model_module": "transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", + "model_class": "Qwen3_5MoeForConditionalGeneration", + "config_module": "transformers.models.qwen3_5_moe.configuration_qwen3_5_moe", + "config_class": "Qwen3_5MoeConfig", + "target_class_paths": ("transformers.models.qwen3_5_moe.modeling_qwen3_5_moe.Qwen3_5MoeExperts",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "qwen3_moe", + "mode": "replace", + "model_module": "transformers.models.qwen3_moe.modeling_qwen3_moe", + "model_class": "Qwen3MoeForCausalLM", + "config_module": "transformers.models.qwen3_moe.configuration_qwen3_moe", + "config_class": "Qwen3MoeConfig", + "target_class_paths": ("defuser.modeling.unfused_moe.qwen3_moe.LinearQwen3MoeSparseMoeBlock",), + "validator": "sparse_block", + }, + { + "model_type": "qwen3_next", + "mode": "replace", + "model_module": "transformers.models.qwen3_next.modeling_qwen3_next", + "model_class": "Qwen3NextForCausalLM", + "config_module": "transformers.models.qwen3_next.configuration_qwen3_next", + "config_class": "Qwen3NextConfig", + "target_class_paths": ("defuser.modeling.unfused_moe.qwen3_next.LinearQwen3NextSparseMoeBlock",), + "validator": "sparse_block", + }, + { + "model_type": "qwen3_omni_moe", + "mode": "replace", + "model_module": "transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe", + "model_class": "Qwen3OmniMoeForConditionalGeneration", + "config_module": "transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe", + "config_class": "Qwen3OmniMoeConfig", + "target_class_paths": ( + "defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeThinkerTextSparseMoeBlock", + "defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeTalkerTextSparseMoeBlock", + ), + "validator": "sparse_block", + "min_targets": 2, + }, + { + "model_type": "qwen3_vl_moe", + "mode": "convert", + "model_module": "transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe", + "model_class": "Qwen3VLMoeForConditionalGeneration", + "config_module": "transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe", + "config_class": "Qwen3VLMoeConfig", + "target_class_paths": ("transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe.Qwen3VLMoeTextExperts",), + "validator": "experts", + "min_targets": 2, + }, + { + "model_type": "solar_open", + "mode": "convert", + "model_module": "transformers.models.solar_open.modeling_solar_open", + "model_class": "SolarOpenForCausalLM", + "config_module": "transformers.models.solar_open.configuration_solar_open", + "config_class": "SolarOpenConfig", + "target_class_paths": ("transformers.models.solar_open.modeling_solar_open.SolarOpenNaiveMoe",), + "validator": "experts", + }, + { + "model_type": "zamba2", + "mode": "convert", + "model_module": "transformers.models.zamba2.modeling_zamba2", + "model_class": "Zamba2ForCausalLM", + "config_module": "transformers.models.zamba2.configuration_zamba2", + "config_class": "Zamba2Config", + "target_class_paths": ("transformers.models.zamba2.modeling_zamba2.Zamba2MLP",), + "validator": "dense_split", + }, +] + + +def test_meta_model_cases_cover_registered_public_models(): + assert {case["model_type"] for case in META_MODEL_CASES} == set(MODEL_CONFIG) - {"qwen3_5_moe_text"} + + +@pytest.mark.parametrize("case", META_MODEL_CASES, ids=[case["model_type"] for case in META_MODEL_CASES]) +def test_each_model_defuses_direct_meta_model(case): + if case["mode"] == "replace": + replace_fused_blocks(case["model_type"]) + + config = _build_model_config(case) + model_cls = _load(case["model_module"], case["model_class"]) + with torch.device("meta"): + model = model_cls(config) + + _assert_all_model_parameters_meta(model) + + hits = _find_module_hits(model, case["target_class_paths"]) + assert hits + assert set(case["target_class_paths"]).issubset({class_path for _, class_path in hits}) + assert len(hits) >= case.get("min_targets", 1) + + if case["mode"] == "replace": + for path, _ in hits: + _validate_defused_module(case, model.get_submodule(path)) + assert convert_model(model) is False + _assert_all_model_parameters_meta(model) + return + + target_paths = [path for path, _ in hits] + assert convert_model(model) is True + _assert_all_model_parameters_meta(model) + for path in target_paths: + _validate_defused_module(case, model.get_submodule(path))