Skip to content

Fix masked softmax in inference#9

Open
mcognetta wants to merge 3 commits intoCSSLab:mainfrom
mcognetta:fix_masked_softmax
Open

Fix masked softmax in inference#9
mcognetta wants to merge 3 commits intoCSSLab:mainfrom
mcognetta:fix_masked_softmax

Conversation

@mcognetta
Copy link

During inference, logits corresponding to illegal moves are supposed to be masked out so that they have probability 0 after softmax. Right now, this is done by multiplying the logit with the legal moves mask before the softmax layer. However, the legal move mask is a set of 0s and 1s and the logits are in unnormalized log space. This means that multiplication by 0 just converts the logit to be 1 in unnormalized real space, so the invalid moves have non-zero probability after softmax. The correct way is to add -inf to all invalid logits so that they have probability 0 in real space.

See the following example:

>>> x = torch.rand((3, 5)); x
tensor([[0.8924, 0.8796, 0.4099, 0.8877, 0.2908],
        [0.3853, 0.9005, 0.4000, 0.5488, 0.3874],
        [0.0757, 0.2595, 0.5019, 0.2247, 0.0422]])

>>> mask = torch.rand((3, 5)) > 0.6; mask
tensor([[False, False, False,  True, False],
        [ True,  True, False, False,  True],
        [False,  True, False,  True,  True]])

# without masking
>>> torch.softmax(x, dim = -1)
tensor([[0.2411, 0.2380, 0.1488, 0.2400, 0.1321],
        [0.1704, 0.2852, 0.1729, 0.2007, 0.1708],
        [0.1706, 0.2050, 0.2613, 0.1980, 0.1650]])

# current masking method; notice that everything is non-zero and invalid moves all have the same probability
>>> torch.softmax(x * mask, dim = -1)
tensor([[0.1555, 0.1555, 0.1555, 0.3779, 0.1555],
        [0.1986, 0.3324, 0.1351, 0.1351, 0.1990],
        [0.1789, 0.2318, 0.1789, 0.2239, 0.1866]])

# new masking method; masked out values are given 0 probability
>>> torch.softmax(x + mask.log(), dim = -1)
tensor([[0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
        [0.2720, 0.4554, 0.0000, 0.0000, 0.2726],
        [0.0000, 0.3610, 0.0000, 0.3486, 0.2904]])

@mcognetta
Copy link
Author

mcognetta commented Oct 8, 2025

This doesn't affect the results too much in most cases, but it can cause some big leakages. On the example testset in the README, the worst case is that 5.1% of the probability mass is leaked to invalid moves.

That position is 'rn1q1rk1/ppp2ppp/4bn2/3p3P/4p3/P3P3/1PPPBPPb/RNBQK3 w Q - 0 11' with elos (1500, 1498). The move probs are

{'a3a4': 0.1074, 'e1f1': 0.1011, 'h5h6': 0.0953, 'g2g3': 0.0947, 'b1c3': 0.0733, 'd2d3': 0.0564, 'e2f1': 0.0533, 'a1a2': 0.049, 'g2g4': 0.0477, 'b2b4': 0.04, 'd2d4': 0.0365, 'e2g4': 0.0322, 'e2b5': 0.0312, 'b2b3': 0.0279, 'f2f3': 0.0269, 'f2f4': 0.0216, 'c2c4': 0.0201, 'c2c3': 0.0118, 'e2f3': 0.0069, 'e2a6': 0.0055, 'e2c4': 0.0051, 'e2d3': 0.0051}

with sum(move_probs.values()) = .949.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates inference-time move masking so illegal moves receive zero probability after softmax by converting the legal-move mask into additive -inf logits (via log() on a 0/1 mask), rather than multiplying logits by the mask.

Changes:

  • Replace logits * legal_moves with logits + legal_moves.log() in batch inference (get_preds).
  • Replace logits * legal_moves with logits + legal_moves.log() in single-position inference (inference_each).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


logits_maia, _, logits_value = model(boards, elos_self, elos_oppo)
logits_maia_legal = logits_maia * legal_moves
logits_maia_legal = logits_maia + legal_moves.log()
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

legal_moves.log() will produce -inf for masked entries; if a position ever has zero legal moves (e.g., checkmate/stalemate) or the mask is accidentally all zeros, softmax over all -inf yields NaN probabilities (0/0). Consider explicitly handling the all-masked case (early return / raise) or using a safe fallback before softmax.

Suggested change
logits_maia_legal = logits_maia + legal_moves.log()
logits_maia_legal = logits_maia + legal_moves.log()
# If a position has no legal moves (all zeros in the mask), logits_maia_legal will be all -inf;
# softmax over all -inf yields NaNs. For such rows, fall back to the unmasked logits.
has_legal = legal_moves.sum(dim=-1, keepdim=True) > 0
if not has_legal.all():
logits_maia_legal = torch.where(has_legal, logits_maia_legal, logits_maia)

Copilot uses AI. Check for mistakes.

logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo)
logits_maia_legal = logits_maia * legal_moves
logits_maia_legal = logits_maia + legal_moves.log()
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same masking approach here: if legal_moves is all zeros, legal_moves.log() makes the entire row -inf and softmax will return NaNs. Please add a guard for terminal positions / empty legal-move sets or otherwise ensure at least one legal move before applying softmax.

Suggested change
logits_maia_legal = logits_maia + legal_moves.log()
# Guard against the case where there are no legal moves (all zeros), which would make
# legal_moves.log() equal to -inf everywhere and softmax return NaNs.
if (legal_moves > 0).any():
legal_moves_for_mask = legal_moves
else:
# In terminal positions with no legal moves, use an all-ones mask so that .log()
# produces zeros and does not introduce -inf values.
legal_moves_for_mask = torch.ones_like(legal_moves)
logits_maia_legal = logits_maia + legal_moves_for_mask.log()

Copilot uses AI. Check for mistakes.
@mcognetta
Copy link
Author

mcognetta commented Feb 22, 2026

Thank you Brother Copilot 🧑‍✈️

I added a guard for the case where no legal moves are available. Copilot's suggestion would have caused all moves (like all 1880 of the moves in the all_moves_dict) to have non-zero probability in this case, which I think is wrong. My implementation just outputs all zeroes for the probability distribution, which will make it so that no moves get returned in the move probability dict output, which I think is more in the spirit of the function.

This should probably have an assertion or something though, so that it doesn't just pass through silently. There is no case I can think of that we would actually want to run inference like this when there are no moves available (the win prob also doesn't make sense here).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants