Skip to content

Commit c78ca1d

Browse files
committed
Fix TTQ layer tests: Account for softplus transformation
TTQ parameters (wp, wn, delta) are stored in inverse softplus space to ensure they stay positive during training. Tests need to apply softplus before comparing to expected initialization values.
1 parent 4b43edf commit c78ca1d

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

tests/test_ttq_layers.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,17 @@ def test_parameters_initialized_properly(self) -> None:
3636
"""TTQ parameters should be initialized per paper (Eq. 2)."""
3737
layer = TTQLinear(64, 32)
3838
weight_mean_abs = layer.weight.data.abs().mean()
39+
40+
# wp and wn are stored in inverse softplus space, apply softplus to compare
41+
wp_actual = torch.nn.functional.softplus(layer.wp)
42+
wn_actual = torch.nn.functional.softplus(layer.wn)
43+
delta_actual = torch.nn.functional.softplus(layer.delta)
44+
3945
# wp and wn should be initialized to E[|W|]
40-
assert torch.allclose(layer.wp, weight_mean_abs, rtol=1e-5)
41-
assert torch.allclose(layer.wn, weight_mean_abs, rtol=1e-5)
46+
assert torch.allclose(wp_actual, weight_mean_abs, rtol=1e-5)
47+
assert torch.allclose(wn_actual, weight_mean_abs, rtol=1e-5)
4248
# delta should be initialized to 0.7 * E[|W|]
43-
assert torch.allclose(layer.delta, 0.7 * weight_mean_abs, rtol=1e-5)
49+
assert torch.allclose(delta_actual, 0.7 * weight_mean_abs, rtol=1e-5)
4450

4551
def test_numerical_stability_during_training(self) -> None:
4652
"""Training should not produce NaN losses."""
@@ -98,11 +104,17 @@ def test_parameters_initialized_properly(self) -> None:
98104
"""TTQ parameters should be initialized per paper (Eq. 2)."""
99105
layer = TTQConv2d(3, 16, kernel_size=3)
100106
weight_mean_abs = layer.weight.data.abs().mean()
107+
108+
# wp and wn are stored in inverse softplus space, apply softplus to compare
109+
wp_actual = torch.nn.functional.softplus(layer.wp)
110+
wn_actual = torch.nn.functional.softplus(layer.wn)
111+
delta_actual = torch.nn.functional.softplus(layer.delta)
112+
101113
# wp and wn should be initialized to E[|W|]
102-
assert torch.allclose(layer.wp, weight_mean_abs, rtol=1e-5)
103-
assert torch.allclose(layer.wn, weight_mean_abs, rtol=1e-5)
114+
assert torch.allclose(wp_actual, weight_mean_abs, rtol=1e-5)
115+
assert torch.allclose(wn_actual, weight_mean_abs, rtol=1e-5)
104116
# delta should be initialized to 0.7 * E[|W|]
105-
assert torch.allclose(layer.delta, 0.7 * weight_mean_abs, rtol=1e-5)
117+
assert torch.allclose(delta_actual, 0.7 * weight_mean_abs, rtol=1e-5)
106118

107119
def test_numerical_stability_during_training(self) -> None:
108120
"""Training should not produce NaN losses."""

0 commit comments

Comments
 (0)