Skip to content

Commit 170b24b

Browse files
authored
Refractor (#31)
* clean up readme * defused weights should stay on same device * cleanup * patch replace_fused_blocks() returning after the first replacement entry instead of applying the full list * defused tensors should remain original's contiguous state * deduplicate logic: move shared code to helper
1 parent 3d9803f commit 170b24b

14 files changed

Lines changed: 233 additions & 178 deletions

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ from defuser import convert_model, replace_fused_blocks
4949
| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
5050
| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
5151
| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
52+
| `gpt_oss` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, carries over expert biases, and converts fused expert tensors into numbered expert `nn.Linear` modules. |
53+
| `llama4` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, converts fused expert tensors into numbered expert `nn.Linear` modules, and preserves the llama4 batched expert-input execution contract. |
5254

5355
## Workflow Summary
5456

defuser/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
55

6-
from defuser.utils.hf import env_flag
6+
from defuser.utils.common import env_flag
77

88
DEBUG_ON = env_flag("DEBUG")
99

defuser/checkpoint_ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
from transformers.core_model_loading import Chunk, Concatenate, ConversionOps, MergeModulelist
33

44

5+
def _owned_contiguous_clone(tensor: torch.Tensor) -> torch.Tensor:
6+
"""Return a contiguous tensor with its own storage using a single clone."""
7+
return tensor.clone(memory_format=torch.contiguous_format)
8+
9+
510
class OwnedChunk(Chunk):
611
"""Split fused tensors into independent chunks so save/load keeps both weights."""
712

@@ -12,7 +17,7 @@ def convert(
1217
split = super().convert(input_dict, source_patterns, target_patterns, **kwargs)
1318
# `torch.chunk()` returns views into shared storage, which can make safetensors
1419
# drop one side of the split tensor during save. Clone each chunk to own storage.
15-
return {name: tensor.contiguous().clone() for name, tensor in split.items()}
20+
return {name: _owned_contiguous_clone(tensor) for name, tensor in split.items()}
1621

1722

1823
class SplitFusedExpertGateUpProj(ConversionOps):
@@ -45,8 +50,8 @@ def convert(
4550
for expert_idx in range(num_experts):
4651
expert_tensor = tensor.select(self.expert_dim, expert_idx)
4752
gate_proj, up_proj = torch.chunk(expert_tensor, 2, dim=self.proj_dim)
48-
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = gate_proj.contiguous().clone()
49-
split_tensors[self._expert_target(target_patterns[1], expert_idx)] = up_proj.contiguous().clone()
53+
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = _owned_contiguous_clone(gate_proj)
54+
split_tensors[self._expert_target(target_patterns[1], expert_idx)] = _owned_contiguous_clone(up_proj)
5055

5156
return split_tensors
5257

@@ -126,7 +131,7 @@ def convert(
126131
split_tensors: dict[str, torch.Tensor] = {}
127132
for expert_idx in range(num_experts):
128133
expert_tensor = tensor.select(self.expert_dim, expert_idx)
129-
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = expert_tensor.contiguous().clone()
134+
split_tensors[self._expert_target(target_patterns[0], expert_idx)] = _owned_contiguous_clone(expert_tensor)
130135

131136
return split_tensors
132137

defuser/defuser.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ def get_checkpoint_conversion_mapping(model_type):
2929
conversion_mapping.orig_get_checkpoint_conversion_mapping = conversion_mapping.get_checkpoint_conversion_mapping
3030

3131
cfg = MODEL_CONFIG.get(model_type)
32-
if cfg:
33-
return deepcopy(cfg.get("checkpoint_mapping", []))
34-
35-
from transformers import conversion_mapping
32+
if cfg and "checkpoint_mapping" in cfg:
33+
return deepcopy(cfg["checkpoint_mapping"])
3634

3735
return conversion_mapping.orig_get_checkpoint_conversion_mapping(model_type)
3836

@@ -52,6 +50,7 @@ def replace_fused_blocks(model_type: str) -> bool:
5250
if cfg is None:
5351
return False
5452

53+
patched_any = False
5554
for orig_path, custom_path in cfg.get(PATCH.REPLACE_MODULE, []):
5655
orig_module_path, orig_class_name = orig_path.rsplit(".", 1)
5756
custom_module_path, custom_class_name = custom_path.rsplit(".", 1)
@@ -81,15 +80,15 @@ def replace_fused_blocks(model_type: str) -> bool:
8180
conversion_mapping.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping
8281
transformers.modeling_utils.get_checkpoint_conversion_mapping = get_checkpoint_conversion_mapping
8382
logger.info(f"Patched {orig_path} -> {custom_path}")
84-
return True
83+
patched_any = True
8584

8685
except Exception as e:
8786
if isinstance(e, PatchError):
8887
raise e
8988

9089
logger.warning(f"Failed to patch {orig_path}: {e}")
9190
return False
92-
return False
91+
return patched_any
9392

9493

9594
def check_model_compatibility(model: nn.Module) -> bool:

defuser/modeling/moe_experts_interface.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,12 @@ def _unfuse_single_projection(
357357
) -> list | None:
358358
"""Unfuse a single projection from 3D Parameter to a list of Linear layers.
359359
360-
Optimized to minimize device allocations and copies:
360+
Optimized to keep peak device memory low while preserving the module's
361+
original device placement:
361362
- Moves the full 3D tensor to CPU in a single transfer
362363
- Performs batch transpose on CPU if needed
363-
- Creates Linear shells on meta device (no allocation)
364-
- Directly assigns weight slices as Parameters (zero-copy on CPU)
364+
- Releases the original fused parameter before allocating defused linears
365+
- Re-materializes each expert linear back onto ``target_device``
365366
366367
Args:
367368
module: The experts module
@@ -391,6 +392,7 @@ def _unfuse_single_projection(
391392

392393
source_device = param.device
393394
is_meta = source_device.type == "meta"
395+
weight_requires_grad = param.requires_grad
394396

395397
# Prepare weight slices on CPU in batch (single D2H transfer + batch transpose)
396398
if not is_meta:
@@ -413,6 +415,19 @@ def _unfuse_single_projection(
413415
if not bias_cpu.is_contiguous():
414416
bias_cpu = bias_cpu.contiguous()
415417
bias_slices = bias_cpu.unbind(0)
418+
bias_requires_grad = bias_param.requires_grad
419+
420+
# Drop the original fused parameter before allocating the defused
421+
# per-expert linears back on the original device.
422+
try:
423+
setattr(module, proj_name, to_meta(param))
424+
param = None
425+
if has_bias:
426+
setattr(module, bias_name, to_meta(bias_param))
427+
bias_param = None
428+
if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()")
429+
except Exception:
430+
pass
416431

417432
# Create Linear shells on meta device (no memory allocation)
418433
linears = []
@@ -421,23 +436,18 @@ def _unfuse_single_projection(
421436
linear = nn.Linear(in_features, out_features, bias=has_bias, dtype=dtype, device="meta")
422437

423438
if not is_meta:
424-
# Direct parameter assignment — no copy, just references the CPU tensor slice
425-
linear.weight = nn.Parameter(weight_slices[i])
439+
weight = weight_slices[i]
440+
if target_device.type != "cpu":
441+
weight = weight.to(device=target_device, dtype=dtype)
442+
linear.weight = nn.Parameter(weight, requires_grad=weight_requires_grad)
426443
if has_bias:
427-
linear.bias = nn.Parameter(bias_slices[i])
444+
bias = bias_slices[i]
445+
if target_device.type != "cpu":
446+
bias = bias.to(device=target_device, dtype=bias.dtype)
447+
linear.bias = nn.Parameter(bias, requires_grad=bias_requires_grad)
428448

429449
linears.append(linear)
430450

431-
# Release original parameter memory
432-
if not is_meta:
433-
try:
434-
setattr(module, proj_name, to_meta(param))
435-
if has_bias:
436-
setattr(module, bias_name, to_meta(bias_param))
437-
if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()")
438-
except Exception:
439-
pass
440-
441451
return linears
442452

443453

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
6+
import torch
7+
from torch import nn
8+
9+
10+
def run_routed_experts(
11+
experts: nn.ModuleList,
12+
hidden_states: torch.Tensor,
13+
routing_weights: torch.Tensor,
14+
selected_experts: torch.Tensor,
15+
num_experts: int,
16+
) -> torch.Tensor:
17+
"""Run a standard top-k routed MoE expert loop over explicit expert modules."""
18+
hidden_dim = hidden_states.shape[-1]
19+
final_hidden_states = torch.zeros(
20+
(hidden_states.shape[0], hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
21+
)
22+
23+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0)
24+
expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)), as_tuple=False).flatten()
25+
26+
for expert_idx in expert_hit.tolist():
27+
expert_layer = experts[expert_idx]
28+
idx, top_x = torch.where(expert_mask[expert_idx])
29+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
30+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
31+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
32+
33+
return final_hidden_states

defuser/modeling/unfused_moe/mixtral.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from transformers import MixtralConfig
99
from transformers.activations import ACT2FN
1010

11+
from defuser.modeling.unfused_moe.common import run_routed_experts
12+
1113

1214
class MixtralBlockSparseTop2MLP(nn.Module):
1315
"""Per-expert Mixtral MLP with explicit gate, up, and down projections."""
@@ -63,27 +65,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6365
hidden_states = hidden_states.view(-1, hidden_dim)
6466
_, routing_weights, selected_experts = self.gate(hidden_states)
6567
routing_weights = routing_weights.to(hidden_states.dtype)
66-
67-
final_hidden_states = torch.zeros(
68-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
68+
final_hidden_states = run_routed_experts(
69+
self.experts,
70+
hidden_states,
71+
routing_weights,
72+
selected_experts,
73+
self.num_experts,
6974
)
70-
71-
# One hot encode the selected experts to create an expert mask
72-
# this will be used to easily index which expert is going to be sollicitated
73-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
74-
75-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
76-
for expert_idx in expert_hit:
77-
expert_layer = self.experts[expert_idx]
78-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
79-
# Index the correct hidden states and compute the expert hidden state for
80-
# the current expert. We need to make sure to multiply the output hidden
81-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
82-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
83-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
84-
85-
# However `index_add_` only support torch tensors for indexing so we'll use
86-
# the `top_x` tensor here.
87-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
8875
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
8976
return final_hidden_states

defuser/modeling/unfused_moe/qwen2_moe.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch.nn as nn
88
from torch.nn import functional as F
99

10+
from defuser.modeling.unfused_moe.common import run_routed_experts
11+
1012

1113
class LinearQwen2MoeSparseMoeBlock(nn.Module):
1214
"""Qwen2 MoE block rewritten to expose one ``nn.Module`` per expert."""
@@ -34,31 +36,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3436
hidden_states = hidden_states.view(-1, hidden_dim)
3537
_, routing_weights, selected_experts = self.gate(hidden_states)
3638
routing_weights = routing_weights.to(hidden_states.dtype)
37-
38-
final_hidden_states = torch.zeros(
39-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
39+
final_hidden_states = run_routed_experts(
40+
self.experts,
41+
hidden_states,
42+
routing_weights,
43+
selected_experts,
44+
self.num_experts,
4045
)
4146

42-
# One hot encode the selected experts to create an expert mask
43-
# this will be used to easily index which expert is going to be sollicitated
44-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
45-
46-
# Loop over all available experts in the model and perform the computation on each expert
47-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
48-
for expert_idx in expert_hit:
49-
expert_layer = self.experts[expert_idx]
50-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
51-
52-
# Index the correct hidden states and compute the expert hidden state for
53-
# the current expert. We need to make sure to multiply the output hidden
54-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
55-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
56-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
57-
58-
# However `index_add_` only support torch tensors for indexing so we'll use
59-
# the `top_x` tensor here.
60-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
61-
6247
shared_expert_output = self.shared_expert(hidden_states)
6348
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
6449

defuser/modeling/unfused_moe/qwen3_moe.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import torch
1010
import torch.nn as nn
1111

12+
from defuser.modeling.unfused_moe.common import run_routed_experts
13+
1214

1315
class LinearQwen3MoeSparseMoeBlock(nn.Module):
1416
"""Qwen3 MoE block rewritten to expose one ``nn.Module`` per expert."""
@@ -33,32 +35,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3335
hidden_states = hidden_states.view(-1, hidden_dim)
3436
_, routing_weights, selected_experts = self.gate(hidden_states)
3537
routing_weights = routing_weights.to(hidden_states.dtype)
36-
37-
final_hidden_states = torch.zeros(
38-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
38+
final_hidden_states = run_routed_experts(
39+
self.experts,
40+
hidden_states,
41+
routing_weights,
42+
selected_experts,
43+
self.num_experts,
3944
)
40-
41-
# One hot encode the selected experts to create an expert mask
42-
# this will be used to easily index which expert is going to be solicited
43-
with torch.no_grad():
44-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
45-
46-
# Loop over all available experts in the model and perform the computation on each expert
47-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
48-
for expert_idx in expert_hit:
49-
if expert_idx == self.num_experts:
50-
continue
51-
expert_layer = self.experts[expert_idx]
52-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
53-
54-
# Index the correct hidden states and compute the expert hidden state for
55-
# the current expert. We need to make sure to multiply the output hidden
56-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
57-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
58-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
59-
60-
# However `index_add_` only support torch tensors for indexing so we'll use
61-
# the `top_x` tensor here.
62-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
6345
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
6446
return final_hidden_states

defuser/modeling/unfused_moe/qwen3_next.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch.nn as nn
88
from torch.nn import functional as F
99

10+
from defuser.modeling.unfused_moe.common import run_routed_experts
11+
1012
class LinearQwen3NextSparseMoeBlock(nn.Module):
1113
"""Qwen3-Next MoE block rewritten to expose one ``nn.Module`` per expert."""
1214

@@ -33,31 +35,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3335
hidden_states = hidden_states.view(-1, hidden_dim)
3436
_, routing_weights, selected_experts = self.gate(hidden_states)
3537
routing_weights = routing_weights.to(hidden_states.dtype)
36-
37-
final_hidden_states = torch.zeros(
38-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
38+
final_hidden_states = run_routed_experts(
39+
self.experts,
40+
hidden_states,
41+
routing_weights,
42+
selected_experts,
43+
self.num_experts,
3944
)
4045

41-
# One hot encode the selected experts to create an expert mask
42-
# this will be used to easily index which expert is going to be sollicitated
43-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
44-
45-
# Loop over all available experts in the model and perform the computation on each expert
46-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
47-
for expert_idx in expert_hit:
48-
expert_layer = self.experts[expert_idx]
49-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
50-
51-
# Index the correct hidden states and compute the expert hidden state for
52-
# the current expert. We need to make sure to multiply the output hidden
53-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
54-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
55-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
56-
57-
# However `index_add_` only support torch tensors for indexing so we'll use
58-
# the `top_x` tensor here.
59-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
60-
6146
shared_expert_output = self.shared_expert(hidden_states)
6247
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
6348

0 commit comments

Comments
 (0)