Skip to content

Commit 42af6a3

Browse files
committed
Fix CI: Add type assertions for TTQ layer parameters
- Add runtime type assertions in TTQ forward() methods - Resolves mypy errors about Parameter | Tensor types - Parameters are always Tensors after register_parameter()
1 parent c5f5485 commit 42af6a3

2 files changed

Lines changed: 10 additions & 0 deletions

File tree

bitnet/nn/ttq_conv2d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,10 @@ def inv_softplus(x: float) -> float:
4040
def forward(self, x: Tensor) -> Tensor:
4141
# Pure TTQ: Only quantize weights, use FP32 activations
4242
# This is TTQ as described in the original paper
43+
# Type assertions for mypy - these are always Tensors after register_parameter
44+
assert isinstance(self.wp, Tensor)
45+
assert isinstance(self.wn, Tensor)
46+
assert isinstance(self.delta, Tensor)
47+
4348
w_quant, _, _ = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
4449
return f.conv2d(x, w_quant, self.bias, self.stride, self.padding, self.dilation, self.groups)

bitnet/nn/ttq_linear.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,10 @@ def inv_softplus(x: float) -> float:
4545
def forward(self, x: Tensor) -> Tensor:
4646
# Pure TTQ: Only quantize weights, use FP32 activations
4747
# This is TTQ as described in the original paper
48+
# Type assertions for mypy - these are always Tensors after register_parameter
49+
assert isinstance(self.wp, Tensor)
50+
assert isinstance(self.wn, Tensor)
51+
assert isinstance(self.delta, Tensor)
52+
4853
w_quant, _, _ = ttq_quantize(self.weight, self.wp, self.wn, self.delta)
4954
return f.linear(x, w_quant, self.bias)

0 commit comments

Comments
 (0)