Skip to content

Commit 8dc0dfe

Browse files
committed
Fix mypy: Add type annotations in mixup_criterion
1 parent 42af6a3 commit 8dc0dfe

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

experiments/training/loops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def mixup_criterion(
6666
criterion: nn.Module, pred: torch.Tensor, y_a: torch.Tensor, y_b: torch.Tensor, lam: float
6767
) -> torch.Tensor:
6868
"""Compute mixup loss."""
69-
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
69+
loss_a: torch.Tensor = criterion(pred, y_a)
70+
loss_b: torch.Tensor = criterion(pred, y_b)
71+
return lam * loss_a + (1 - lam) * loss_b
7072

7173

7274
def train_epoch(

0 commit comments

Comments
 (0)