diff --git a/README.md b/README.md index d4499fe..b5b4292 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ from defuser import convert_model, replace_fused_blocks | `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. | | `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. | | `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. | +| `gpt_oss` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, carries over expert biases, and converts fused expert tensors into numbered expert `nn.Linear` modules. | +| `llama4` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, converts fused expert tensors into numbered expert `nn.Linear` modules, and preserves the llama4 batched expert-input execution contract. | ## Workflow Summary diff --git a/defuser/__init__.py b/defuser/__init__.py index bac6e84..1e479ee 100644 --- a/defuser/__init__.py +++ b/defuser/__init__.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from defuser.utils.hf import env_flag +from defuser.utils.common import env_flag DEBUG_ON = env_flag("DEBUG") diff --git a/defuser/checkpoint_ops.py b/defuser/checkpoint_ops.py index 63a9887..9b5d32d 100644 --- a/defuser/checkpoint_ops.py +++ b/defuser/checkpoint_ops.py @@ -2,6 +2,11 @@ from transformers.core_model_loading import Chunk, Concatenate, ConversionOps, MergeModulelist +def _owned_contiguous_clone(tensor: torch.Tensor) -> torch.Tensor: + """Return a contiguous tensor with its own storage using a single clone.""" + return tensor.clone(memory_format=torch.contiguous_format) + + class OwnedChunk(Chunk): """Split fused tensors into independent chunks so save/load keeps both weights.""" @@ -12,7 +17,7 @@ def convert( split = super().convert(input_dict, source_patterns, target_patterns, **kwargs) # `torch.chunk()` returns views into shared storage, which can make safetensors # drop one side of the split tensor during save. Clone each chunk to own storage. - return {name: tensor.contiguous().clone() for name, tensor in split.items()} + return {name: _owned_contiguous_clone(tensor) for name, tensor in split.items()} class SplitFusedExpertGateUpProj(ConversionOps): @@ -45,8 +50,8 @@ def convert( for expert_idx in range(num_experts): expert_tensor = tensor.select(self.expert_dim, expert_idx) gate_proj, up_proj = torch.chunk(expert_tensor, 2, dim=self.proj_dim) - split_tensors[self._expert_target(target_patterns[0], expert_idx)] = gate_proj.contiguous().clone() - split_tensors[self._expert_target(target_patterns[1], expert_idx)] = up_proj.contiguous().clone() + split_tensors[self._expert_target(target_patterns[0], expert_idx)] = _owned_contiguous_clone(gate_proj) + split_tensors[self._expert_target(target_patterns[1], expert_idx)] = _owned_contiguous_clone(up_proj) return split_tensors @@ -126,7 +131,7 @@ def convert( split_tensors: dict[str, torch.Tensor] = {} for expert_idx in range(num_experts): expert_tensor = tensor.select(self.expert_dim, expert_idx) - split_tensors[self._expert_target(target_patterns[0], expert_idx)] = expert_tensor.contiguous().clone() + split_tensors[self._expert_target(target_patterns[0], expert_idx)] = _owned_contiguous_clone(expert_tensor) return split_tensors diff --git a/defuser/defuser.py b/defuser/defuser.py index 345e909..719eaf2 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -29,10 +29,8 @@ def get_checkpoint_conversion_mapping(model_type): conversion_mapping.orig_get_checkpoint_conversion_mapping = conversion_mapping.get_checkpoint_conversion_mapping cfg = MODEL_CONFIG.get(model_type) - if cfg: - return deepcopy(cfg.get("checkpoint_mapping", [])) - - from transformers import conversion_mapping + if cfg and "checkpoint_mapping" in cfg: + return deepcopy(cfg["checkpoint_mapping"]) return conversion_mapping.orig_get_checkpoint_conversion_mapping(model_type) @@ -52,6 +50,7 @@ def replace_fused_blocks(model_type: str) -> bool: if cfg is None: return False + patched_any = False for orig_path, custom_path in cfg.get(PATCH.REPLACE_MODULE, []): orig_module_path, orig_class_name = orig_path.rsplit(".", 1) custom_module_path, custom_class_name = custom_path.rsplit(".", 1) @@ -81,7 +80,7 @@ def replace_fused_blocks(model_type: str) -> bool: conversion_mapping.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping transformers.modeling_utils.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping logger.info(f"Patched {orig_path} -> {custom_path}") - return True + patched_any = True except Exception as e: if isinstance(e, PatchError): @@ -89,7 +88,7 @@ def replace_fused_blocks(model_type: str) -> bool: logger.warning(f"Failed to patch {orig_path}: {e}") return False - return False + return patched_any def check_model_compatibility(model: nn.Module) -> bool: diff --git a/defuser/modeling/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py index 99e0f8d..ad2b9d1 100644 --- a/defuser/modeling/moe_experts_interface.py +++ b/defuser/modeling/moe_experts_interface.py @@ -357,11 +357,12 @@ def _unfuse_single_projection( ) -> list | None: """Unfuse a single projection from 3D Parameter to a list of Linear layers. - Optimized to minimize device allocations and copies: + Optimized to keep peak device memory low while preserving the module's + original device placement: - Moves the full 3D tensor to CPU in a single transfer - Performs batch transpose on CPU if needed - - Creates Linear shells on meta device (no allocation) - - Directly assigns weight slices as Parameters (zero-copy on CPU) + - Releases the original fused parameter before allocating defused linears + - Re-materializes each expert linear back onto ``target_device`` Args: module: The experts module @@ -391,6 +392,7 @@ def _unfuse_single_projection( source_device = param.device is_meta = source_device.type == "meta" + weight_requires_grad = param.requires_grad # Prepare weight slices on CPU in batch (single D2H transfer + batch transpose) if not is_meta: @@ -413,6 +415,19 @@ def _unfuse_single_projection( if not bias_cpu.is_contiguous(): bias_cpu = bias_cpu.contiguous() bias_slices = bias_cpu.unbind(0) + bias_requires_grad = bias_param.requires_grad + + # Drop the original fused parameter before allocating the defused + # per-expert linears back on the original device. + try: + setattr(module, proj_name, to_meta(param)) + param = None + if has_bias: + setattr(module, bias_name, to_meta(bias_param)) + bias_param = None + if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()") + except Exception: + pass # Create Linear shells on meta device (no memory allocation) linears = [] @@ -421,23 +436,18 @@ def _unfuse_single_projection( linear = nn.Linear(in_features, out_features, bias=has_bias, dtype=dtype, device="meta") if not is_meta: - # Direct parameter assignment — no copy, just references the CPU tensor slice - linear.weight = nn.Parameter(weight_slices[i]) + weight = weight_slices[i] + if target_device.type != "cpu": + weight = weight.to(device=target_device, dtype=dtype) + linear.weight = nn.Parameter(weight, requires_grad=weight_requires_grad) if has_bias: - linear.bias = nn.Parameter(bias_slices[i]) + bias = bias_slices[i] + if target_device.type != "cpu": + bias = bias.to(device=target_device, dtype=bias.dtype) + linear.bias = nn.Parameter(bias, requires_grad=bias_requires_grad) linears.append(linear) - # Release original parameter memory - if not is_meta: - try: - setattr(module, proj_name, to_meta(param)) - if has_bias: - setattr(module, bias_name, to_meta(bias_param)) - if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()") - except Exception: - pass - return linears diff --git a/defuser/modeling/unfused_moe/common.py b/defuser/modeling/unfused_moe/common.py new file mode 100644 index 0000000..b5b7c16 --- /dev/null +++ b/defuser/modeling/unfused_moe/common.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +from torch import nn + + +def run_routed_experts( + experts: nn.ModuleList, + hidden_states: torch.Tensor, + routing_weights: torch.Tensor, + selected_experts: torch.Tensor, + num_experts: int, +) -> torch.Tensor: + """Run a standard top-k routed MoE expert loop over explicit expert modules.""" + hidden_dim = hidden_states.shape[-1] + final_hidden_states = torch.zeros( + (hidden_states.shape[0], hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0) + expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)), as_tuple=False).flatten() + + for expert_idx in expert_hit.tolist(): + expert_layer = experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + return final_hidden_states diff --git a/defuser/modeling/unfused_moe/mixtral.py b/defuser/modeling/unfused_moe/mixtral.py index edfbf20..a2a2786 100644 --- a/defuser/modeling/unfused_moe/mixtral.py +++ b/defuser/modeling/unfused_moe/mixtral.py @@ -8,6 +8,8 @@ from transformers import MixtralConfig from transformers.activations import ACT2FN +from defuser.modeling.unfused_moe.common import run_routed_experts + class MixtralBlockSparseTop2MLP(nn.Module): """Per-expert Mixtral MLP with explicit gate, up, and down projections.""" @@ -63,27 +65,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + final_hidden_states = run_routed_experts( + self.experts, + hidden_states, + routing_weights, + selected_experts, + self.num_experts, ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states diff --git a/defuser/modeling/unfused_moe/qwen2_moe.py b/defuser/modeling/unfused_moe/qwen2_moe.py index 61ef68e..7368fa7 100644 --- a/defuser/modeling/unfused_moe/qwen2_moe.py +++ b/defuser/modeling/unfused_moe/qwen2_moe.py @@ -7,6 +7,8 @@ import torch.nn as nn from torch.nn import functional as F +from defuser.modeling.unfused_moe.common import run_routed_experts + class LinearQwen2MoeSparseMoeBlock(nn.Module): """Qwen2 MoE block rewritten to expose one ``nn.Module`` per expert.""" @@ -34,31 +36,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + final_hidden_states = run_routed_experts( + self.experts, + hidden_states, + routing_weights, + selected_experts, + self.num_experts, ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output diff --git a/defuser/modeling/unfused_moe/qwen3_moe.py b/defuser/modeling/unfused_moe/qwen3_moe.py index 89d382d..1dc5fd0 100644 --- a/defuser/modeling/unfused_moe/qwen3_moe.py +++ b/defuser/modeling/unfused_moe/qwen3_moe.py @@ -9,6 +9,8 @@ import torch import torch.nn as nn +from defuser.modeling.unfused_moe.common import run_routed_experts + class LinearQwen3MoeSparseMoeBlock(nn.Module): """Qwen3 MoE block rewritten to expose one ``nn.Module`` per expert.""" @@ -33,32 +35,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + final_hidden_states = run_routed_experts( + self.experts, + hidden_states, + routing_weights, + selected_experts, + self.num_experts, ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be solicited - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - if expert_idx == self.num_experts: - continue - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states diff --git a/defuser/modeling/unfused_moe/qwen3_next.py b/defuser/modeling/unfused_moe/qwen3_next.py index d9bdcc2..af58062 100644 --- a/defuser/modeling/unfused_moe/qwen3_next.py +++ b/defuser/modeling/unfused_moe/qwen3_next.py @@ -7,6 +7,8 @@ import torch.nn as nn from torch.nn import functional as F +from defuser.modeling.unfused_moe.common import run_routed_experts + class LinearQwen3NextSparseMoeBlock(nn.Module): """Qwen3-Next MoE block rewritten to expose one ``nn.Module`` per expert.""" @@ -33,31 +35,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + final_hidden_states = run_routed_experts( + self.experts, + hidden_states, + routing_weights, + selected_experts, + self.num_experts, ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output diff --git a/defuser/modeling/unfused_moe/qwen3_omni_moe.py b/defuser/modeling/unfused_moe/qwen3_omni_moe.py index 94c0d4f..798629a 100644 --- a/defuser/modeling/unfused_moe/qwen3_omni_moe.py +++ b/defuser/modeling/unfused_moe/qwen3_omni_moe.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn +from defuser.modeling.unfused_moe.common import run_routed_experts + class LinearQwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module): """Text thinker MoE block for qwen3-omni with explicit per-expert modules.""" @@ -35,29 +37,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + final_hidden_states = run_routed_experts( + self.experts, + hidden_states, + routing_weights, + selected_experts, + self.num_experts, ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states diff --git a/defuser/utils/hf.py b/defuser/utils/hf.py index af370f5..9846bd6 100644 --- a/defuser/utils/hf.py +++ b/defuser/utils/hf.py @@ -16,27 +16,12 @@ from transformers import AutoConfig from defuser.model_registry import MODEL_CONFIG -from defuser.utils.common import warn_if_public_api_transformers_unsupported +from defuser.utils.common import env_flag, warn_if_public_api_transformers_unsupported logger = LogBar(__name__) _ENV_VAR: Final[str] = "GPTQMODEL_USE_MODELSCOPE" -TRUTHFUL = {"1", "true", "yes", "on", "y"} - - -def env_flag(name: str, default: str | bool | None = "0") -> bool: - """Return ``True`` when an env var is set to a truthy value.""" - - value = os.getenv(name) - if value is None: - if default is None: - return False - if isinstance(default, bool): - return default - value = default - return str(value).strip().lower() in TRUTHFUL - def modelscope_requested() -> bool: """ diff --git a/defuser/utils/model.py b/defuser/utils/model.py index a7bbc76..6292c52 100644 --- a/defuser/utils/model.py +++ b/defuser/utils/model.py @@ -8,6 +8,8 @@ import torch +from defuser.utils.device import unsupported_meta_device + def _update_parameter( module: torch.nn.Module, @@ -18,20 +20,3 @@ def _update_parameter( old_param = getattr(module, name) new_param = torch.nn.Parameter(data, requires_grad=old_param.requires_grad) setattr(module, name, new_param) - - -def unsupported_meta_device(model): - """Return ``True`` when mixed real/meta parameters make lazy materialization unsafe.""" - target_device = None - for param in model.parameters(): - if target_device is None: - target_device = param.device - if param.device != target_device: - if param.device.type == "meta" or target_device.type == "meta": - return True - if target_device.type == "meta": - if hasattr(model, "path"): - return False - else: - return True - return False diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index a2a4789..53465b8 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from types import SimpleNamespace +import pytest import torch from safetensors.torch import save_file from torch import nn @@ -41,8 +42,8 @@ import defuser.defuser as defuser_api import defuser.utils.hf as hf_utils from defuser import convert_model, replace_fused_blocks -from defuser.checkpoint_ops import OwnedChunk, SplitFusedExpertGateUpProj -from defuser.model_registry import MODEL_CONFIG +from defuser.checkpoint_ops import OwnedChunk, SplitFusedExpertDownProj, SplitFusedExpertGateUpProj +from defuser.model_registry import MODEL_CONFIG, PATCH from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model from defuser.modeling.unfused_moe.glm4_moe import LinearGlm4MoeMoE from defuser.modeling.unfused_moe.mixtral import LinearMixtralSparseMoeBlock @@ -467,6 +468,54 @@ def test_replace_fused_blocks_returns_false_for_unregistered_model(): assert replace_fused_blocks("unsupported_model_type") is False +def test_replace_fused_blocks_applies_all_registered_replacements(): + import sys + import types + + orig1 = types.ModuleType("dummy_orig1") + orig2 = types.ModuleType("dummy_orig2") + custom1 = types.ModuleType("dummy_custom1") + custom2 = types.ModuleType("dummy_custom2") + + class OriginalOne: + pass + + class OriginalTwo: + pass + + class ReplacementOne: + pass + + class ReplacementTwo: + pass + + orig1.OriginalOne = OriginalOne + orig2.OriginalTwo = OriginalTwo + custom1.ReplacementOne = ReplacementOne + custom2.ReplacementTwo = ReplacementTwo + + sys.modules["dummy_orig1"] = orig1 + sys.modules["dummy_orig2"] = orig2 + sys.modules["dummy_custom1"] = custom1 + sys.modules["dummy_custom2"] = custom2 + MODEL_CONFIG["dummy_multi_patch"] = { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + PATCH.REPLACE_MODULE: [ + ("dummy_orig1.OriginalOne", "dummy_custom1.ReplacementOne"), + ("dummy_orig2.OriginalTwo", "dummy_custom2.ReplacementTwo"), + ], + } + + try: + assert replace_fused_blocks("dummy_multi_patch") is True + assert orig1.OriginalOne is ReplacementOne + assert orig2.OriginalTwo is ReplacementTwo + finally: + MODEL_CONFIG.pop("dummy_multi_patch", None) + for name in ("dummy_orig1", "dummy_orig2", "dummy_custom1", "dummy_custom2"): + sys.modules.pop(name, None) + + def test_model_registry_requires_transformers_5_3_or_newer(): assert {cfg["min_transformers_version"] for cfg in MODEL_CONFIG.values()} == {MIN_SUPPORTED_TRANSFORMERS_VERSION} @@ -549,6 +598,25 @@ def test_qwen3_5_moe(): torch.testing.assert_close(expert0.down_proj.weight, expected_down) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_qwen3_5_moe_runtime_defusion_preserves_cuda_device(): + model = Qwen3_5MoeForConditionalGeneration(_tiny_qwen3_5_moe_config()).cuda().eval() + + converted = convert_model(model, cleanup_original=False, max_layers=1) + assert converted + + expert0 = getattr(model.model.language_model.layers[0].mlp.experts, "0") + assert expert0.gate_proj.weight.device.type == "cuda" + assert expert0.up_proj.weight.device.type == "cuda" + assert expert0.down_proj.weight.device.type == "cuda" + + hidden_states = torch.randn(2, model.config.text_config.hidden_size, device="cuda") + with torch.no_grad(): + output = expert0.gate_proj(hidden_states) + + assert output.device.type == "cuda" + + def test_mixtral(): model_type = "mixtral" replace_fused_blocks(model_type) @@ -592,6 +660,46 @@ def test_mixtral_checkpoint_mapping_splits_fused_experts(): torch.testing.assert_close(split[".experts.0.up_proj.weight"], fused_gate_up[0, 3:]) torch.testing.assert_close(split[".experts.3.gate_proj.weight"], fused_gate_up[3, :3]) torch.testing.assert_close(split[".experts.3.up_proj.weight"], fused_gate_up[3, 3:]) + assert split[".experts.0.gate_proj.weight"].is_contiguous() + assert split[".experts.0.up_proj.weight"].is_contiguous() + assert split[".experts.0.gate_proj.weight"].untyped_storage().data_ptr() != fused_gate_up.untyped_storage().data_ptr() + assert split[".experts.0.up_proj.weight"].untyped_storage().data_ptr() != fused_gate_up.untyped_storage().data_ptr() + + +def test_mixtral_checkpoint_mapping_splits_down_proj_into_owned_contiguous_tensors(): + split_op = SplitFusedExpertDownProj() + fused_down = torch.arange(4 * 8 * 3, dtype=torch.float32).reshape(4, 8, 3) + + split = split_op.convert( + {".experts.down_proj": fused_down}, + [".experts.down_proj"], + [".experts.0.down_proj.weight"], + ) + + torch.testing.assert_close(split[".experts.0.down_proj.weight"], fused_down[0]) + torch.testing.assert_close(split[".experts.3.down_proj.weight"], fused_down[3]) + assert split[".experts.0.down_proj.weight"].is_contiguous() + assert split[".experts.3.down_proj.weight"].is_contiguous() + assert split[".experts.0.down_proj.weight"].untyped_storage().data_ptr() != fused_down.untyped_storage().data_ptr() + assert split[".experts.3.down_proj.weight"].untyped_storage().data_ptr() != fused_down.untyped_storage().data_ptr() + + +def test_registered_models_without_custom_checkpoint_mapping_keep_transformers_fallback(): + from transformers import conversion_mapping + + upstream_get_mapping = getattr( + conversion_mapping, + "orig_get_checkpoint_conversion_mapping", + conversion_mapping.get_checkpoint_conversion_mapping, + ) + upstream_mapping = upstream_get_mapping("qwen2_moe") + + replace_fused_blocks("qwen2_moe") + + patched_mapping = conversion_mapping.get_checkpoint_conversion_mapping("qwen2_moe") + assert len(patched_mapping) == len(upstream_mapping) + assert [item.source_patterns for item in patched_mapping] == [item.source_patterns for item in upstream_mapping] + assert [item.target_patterns for item in patched_mapping] == [item.target_patterns for item in upstream_mapping] def test_mixtral_from_pretrained_loads_fused_checkpoint_into_defused_model(tmp_path): @@ -804,6 +912,10 @@ def test_glm4v_checkpoint_mapping_splits_gate_up_proj(): torch.testing.assert_close(split["mlp.gate_proj.weight"], fused[:3]) torch.testing.assert_close(split["mlp.up_proj.weight"], fused[3:]) assert split["mlp.gate_proj.weight"].data_ptr() != split["mlp.up_proj.weight"].data_ptr() + assert split["mlp.gate_proj.weight"].is_contiguous() + assert split["mlp.up_proj.weight"].is_contiguous() + assert split["mlp.gate_proj.weight"].untyped_storage().data_ptr() != fused.untyped_storage().data_ptr() + assert split["mlp.up_proj.weight"].untyped_storage().data_ptr() != fused.untyped_storage().data_ptr() def test_glm4v_split_forward_matches_fused_math():