-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
234 lines (178 loc) · 9.81 KB
/
train.py
File metadata and controls
234 lines (178 loc) · 9.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import random
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datetime import datetime
import networks.config as config
from core.logger import initialize_logger, log_info, log_warning, log_debug, log_error, start_loading_animation, stop_loading_animation, TermColors
from Minesweeper import MinesweeperGame
from networks.Minesweeper_CNN import MinesweeperCNN
# 初始化日志记录器
initialize_logger(app_name="MinesweeperAI-Trainer", config_debug_mode=True)
def encode_board(board):
"""
将游戏盘面数组编码为 9 通道的 one-hot 表示。
- board: 单个 (height, width) 的盘面数组
- returns: (9, height, width) 的 numpy 数组
"""
height, width = board.shape
encoded = np.zeros((9, height, width), dtype=np.float32)
for i in range(1, 9):
encoded[i-1, :, :] = (board == i)
encoded[8, :, :] = (board == -1)
return encoded
class MinesweeperDataset(Dataset):
"""
一个 PyTorch Dataset,用于动态生成不同尺寸和雷密度的扫雷训练数据。
"""
def __init__(self, num_samples):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
"""
生成并返回一个训练样本 (输入, 标签)。
每次调用都会随机生成一个在config规定范围内的棋盘。
"""
while True:
try:
# 1. 随机化棋盘尺寸和雷密度
h = random.randint(config.MIN_BOARD_HEIGHT, config.MAX_BOARD_HEIGHT)
w = random.randint(config.MIN_BOARD_WIDTH, config.MAX_BOARD_WIDTH)
density = random.uniform(config.MIN_MINE_DENSITY, config.MAX_MINE_DENSITY)
num_mines = int(h * w * density)
# 确保雷数有效
if num_mines <= 0: num_mines = 1
# 确保至少有9个非雷格子可以点击(用于首次点击)
if num_mines >= h * w - 9: continue
# 2. 生成游戏局面
game = MinesweeperGame(h, w, num_mines, {})
first_r, first_c = random.randint(0, h - 1), random.randint(0, w - 1)
game._place_mines(first_r, first_c)
game._calculate_numbers()
solution = np.array(game.solution_board)
y_true = (solution == -1).astype(np.float32) # 目标是预测未揭开的格子
player_view = -np.ones((h, w), dtype=np.int8) # 初始化为全部未揭开
safe_cells = np.argwhere(solution != -1) # 找出所有安全格子
if len(safe_cells) == 0: continue # 如果没有安全格子(理论上不可能,除非棋盘太小,但以防万一)
# 随机揭开一定比例的安全格子
num_to_reveal = int(len(safe_cells) * random.uniform(config.REVEAL_PERCENT_MIN, config.REVEAL_PERCENT_MAX))
if num_to_reveal == 0: num_to_reveal = 1 # 至少揭开一个
if num_to_reveal > len(safe_cells): num_to_reveal = len(safe_cells) # 避免超出可揭开格子数
reveal_indices = np.random.choice(len(safe_cells), num_to_reveal, replace=False)
for index in reveal_indices:
r, c = safe_cells[index]
player_view[r, c] = solution[r, c] # 将真实数字填充到玩家视角
X_encoded = encode_board(player_view)
return torch.from_numpy(X_encoded), torch.from_numpy(y_true)
except (ValueError, IndexError) as e:
continue
def collate_fn_pad(batch):
"""
自定义的collate_fn,用于处理可变大小的输入。
它会将批次内所有棋盘填充到该批次中最大棋盘的尺寸。
"""
max_h = max([item[0].shape[1] for item in batch])
max_w = max([item[0].shape[2] for item in batch])
padded_inputs = []
padded_labels = []
for inputs, labels in batch:
h, w = inputs.shape[1], inputs.shape[2]
pad_h = max_h - h
pad_w = max_w - w
padded_input = nn.functional.pad(inputs, (0, pad_w, 0, pad_h), mode='constant', value=0)
padded_label = nn.functional.pad(labels, (0, pad_w, 0, pad_h), mode='constant', value=-1)
padded_inputs.append(padded_input)
padded_labels.append(padded_label)
return torch.stack(padded_inputs), torch.stack(padded_labels)
def main():
log_info("=============== 开始训练扫雷AI模型 ===============")
log_info(f"设备: {config.DEVICE}")
log_info(f"训练尺寸范围: 高({config.MIN_BOARD_HEIGHT}-{config.MAX_BOARD_HEIGHT}), 宽({config.MIN_BOARD_WIDTH}-{config.MAX_BOARD_WIDTH})")
log_info(f"雷密度范围: {config.MIN_MINE_DENSITY*100:.1f}% - {config.MAX_MINE_DENSITY*100:.1f}%")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
current_model_dir = os.path.join(config.MODEL_DIR, timestamp)
if not os.path.exists(current_model_dir):
os.makedirs(current_model_dir)
log_info(f"已创建本次训练的模型目录: '{current_model_dir}'")
else:
log_warning(f"目录 '{current_model_dir}' 已存在,可能会覆盖文件。")
model = MinesweeperCNN(config.NUM_RES_BLOCKS, config.NUM_FILTERS).to(config.DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=config.LR_DECAY_FACTOR, patience=config.LR_PATIENCE, verbose=True) # 增加patience
# 训练数据集和数据加载器
train_dataset = MinesweeperDataset(config.STEPS_PER_EPOCH * config.BATCH_SIZE)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True, # 每个epoch开始时是否打乱数据
num_workers=config.NUM_WORKERS,
pin_memory=True,
collate_fn=collate_fn_pad
)
# 验证数据集和数据加载器
val_dataset = MinesweeperDataset(config.VALIDATION_STEPS * config.BATCH_SIZE)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False, # 验证集通常不需要打乱
num_workers=config.NUM_WORKERS,
pin_memory=True,
collate_fn=collate_fn_pad
)
log_info(f"使用 {train_loader.num_workers} 个CPU进程进行数据加载")
log_info(f"每个Epoch训练 {len(train_loader.dataset)} 样本,验证 {len(val_loader.dataset)} 样本")
best_val_loss = float('inf')
for epoch in range(config.NUM_EPOCHS):
model.train()
total_train_loss = 0
log_info(f"{TermColors.CYAN}--- 第 {epoch + 1}/{config.NUM_EPOCHS} 轮 (Epoch) ---{TermColors.RESET}")
progress_bar = tqdm(train_loader, desc=f"训练中", bar_format="{l_bar}{bar:20}{r_bar}", dynamic_ncols=True)
for inputs, labels in progress_bar:
inputs, labels = inputs.to(config.DEVICE, non_blocking=True), labels.to(config.DEVICE, non_blocking=True)
optimizer.zero_grad()
outputs = model(inputs)
mask = (inputs[:, 8, :, :] == 1) & (labels != -1)
if mask.sum() == 0:
log_debug("批次中没有需要预测的未揭开单元格或只有填充区域,跳过此批次。")
continue
loss = criterion(outputs[mask], labels[mask])
loss.backward()
optimizer.step()
total_train_loss += loss.item()
progress_bar.set_postfix(loss=f"{loss.item():.4f}")
avg_train_loss = total_train_loss / len(train_loader) if len(train_loader) > 0 else 0
log_info(f"平均训练损失: {TermColors.GREEN}{avg_train_loss:.4f}{TermColors.RESET}")
model.eval()
total_val_loss = 0
start_loading_animation("正在验证模型...", animation_style_key='dots', animation_color=TermColors.YELLOW)
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(config.DEVICE, non_blocking=True), labels.to(config.DEVICE, non_blocking=True)
outputs = model(inputs)
mask = (inputs[:, 8, :, :] == 1) & (labels != -1)
if mask.sum() == 0:
log_debug("验证批次中没有需要预测的未揭开单元格或只有填充区域,跳过此批次。")
continue
loss = criterion(outputs[mask], labels[mask])
total_val_loss += loss.item()
stop_loading_animation(success=True)
avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
log_info(f"平均验证损失: {TermColors.YELLOW}{avg_val_loss:.4f}{TermColors.RESET}")
scheduler.step(avg_val_loss)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model_path = os.path.join(current_model_dir, config.BEST_MODEL_NAME)
torch.save(model.state_dict(), model_path)
log_info(f"{TermColors.MAGENTA}🎉 新的最佳模型已保存到 '{model_path}'! 验证损失降至: {best_val_loss:.4f}{TermColors.RESET}")
else:
log_debug(f"验证损失未改善,当前最佳: {best_val_loss:.4f}")
log_info("=============== 训练完成 ===============")
if __name__ == "__main__":
main()