Conversation
|
This doesn't affect the results too much in most cases, but it can cause some big leakages. On the example testset in the README, the worst case is that 5.1% of the probability mass is leaked to invalid moves. That position is with |
There was a problem hiding this comment.
Pull request overview
Updates inference-time move masking so illegal moves receive zero probability after softmax by converting the legal-move mask into additive -inf logits (via log() on a 0/1 mask), rather than multiplying logits by the mask.
Changes:
- Replace
logits * legal_moveswithlogits + legal_moves.log()in batch inference (get_preds). - Replace
logits * legal_moveswithlogits + legal_moves.log()in single-position inference (inference_each).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
maia2/inference.py
Outdated
|
|
||
| logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) | ||
| logits_maia_legal = logits_maia * legal_moves | ||
| logits_maia_legal = logits_maia + legal_moves.log() |
There was a problem hiding this comment.
legal_moves.log() will produce -inf for masked entries; if a position ever has zero legal moves (e.g., checkmate/stalemate) or the mask is accidentally all zeros, softmax over all -inf yields NaN probabilities (0/0). Consider explicitly handling the all-masked case (early return / raise) or using a safe fallback before softmax.
| logits_maia_legal = logits_maia + legal_moves.log() | |
| logits_maia_legal = logits_maia + legal_moves.log() | |
| # If a position has no legal moves (all zeros in the mask), logits_maia_legal will be all -inf; | |
| # softmax over all -inf yields NaNs. For such rows, fall back to the unmasked logits. | |
| has_legal = legal_moves.sum(dim=-1, keepdim=True) > 0 | |
| if not has_legal.all(): | |
| logits_maia_legal = torch.where(has_legal, logits_maia_legal, logits_maia) |
maia2/inference.py
Outdated
|
|
||
| logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) | ||
| logits_maia_legal = logits_maia * legal_moves | ||
| logits_maia_legal = logits_maia + legal_moves.log() |
There was a problem hiding this comment.
Same masking approach here: if legal_moves is all zeros, legal_moves.log() makes the entire row -inf and softmax will return NaNs. Please add a guard for terminal positions / empty legal-move sets or otherwise ensure at least one legal move before applying softmax.
| logits_maia_legal = logits_maia + legal_moves.log() | |
| # Guard against the case where there are no legal moves (all zeros), which would make | |
| # legal_moves.log() equal to -inf everywhere and softmax return NaNs. | |
| if (legal_moves > 0).any(): | |
| legal_moves_for_mask = legal_moves | |
| else: | |
| # In terminal positions with no legal moves, use an all-ones mask so that .log() | |
| # produces zeros and does not introduce -inf values. | |
| legal_moves_for_mask = torch.ones_like(legal_moves) | |
| logits_maia_legal = logits_maia + legal_moves_for_mask.log() |
|
Thank you Brother Copilot 🧑 I added a guard for the case where no legal moves are available. Copilot's suggestion would have caused all moves (like all 1880 of the moves in the all_moves_dict) to have non-zero probability in this case, which I think is wrong. My implementation just outputs all zeroes for the probability distribution, which will make it so that no moves get returned in the move probability dict output, which I think is more in the spirit of the function. This should probably have an assertion or something though, so that it doesn't just pass through silently. There is no case I can think of that we would actually want to run inference like this when there are no moves available (the win prob also doesn't make sense here). |
During inference, logits corresponding to illegal moves are supposed to be masked out so that they have probability 0 after softmax. Right now, this is done by multiplying the logit with the legal moves mask before the softmax layer. However, the legal move mask is a set of 0s and 1s and the logits are in unnormalized log space. This means that multiplication by 0 just converts the logit to be 1 in unnormalized real space, so the invalid moves have non-zero probability after softmax. The correct way is to add
-infto all invalid logits so that they have probability 0 in real space.See the following example: