@@ -36,11 +36,17 @@ def test_parameters_initialized_properly(self) -> None:
3636 """TTQ parameters should be initialized per paper (Eq. 2)."""
3737 layer = TTQLinear (64 , 32 )
3838 weight_mean_abs = layer .weight .data .abs ().mean ()
39+
40+ # wp and wn are stored in inverse softplus space, apply softplus to compare
41+ wp_actual = torch .nn .functional .softplus (layer .wp )
42+ wn_actual = torch .nn .functional .softplus (layer .wn )
43+ delta_actual = torch .nn .functional .softplus (layer .delta )
44+
3945 # wp and wn should be initialized to E[|W|]
40- assert torch .allclose (layer . wp , weight_mean_abs , rtol = 1e-5 )
41- assert torch .allclose (layer . wn , weight_mean_abs , rtol = 1e-5 )
46+ assert torch .allclose (wp_actual , weight_mean_abs , rtol = 1e-5 )
47+ assert torch .allclose (wn_actual , weight_mean_abs , rtol = 1e-5 )
4248 # delta should be initialized to 0.7 * E[|W|]
43- assert torch .allclose (layer . delta , 0.7 * weight_mean_abs , rtol = 1e-5 )
49+ assert torch .allclose (delta_actual , 0.7 * weight_mean_abs , rtol = 1e-5 )
4450
4551 def test_numerical_stability_during_training (self ) -> None :
4652 """Training should not produce NaN losses."""
@@ -98,11 +104,17 @@ def test_parameters_initialized_properly(self) -> None:
98104 """TTQ parameters should be initialized per paper (Eq. 2)."""
99105 layer = TTQConv2d (3 , 16 , kernel_size = 3 )
100106 weight_mean_abs = layer .weight .data .abs ().mean ()
107+
108+ # wp and wn are stored in inverse softplus space, apply softplus to compare
109+ wp_actual = torch .nn .functional .softplus (layer .wp )
110+ wn_actual = torch .nn .functional .softplus (layer .wn )
111+ delta_actual = torch .nn .functional .softplus (layer .delta )
112+
101113 # wp and wn should be initialized to E[|W|]
102- assert torch .allclose (layer . wp , weight_mean_abs , rtol = 1e-5 )
103- assert torch .allclose (layer . wn , weight_mean_abs , rtol = 1e-5 )
114+ assert torch .allclose (wp_actual , weight_mean_abs , rtol = 1e-5 )
115+ assert torch .allclose (wn_actual , weight_mean_abs , rtol = 1e-5 )
104116 # delta should be initialized to 0.7 * E[|W|]
105- assert torch .allclose (layer . delta , 0.7 * weight_mean_abs , rtol = 1e-5 )
117+ assert torch .allclose (delta_actual , 0.7 * weight_mean_abs , rtol = 1e-5 )
106118
107119 def test_numerical_stability_during_training (self ) -> None :
108120 """Training should not produce NaN losses."""
0 commit comments