Skip to content

Commit db13507

Browse files
committed
Fix critical TTQ beta mismatch bug
Bug: softplus was applied twice - once in ttq_quantize and once when computing beta. This caused quantized weights to use scale X but dequantization to use scale softplus(X), completely breaking learning. Fix: ttq_quantize now returns (quantized, wp_pos, wn_pos) tuple so layers can use the exact same scales for beta computation. Result: Model was stuck at 10% (random guessing), now should learn.
1 parent 0a265ce commit db13507

3 files changed

Lines changed: 8 additions & 7 deletions

File tree

bitnet/nn/ttq_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def forward(self, x: Tensor) -> Tensor:
3434
x_quant, gamma = quantize_activations(x, self.num_bits)
3535

3636
# TTQ weight quantization with learned scales
37-
w_quant = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
37+
w_quant, wp_pos, wn_pos = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
3838

3939
# Use average of positive scales as beta for dequantization
40-
beta = (f.softplus(self.wp) + f.softplus(self.wn)) / 2
40+
beta = (wp_pos + wn_pos) / 2
4141

4242
out = f.conv2d(x_quant, w_quant, self.bias, self.stride, self.padding, self.dilation, self.groups)
4343
return dequantize(out, gamma, beta, self.num_bits)

bitnet/nn/ttq_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def forward(self, x: Tensor) -> Tensor:
4040
x_quant, gamma = quantize_activations(x, self.num_bits)
4141

4242
# TTQ weight quantization with learned scales
43-
w_quant = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
43+
w_quant, wp_pos, wn_pos = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
4444

4545
# Use average of positive scales as beta for dequantization
46-
beta = (f.softplus(self.wp) + f.softplus(self.wn)) / 2
46+
beta = (wp_pos + wn_pos) / 2
4747

4848
out = f.linear(x_quant, w_quant, self.bias)
4949
return dequantize(out, gamma, beta, self.num_bits)

bitnet/nn/ttq_quantization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44

55

6-
def ttq_quantize(weight: Tensor, wp: Tensor, wn: Tensor, delta: Tensor) -> Tensor:
6+
def ttq_quantize(weight: Tensor, wp: Tensor, wn: Tensor, delta: Tensor) -> tuple[Tensor, Tensor, Tensor]:
77
"""Quantize weights to {-wn, 0, +wp} using Trained Ternary Quantization.
88
99
TTQ (Zhu et al., ICLR 2017) learns per-layer positive/negative scales
@@ -19,7 +19,7 @@ def ttq_quantize(weight: Tensor, wp: Tensor, wn: Tensor, delta: Tensor) -> Tenso
1919
delta: Learnable threshold
2020
2121
Returns:
22-
Quantized tensor in {-wn, 0, +wp}
22+
Tuple of (quantized weights, wp_positive, wn_positive)
2323
"""
2424
# Ensure scales and threshold are positive with softplus (maintains gradients)
2525
wp_pos = f.softplus(wp)
@@ -35,4 +35,5 @@ def ttq_quantize(weight: Tensor, wp: Tensor, wn: Tensor, delta: Tensor) -> Tenso
3535
quantized[neg_mask] = -wn_pos
3636

3737
# Straight-through estimator for gradients
38-
return quantized + (weight - weight.detach())
38+
quantized_ste = quantized + (weight - weight.detach())
39+
return quantized_ste, wp_pos, wn_pos

0 commit comments

Comments
 (0)