Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ from defuser import convert_model, replace_fused_blocks
| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
| `gpt_oss` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, carries over expert biases, and converts fused expert tensors into numbered expert `nn.Linear` modules. |
| `llama4` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, converts fused expert tensors into numbered expert `nn.Linear` modules, and preserves the llama4 batched expert-input execution contract. |

## Workflow Summary

Expand Down
2 changes: 1 addition & 1 deletion defuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from defuser.utils.hf import env_flag
from defuser.utils.common import env_flag

DEBUG_ON = env_flag("DEBUG")

Expand Down
13 changes: 9 additions & 4 deletions defuser/checkpoint_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from transformers.core_model_loading import Chunk, Concatenate, ConversionOps, MergeModulelist


def _owned_contiguous_clone(tensor: torch.Tensor) -> torch.Tensor:
"""Return a contiguous tensor with its own storage using a single clone."""
return tensor.clone(memory_format=torch.contiguous_format)


class OwnedChunk(Chunk):
"""Split fused tensors into independent chunks so save/load keeps both weights."""

Expand All @@ -12,7 +17,7 @@ def convert(
split = super().convert(input_dict, source_patterns, target_patterns, **kwargs)
# `torch.chunk()` returns views into shared storage, which can make safetensors
# drop one side of the split tensor during save. Clone each chunk to own storage.
return {name: tensor.contiguous().clone() for name, tensor in split.items()}
return {name: _owned_contiguous_clone(tensor) for name, tensor in split.items()}


class SplitFusedExpertGateUpProj(ConversionOps):
Expand Down Expand Up @@ -45,8 +50,8 @@ def convert(
for expert_idx in range(num_experts):
expert_tensor = tensor.select(self.expert_dim, expert_idx)
gate_proj, up_proj = torch.chunk(expert_tensor, 2, dim=self.proj_dim)
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = gate_proj.contiguous().clone()
split_tensors[self._expert_target(target_patterns[1], expert_idx)] = up_proj.contiguous().clone()
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = _owned_contiguous_clone(gate_proj)
split_tensors[self._expert_target(target_patterns[1], expert_idx)] = _owned_contiguous_clone(up_proj)

return split_tensors

Expand Down Expand Up @@ -126,7 +131,7 @@ def convert(
split_tensors: dict[str, torch.Tensor] = {}
for expert_idx in range(num_experts):
expert_tensor = tensor.select(self.expert_dim, expert_idx)
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = expert_tensor.contiguous().clone()
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = _owned_contiguous_clone(expert_tensor)

return split_tensors

Expand Down
11 changes: 5 additions & 6 deletions defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ def get_checkpoint_conversion_mapping(model_type):
conversion_mapping.orig_get_checkpoint_conversion_mapping = conversion_mapping.get_checkpoint_conversion_mapping

cfg = MODEL_CONFIG.get(model_type)
if cfg:
return deepcopy(cfg.get("checkpoint_mapping", []))

from transformers import conversion_mapping
if cfg and "checkpoint_mapping" in cfg:
return deepcopy(cfg["checkpoint_mapping"])

return conversion_mapping.orig_get_checkpoint_conversion_mapping(model_type)

Expand All @@ -52,6 +50,7 @@ def replace_fused_blocks(model_type: str) -> bool:
if cfg is None:
return False

patched_any = False
for orig_path, custom_path in cfg.get(PATCH.REPLACE_MODULE, []):
orig_module_path, orig_class_name = orig_path.rsplit(".", 1)
custom_module_path, custom_class_name = custom_path.rsplit(".", 1)
Expand Down Expand Up @@ -81,15 +80,15 @@ def replace_fused_blocks(model_type: str) -> bool:
conversion_mapping.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping
transformers.modeling_utils.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping
logger.info(f"Patched {orig_path} -> {custom_path}")
return True
patched_any = True

except Exception as e:
if isinstance(e, PatchError):
raise e

logger.warning(f"Failed to patch {orig_path}: {e}")
return False
return False
return patched_any


def check_model_compatibility(model: nn.Module) -> bool:
Expand Down
42 changes: 26 additions & 16 deletions defuser/modeling/moe_experts_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,12 @@ def _unfuse_single_projection(
) -> list | None:
"""Unfuse a single projection from 3D Parameter to a list of Linear layers.

Optimized to minimize device allocations and copies:
Optimized to keep peak device memory low while preserving the module's
original device placement:
- Moves the full 3D tensor to CPU in a single transfer
- Performs batch transpose on CPU if needed
- Creates Linear shells on meta device (no allocation)
- Directly assigns weight slices as Parameters (zero-copy on CPU)
- Releases the original fused parameter before allocating defused linears
- Re-materializes each expert linear back onto ``target_device``

Args:
module: The experts module
Expand Down Expand Up @@ -391,6 +392,7 @@ def _unfuse_single_projection(

source_device = param.device
is_meta = source_device.type == "meta"
weight_requires_grad = param.requires_grad

# Prepare weight slices on CPU in batch (single D2H transfer + batch transpose)
if not is_meta:
Expand All @@ -413,6 +415,19 @@ def _unfuse_single_projection(
if not bias_cpu.is_contiguous():
bias_cpu = bias_cpu.contiguous()
bias_slices = bias_cpu.unbind(0)
bias_requires_grad = bias_param.requires_grad

# Drop the original fused parameter before allocating the defused
# per-expert linears back on the original device.
try:
setattr(module, proj_name, to_meta(param))
param = None
if has_bias:
setattr(module, bias_name, to_meta(bias_param))
bias_param = None
if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()")
except Exception:
pass

# Create Linear shells on meta device (no memory allocation)
linears = []
Expand All @@ -421,23 +436,18 @@ def _unfuse_single_projection(
linear = nn.Linear(in_features, out_features, bias=has_bias, dtype=dtype, device="meta")

if not is_meta:
# Direct parameter assignment — no copy, just references the CPU tensor slice
linear.weight = nn.Parameter(weight_slices[i])
weight = weight_slices[i]
if target_device.type != "cpu":
weight = weight.to(device=target_device, dtype=dtype)
linear.weight = nn.Parameter(weight, requires_grad=weight_requires_grad)
if has_bias:
linear.bias = nn.Parameter(bias_slices[i])
bias = bias_slices[i]
if target_device.type != "cpu":
bias = bias.to(device=target_device, dtype=bias.dtype)
linear.bias = nn.Parameter(bias, requires_grad=bias_requires_grad)

linears.append(linear)

# Release original parameter memory
if not is_meta:
try:
setattr(module, proj_name, to_meta(param))
if has_bias:
setattr(module, bias_name, to_meta(bias_param))
if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()")
except Exception:
pass

return linears


Expand Down
33 changes: 33 additions & 0 deletions defuser/modeling/unfused_moe/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import torch
from torch import nn


def run_routed_experts(
experts: nn.ModuleList,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
selected_experts: torch.Tensor,
num_experts: int,
) -> torch.Tensor:
"""Run a standard top-k routed MoE expert loop over explicit expert modules."""
hidden_dim = hidden_states.shape[-1]
final_hidden_states = torch.zeros(
(hidden_states.shape[0], hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0)
expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)), as_tuple=False).flatten()

for expert_idx in expert_hit.tolist():
expert_layer = experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

return final_hidden_states
29 changes: 8 additions & 21 deletions defuser/modeling/unfused_moe/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from transformers import MixtralConfig
from transformers.activations import ACT2FN

from defuser.modeling.unfused_moe.common import run_routed_experts


class MixtralBlockSparseTop2MLP(nn.Module):
"""Per-expert Mixtral MLP with explicit gate, up, and down projections."""
Expand Down Expand Up @@ -63,27 +65,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
final_hidden_states = run_routed_experts(
self.experts,
hidden_states,
routing_weights,
selected_experts,
self.num_experts,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states
31 changes: 8 additions & 23 deletions defuser/modeling/unfused_moe/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.nn as nn
from torch.nn import functional as F

from defuser.modeling.unfused_moe.common import run_routed_experts


class LinearQwen2MoeSparseMoeBlock(nn.Module):
"""Qwen2 MoE block rewritten to expose one ``nn.Module`` per expert."""
Expand Down Expand Up @@ -34,31 +36,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
final_hidden_states = run_routed_experts(
self.experts,
hidden_states,
routing_weights,
selected_experts,
self.num_experts,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

Expand Down
34 changes: 8 additions & 26 deletions defuser/modeling/unfused_moe/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
import torch.nn as nn

from defuser.modeling.unfused_moe.common import run_routed_experts


class LinearQwen3MoeSparseMoeBlock(nn.Module):
"""Qwen3 MoE block rewritten to expose one ``nn.Module`` per expert."""
Expand All @@ -33,32 +35,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
final_hidden_states = run_routed_experts(
self.experts,
hidden_states,
routing_weights,
selected_experts,
self.num_experts,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be solicited
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
if expert_idx == self.num_experts:
continue
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states
31 changes: 8 additions & 23 deletions defuser/modeling/unfused_moe/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.nn as nn
from torch.nn import functional as F

from defuser.modeling.unfused_moe.common import run_routed_experts

class LinearQwen3NextSparseMoeBlock(nn.Module):
"""Qwen3-Next MoE block rewritten to expose one ``nn.Module`` per expert."""

Expand All @@ -33,31 +35,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states)
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
final_hidden_states = run_routed_experts(
self.experts,
hidden_states,
routing_weights,
selected_experts,
self.num_experts,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

Expand Down
Loading