diff --git a/maia2/inference.py b/maia2/inference.py index f4f2cfa..d3a4f3e 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -60,8 +60,13 @@ def get_preds(model, dataloader, all_moves_dict_reversed): legal_moves = legal_moves.to(device) logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) - logits_maia_legal = logits_maia * legal_moves - probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() + if not legal_moves.any(dim=-1).all(): + # no legal moves available, so just output zeros for all probabilities + # this should prevent any moves from being present in the output dictionary + probs = torch.zeros_like(logits_maia, device='cpu').tolist() + else: + logits_maia_legal = logits_maia + legal_moves.log() + probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() @@ -154,8 +159,14 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): legal_moves = legal_moves.unsqueeze(dim=0).to(device) logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) - logits_maia_legal = logits_maia * legal_moves - probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() + + if not legal_moves.any(dim=-1).all(): + # no legal moves available, so just output zeros for all probabilities + # this should prevent any moves from being present in the output dictionary + probs = torch.zeros_like(logits_maia, device='cpu').tolist() + else: + logits_maia_legal = logits_maia + legal_moves.log() + probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item()