Skip to content

Commit 90f58ae

Browse files
committed
Fix critical TTQ double-scaling bug
Bug: TTQ quantized weights are ALREADY scaled by wp/wn: w_quant = {-wn_pos, 0, +wp_pos} But we were applying beta = (wp_pos + wn_pos) / 2 in dequantization, scaling AGAIN! This is double-scaling. BitNet comparison: - BitNet: quantize to {-1,0,+1}, then scale with beta - TTQ: quantize to {-wn,0,+wp} (already scaled!), beta should be 1.0 Fix: Set beta = 1.0 since TTQ weights are pre-scaled. This was causing the model to be stuck at 10% accuracy.
1 parent 28594ae commit 90f58ae

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

bitnet/nn/ttq_conv2d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ def forward(self, x: Tensor) -> Tensor:
3232
x = f.layer_norm(x, x.shape[1:])
3333
x_quant, gamma = quantize_activations(x, self.num_bits)
3434

35-
# TTQ weight quantization with learned scales
35+
# TTQ weight quantization with learned scales (already scaled!)
3636
w_quant, wp_pos, wn_pos = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
3737

38-
# Use average of positive scales as beta for dequantization
39-
beta = (wp_pos + wn_pos) / 2
38+
# Beta = 1.0 because quantized weights are already scaled by wp/wn
39+
# Unlike BitNet which scales {-1,0,+1} with beta in dequant, TTQ pre-scales
40+
beta = torch.ones_like(wp_pos)
4041

4142
out = f.conv2d(x_quant, w_quant, self.bias, self.stride, self.padding, self.dilation, self.groups)
4243
return dequantize(out, gamma, beta, self.num_bits)

bitnet/nn/ttq_linear.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ def forward(self, x: Tensor) -> Tensor:
3838
x = f.layer_norm(x, x.shape[1:])
3939
x_quant, gamma = quantize_activations(x, self.num_bits)
4040

41-
# TTQ weight quantization with learned scales
41+
# TTQ weight quantization with learned scales (already scaled!)
4242
w_quant, wp_pos, wn_pos = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
4343

44-
# Use average of positive scales as beta for dequantization
45-
beta = (wp_pos + wn_pos) / 2
44+
# Beta = 1.0 because quantized weights are already scaled by wp/wn
45+
# Unlike BitNet which scales {-1,0,+1} with beta in dequant, TTQ pre-scales
46+
beta = torch.ones_like(wp_pos)
4647

4748
out = f.linear(x_quant, w_quant, self.bias)
4849
return dequantize(out, gamma, beta, self.num_bits)

0 commit comments

Comments
 (0)