-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhelpers.py
More file actions
92 lines (82 loc) · 2.33 KB
/
helpers.py
File metadata and controls
92 lines (82 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from chess import SQUARES, Move, KNIGHT, BISHOP, ROOK, QUEEN
import numpy as np
import torch
MOVE_PLANES = [
# up
8, 16, 24, 32, 40, 48, 56,
# down
-8, -16, -24, -32, -40, -48, -56,
# left
-1, -2, -3, -4, -5, -6, -7,
# right
1, 2, 3, 4, 5, 6, 7,
# up-left
7, 14, 21, 28, 35, 42, 49,
# up-right
9, 18, 27, 36, 45, 54, 63,
# down-left
-9, -18, -27, -36, -45, -54, -63,
# down-right
-7, -14, -21, -28, -35, -42, -49,
# knight moves
17, 15, 10, 6, -17, -15, -10, -6,
]
UNDERPROMOTION_MOVE_PLANES = [
# up
8,
# up-left
7,
# up-right
9,
# down
-8,
# down-right
-7,
# down-left
-9
]
PROMOTION_PIECES = [KNIGHT, BISHOP, ROOK, QUEEN]
planes = []
for trip in MOVE_PLANES:
plane = []
up_plane = trip // 8
right_plane = trip % 8
for square in SQUARES:
up_square = square // 8
right_square = square % 8
# if 0 <= (up_square + up_plane) < 8 and 0 <= (right_square + right_plane) < 8:
# plane.append(Move(square, square + trip))
# else:
# plane.append(Move(square, 100))
if square + trip in SQUARES:
plane.append(Move(square, square + trip))
else:
plane.append(Move(square, 100))
planes.append(plane)
for trip in UNDERPROMOTION_MOVE_PLANES:
for piece in PROMOTION_PIECES:
plane = []
up_plane = trip // 8
right_plane = trip % 8
for square in SQUARES:
up_square = square // 8
right_square = square % 8
if square + trip in SQUARES:
plane.append(Move(square, square + trip, promotion=piece))
else:
plane.append(Move(square, 100))
planes.append(plane)
moves = np.array(planes).reshape((88, 8, 8))
def policy_to_move_probabilities(policy: torch.Tensor):
"""
Convert the policy output of the neural network to a dictionary of move probabilities.
"""
return {m: p for m, p in zip(moves.flatten(), policy.flatten().detach().numpy())}
def move_probabilities_to_policy(move_probabilities: dict):
"""
Convert a dictionary of move probabilities to a policy array.
"""
policy = np.zeros((88, 8, 8))
for move, probability in move_probabilities.items():
policy[moves == move] = probability
return policy