From de4396869a7e198bbb4ec14d96bb3b64709b262d Mon Sep 17 00:00:00 2001 From: Marco Date: Tue, 7 Oct 2025 06:22:14 +0900 Subject: [PATCH 1/3] fix masked softmax --- maia2/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maia2/inference.py b/maia2/inference.py index f4f2cfa..09e9710 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -60,7 +60,7 @@ def get_preds(model, dataloader, all_moves_dict_reversed): legal_moves = legal_moves.to(device) logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) - logits_maia_legal = logits_maia * legal_moves + logits_maia_legal = logits_maia + legal_moves.log() probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() From c3b54d9e84248757534e8ddc2fd19b9dedbdd3b4 Mon Sep 17 00:00:00 2001 From: Marco Date: Sun, 16 Nov 2025 02:06:25 -0800 Subject: [PATCH 2/3] fix in inference_each also --- maia2/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maia2/inference.py b/maia2/inference.py index 09e9710..8da5e91 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -154,7 +154,7 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): legal_moves = legal_moves.unsqueeze(dim=0).to(device) logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) - logits_maia_legal = logits_maia * legal_moves + logits_maia_legal = logits_maia + legal_moves.log() probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item() From 4017653518f78f9b7a1e7abf325c62b972ab02c9 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Sat, 21 Feb 2026 23:52:10 -0800 Subject: [PATCH 3/3] add guard for degenerate case where no legal moves are available --- maia2/inference.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/maia2/inference.py b/maia2/inference.py index 8da5e91..d3a4f3e 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -60,8 +60,13 @@ def get_preds(model, dataloader, all_moves_dict_reversed): legal_moves = legal_moves.to(device) logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) - logits_maia_legal = logits_maia + legal_moves.log() - probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() + if not legal_moves.any(dim=-1).all(): + # no legal moves available, so just output zeros for all probabilities + # this should prevent any moves from being present in the output dictionary + probs = torch.zeros_like(logits_maia, device='cpu').tolist() + else: + logits_maia_legal = logits_maia + legal_moves.log() + probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() @@ -154,8 +159,14 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): legal_moves = legal_moves.unsqueeze(dim=0).to(device) logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) - logits_maia_legal = logits_maia + legal_moves.log() - probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() + + if not legal_moves.any(dim=-1).all(): + # no legal moves available, so just output zeros for all probabilities + # this should prevent any moves from being present in the output dictionary + probs = torch.zeros_like(logits_maia, device='cpu').tolist() + else: + logits_maia_legal = logits_maia + legal_moves.log() + probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item()