diff --git a/defuser/model_registry.py b/defuser/model_registry.py index 9a4d7f0..a8617ac 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -128,4 +128,7 @@ class PATCH(str, Enum): } ], }, + "phimoe": { + "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + }, } diff --git a/pyproject.toml b/pyproject.toml index a6ede77..ea58861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "Defuser" -version = "0.0.15" +version = "0.0.16" description = "Model defuser helper for HF Transformers." readme = "README.md" requires-python = ">=3.9" diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 53465b8..e8e2b64 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -33,6 +33,7 @@ ) from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeConfig from transformers.models.gpt_oss.modeling_gpt_oss import GptOssConfig, GptOssForCausalLM +from transformers.models.phimoe.modeling_phimoe import PhimoeConfig, PhimoeForCausalLM from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4ForConditionalGeneration from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( Qwen3OmniMoeForConditionalGeneration, @@ -252,6 +253,22 @@ def _tiny_llama4_config(): ) +def _tiny_phimoe_config(): + return PhimoeConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=4, + num_local_experts=4, + num_experts_per_tok=2, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + def _write_single_safetensors_checkpoint(path, state_dict: dict[str, torch.Tensor], config) -> None: config.save_pretrained(path) save_file({name: tensor.detach().cpu().contiguous() for name, tensor in state_dict.items()}, str(path / "model.safetensors")) @@ -1076,3 +1093,61 @@ def test_llama4_split_forward_matches_fused_math(): # The split module should exactly reproduce the original fused MLP math. torch.testing.assert_close(mlp(hidden_states), expected) + + +def test_phimoe(): + from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock + + model = PhimoeForCausalLM(_tiny_phimoe_config()) + assert model.config.model_type == "phimoe" + + original_moe_block = model.model.layers[0].mlp + assert isinstance(original_moe_block, PhimoeSparseMoeBlock) + + hidden_dim = original_moe_block.experts.gate_up_proj.shape[-1] + intermediate_dim = original_moe_block.experts.gate_up_proj.shape[1] // 2 + + expected_gate = original_moe_block.experts.gate_up_proj[0, :intermediate_dim, :hidden_dim].contiguous().clone() + expected_up = original_moe_block.experts.gate_up_proj[0, intermediate_dim:, :hidden_dim].contiguous().clone() + expected_down = original_moe_block.experts.down_proj[0, :hidden_dim, :intermediate_dim].contiguous().clone() + + converted = convert_model(model, cleanup_original=False, max_layers=1) + assert converted + + moe_block = model.model.layers[0].mlp + experts = moe_block.experts + + _assert_unfused_expert_module(experts) + expert0 = getattr(experts, "0") + + materialize_model(model.model.layers[0]) + + torch.testing.assert_close(expert0.gate_proj.weight, expected_gate) + torch.testing.assert_close(expert0.up_proj.weight, expected_up) + torch.testing.assert_close(expert0.down_proj.weight, expected_down) + +def test_phimoe_split_forward_matches_fused_math(): + from transformers.models.phimoe.modeling_phimoe import PhimoeExperts + + model = PhimoeForCausalLM(_tiny_phimoe_config()) + fused_experts = model.model.layers[0].mlp.experts + assert isinstance(fused_experts, PhimoeExperts) + + hidden_states = torch.randn(5, model.config.hidden_size, dtype=torch.float32) + top_k_index = torch.zeros((hidden_states.size(0), 1), dtype=torch.long) + top_k_weights = torch.ones((hidden_states.size(0), 1), dtype=hidden_states.dtype) + + with torch.no_grad(): + expected = fused_experts(hidden_states, top_k_index, top_k_weights) + + converted = convert_model(model, cleanup_original=False, max_layers=1) + assert converted + + split_experts = model.model.layers[0].mlp.experts + _assert_unfused_expert_module(split_experts) + materialize_model(model.model.layers[0]) + with torch.no_grad(): + actual = split_experts(hidden_states, top_k_index, top_k_weights) + + # The split experts path should exactly reproduce the original fused experts math. + torch.testing.assert_close(actual, expected)