Skip to content

Commit 2337e57

Browse files
committed
add test
1 parent 69b8402 commit 2337e57

1 file changed

Lines changed: 107 additions & 0 deletions

File tree

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
6+
import gc
7+
import weakref
8+
9+
import pytest
10+
import torch
11+
12+
from defuser.modeling.fused_moe.replace_modules import (
13+
ModuleReplacementTracker,
14+
ReplacementModuleBase,
15+
release_original_module_,
16+
)
17+
18+
19+
class DummyOriginalForTrackerTests(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
self.linear = torch.nn.Linear(4, 4)
23+
24+
25+
class DummyReplacementForTrackerTests(ReplacementModuleBase):
26+
@classmethod
27+
def original_module_class(cls) -> str:
28+
return "DummyOriginalForTrackerTests"
29+
30+
@classmethod
31+
def from_original(cls, original: torch.nn.Module, config) -> "DummyReplacementForTrackerTests":
32+
return cls(original)
33+
34+
def _materialize_weights(self) -> None:
35+
pass
36+
37+
38+
class DummyContainerForTrackerTests(torch.nn.Module):
39+
def __init__(self, replacement: ReplacementModuleBase):
40+
super().__init__()
41+
self.replacement = replacement
42+
43+
44+
@pytest.fixture(autouse=True)
45+
def _cleanup_tracker():
46+
tracker = ModuleReplacementTracker.get_instance()
47+
tracker.clear()
48+
yield
49+
tracker.clear()
50+
51+
52+
def test_replacement_owns_original_reference_until_release():
53+
original = DummyOriginalForTrackerTests()
54+
replacement = DummyReplacementForTrackerTests(original)
55+
56+
assert replacement._get_original_module() is original
57+
58+
replacement.release_original_module()
59+
60+
with pytest.raises(RuntimeError, match="already been released"):
61+
replacement._get_original_module()
62+
63+
64+
def test_tracker_metadata_does_not_keep_original_alive_after_release():
65+
tracker = ModuleReplacementTracker.get_instance()
66+
original = DummyOriginalForTrackerTests()
67+
original_ref = weakref.ref(original)
68+
replacement = DummyReplacementForTrackerTests(original)
69+
tracker_name = str(id(replacement))
70+
71+
info = tracker.get_info_by_name(tracker_name)
72+
assert info is not None
73+
assert info.original_module_class == DummyOriginalForTrackerTests.__name__
74+
assert info.replacement_module_class == DummyReplacementForTrackerTests.__name__
75+
assert info.replacement_module_ref() is replacement
76+
77+
replacement.release_original_module()
78+
del original
79+
gc.collect()
80+
81+
assert original_ref() is None
82+
assert tracker.get_info_by_name(tracker_name) is None
83+
84+
85+
def test_release_all_originals_releases_replacement_owned_originals():
86+
tracker = ModuleReplacementTracker.get_instance()
87+
replacement = DummyReplacementForTrackerTests(DummyOriginalForTrackerTests())
88+
tracker_name = str(id(replacement))
89+
90+
tracker.release_all_originals()
91+
92+
with pytest.raises(RuntimeError, match="already been released"):
93+
replacement._get_original_module()
94+
assert tracker.get_info_by_name(tracker_name) is None
95+
96+
97+
def test_release_original_module_helper_releases_all_replacements_in_model():
98+
tracker = ModuleReplacementTracker.get_instance()
99+
replacement = DummyReplacementForTrackerTests(DummyOriginalForTrackerTests())
100+
tracker_name = str(id(replacement))
101+
model = DummyContainerForTrackerTests(replacement)
102+
103+
release_original_module_(model)
104+
105+
with pytest.raises(RuntimeError, match="already been released"):
106+
replacement._get_original_module()
107+
assert tracker.get_info_by_name(tracker_name) is None

0 commit comments

Comments
 (0)