Skip to content

Commit 541dd31

Browse files
♟️ Add comprehensive chess training example with domain tokenizer (#27)
* Add chess training example with Lichess database parser * Add examples/chess.py: downloads and parses Lichess PGN databases for transformer training * Streams large PGN files, removes metadata/annotations, filters games with 2+ moves * Uses temporary directory, outputs clean move sequences ready for training * Add zstandard dependency in examples-dependencies optional group for PGN decompression * feat: Add deterministic SAN ChessTokenizer * Introduces `ChessTokenizer`, a serializable tokenizer with a pre-generated vocabulary for Standard Algebraic Notation (SAN). * The vocabulary includes all valid moves, promotions, castling, move numbers, check/mate symbols, and game termination markers. * Implements a robust `encode` method to correctly parse PGN strings into distinct tokens. * Complete chess training example with ChessTokenizer * Add examples/chess_tokenizer.py: deterministic tokenizer for chess moves with comprehensive SAN vocabulary * Transform chess.py into full training pipeline with ChessDataLoader, model training, and move generation demo * Add tests/examples/test_chess_tokenizer.py: comprehensive test suite for chess tokenizer functionality * Include chess-optimized model config (larger context, sliding windows) and famous opening position demos * Explicit boolean Co-authored-by: Aleks <ayeganov@users.noreply.github.com> * Refactor chess.py: remove duplication and magic numbers * Replace manual token generation loop with model.generate() method to eliminate code duplication * Generate chess moves one at a time instead of batch generation for clearer logic and easier debugging * Extract magic number 80 to GAME_PREVIEW_MAX_LENGTH constant and add type hints to module-level constants * Remove unused torch.nn.functional import and add missing return type hints --------- Co-authored-by: Aleksandr V Yeganov <ayeganov@gmail.com> Co-authored-by: Aleks <ayeganov@users.noreply.github.com>
1 parent 29a2730 commit 541dd31

8 files changed

Lines changed: 809 additions & 14 deletions

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ projects using ScratchGPT.
6868

6969
Please take a look at the [simple example](./examples/simple.py) in the examples folder.
7070

71+
**Note:** Some examples require additional dependencies. To run all examples, install the optional dependencies:
72+
```bash
73+
uv sync --extra examples-dependencies
74+
```
75+
7176
## Usage
7277

7378
### Training

examples/chess.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Chess Engine Training Example - Train a transformer to predict chess moves using ScratchGPT
4+
5+
This script demonstrates training a GPT-style model on chess games from the Lichess database.
6+
It downloads a collection of games in PGN format, parses them into move sequences,
7+
and trains a transformer to continue games by predicting the next moves.
8+
9+
The model learns chess patterns without knowing the rules - it just sees that certain
10+
move sequences tend to follow others in master games from Lichess.
11+
12+
Usage:
13+
python chess.py
14+
python chess.py -g https://database.lichess.org/blitz/lichess_db_blitz_rated_2024-01.pgn.zst
15+
"""
16+
17+
import argparse
18+
import re
19+
import sys
20+
import tempfile
21+
import time
22+
from pathlib import Path
23+
from urllib.parse import urlparse
24+
from urllib.request import urlretrieve
25+
26+
import torch
27+
import zstandard as zstd
28+
from torch.optim import AdamW
29+
30+
from examples.chess_tokenizer import ChessTokenizer
31+
from scratchgpt import (
32+
ScratchGPTArchitecture,
33+
ScratchGPTConfig,
34+
ScratchGPTTraining,
35+
Trainer,
36+
TransformerLanguageModel,
37+
save_tokenizer,
38+
)
39+
from scratchgpt.data import create_data_source
40+
41+
# Alternative: use character-level tokenization
42+
# from scratchgpt import CharTokenizer
43+
44+
# Default Lichess database file
45+
DEFAULT_LICHESS_URL: str = "https://database.lichess.org/standard/lichess_db_standard_rated_2016-02.pgn.zst"
46+
GAME_PREVIEW_MAX_LENGTH: int = 80
47+
48+
49+
def parse_args() -> argparse.Namespace:
50+
"""Parse command line arguments."""
51+
parser = argparse.ArgumentParser(description="Train a chess move predictor using ScratchGPT")
52+
parser.add_argument(
53+
"-g",
54+
"--game-url",
55+
type=str,
56+
default=DEFAULT_LICHESS_URL,
57+
help=f"Lichess database URL to download (default: {DEFAULT_LICHESS_URL})",
58+
)
59+
return parser.parse_args()
60+
61+
62+
class ChessDataLoader:
63+
"""Handles downloading and parsing of Lichess chess databases."""
64+
65+
def __init__(self, game_url: str) -> None:
66+
self.game_url = game_url
67+
68+
def download_and_parse(self) -> str:
69+
"""Download, decompress, and parse chess games into clean move sequences."""
70+
with tempfile.TemporaryDirectory() as tmp_dir:
71+
temp_path = Path(tmp_dir)
72+
print(f"Working in temporary directory: {temp_path}")
73+
pgn_file = self._download_and_decompress(temp_path)
74+
games_text = self._parse_pgn_to_games(pgn_file)
75+
return games_text
76+
77+
def _download_and_decompress(self, temp_dir: Path) -> Path:
78+
"""Download and decompress the Lichess database file."""
79+
filename = Path(urlparse(self.game_url).path).name
80+
compressed_file = temp_dir / filename
81+
82+
print(f"Downloading: {filename}")
83+
print("This may take several minutes depending on file size...")
84+
urlretrieve(self.game_url, compressed_file)
85+
86+
pgn_file = temp_dir / filename.replace(".zst", "")
87+
print(f"Decompressing: {filename}")
88+
89+
dctx = zstd.ZstdDecompressor()
90+
with open(compressed_file, "rb") as compressed_fp, open(pgn_file, "wb") as output_fp:
91+
dctx.copy_stream(compressed_fp, output_fp)
92+
93+
# Remove compressed file to save space
94+
compressed_file.unlink()
95+
return pgn_file
96+
97+
def _parse_pgn_to_games(self, pgn_file: Path) -> str:
98+
"""Parse PGN file and extract move sequences."""
99+
print(f"Parsing games from: {pgn_file.name}")
100+
101+
games = []
102+
current_game_lines = []
103+
games_processed = 0
104+
105+
with open(pgn_file, encoding="utf-8", errors="ignore") as f:
106+
for line_num, line in enumerate(f, 1):
107+
line = line.strip()
108+
109+
if line_num % 1_000_000 == 0:
110+
print(f"Processed {line_num:,} lines, found {games_processed:,} games")
111+
if line.startswith("["):
112+
continue
113+
if not line:
114+
continue
115+
116+
current_game_lines.append(line)
117+
118+
if any(result in line for result in ["1-0", "0-1", "1/2-1/2", "*"]):
119+
game_text = " ".join(current_game_lines).strip()
120+
clean_text = self._clean_game_text(game_text)
121+
# Only keep games with more than 2 moves
122+
has_many_moves = len(clean_text.split()) > 2
123+
if has_many_moves:
124+
games.append(clean_text)
125+
games_processed += 1
126+
127+
# Reset for next game
128+
current_game_lines = []
129+
130+
print(f"Extracted {len(games)} valid games")
131+
return "\n".join(games)
132+
133+
def _clean_game_text(self, game_text: str) -> str:
134+
"""Clean annotations and comments from game text."""
135+
# Remove comments in curly braces
136+
game_text = re.sub(r"\{[^}]*\}", " ", game_text)
137+
138+
# Remove evaluation annotations like [%eval 0.5]
139+
game_text = re.sub(r"\[%[^\]]*\]", " ", game_text)
140+
141+
# Clean up multiple spaces
142+
game_text = re.sub(r"\s+", " ", game_text).strip()
143+
144+
# Remove game results from the end
145+
for result in ["1-0", "0-1", "1/2-1/2", "*"]:
146+
suffix = f" {result}"
147+
if game_text.endswith(suffix):
148+
game_text = game_text.removesuffix(suffix)
149+
break
150+
return game_text
151+
152+
153+
def create_chess_config(tokenizer_vocab_size: int) -> ScratchGPTConfig:
154+
"""Create a configuration optimized for chess move prediction."""
155+
# Chess-optimized architecture
156+
architecture = ScratchGPTArchitecture(
157+
block_size=256, # Longer context for chess games (can see ~60-80 moves)
158+
embedding_size=384, # Balanced size for chess vocabulary
159+
num_heads=8, # Good attention for chess patterns
160+
num_blocks=6, # Sufficient depth for chess understanding
161+
vocab_size=tokenizer_vocab_size,
162+
)
163+
164+
# Training config optimized for chess patterns
165+
training = ScratchGPTTraining(
166+
max_epochs=15, # Chess patterns learn faster than language
167+
learning_rate=3e-4, # Standard rate works well for chess
168+
batch_size=32, # Good balance for chess sequences
169+
dropout_rate=0.1, # Lower dropout for structured chess patterns
170+
random_seed=1337,
171+
iteration_type="chunking",
172+
)
173+
174+
return ScratchGPTConfig(architecture=architecture, training=training)
175+
176+
177+
def generate_chess_moves(
178+
device: torch.device,
179+
model: TransformerLanguageModel,
180+
tokenizer,
181+
game_start: str,
182+
max_moves: int = 8,
183+
temperature: float = 0.8,
184+
) -> str:
185+
"""
186+
Generate chess moves one at a time.
187+
188+
Uses moderate temperature to balance chess-like patterns with some creativity.
189+
"""
190+
model.eval()
191+
192+
current_game = game_start
193+
194+
with torch.no_grad():
195+
for _ in range(max_moves):
196+
# Encode current game state
197+
context = torch.tensor(tokenizer.encode(current_game)).unsqueeze(0).to(device)
198+
199+
# Generate tokens for one move (typically 4-6 tokens)
200+
context = model.generate(context=context, max_new_tokens=6, temperature=temperature)
201+
current_game = tokenizer.decode(context[0].tolist())
202+
203+
return current_game
204+
205+
206+
def main() -> None:
207+
print("Chess Move Prediction Training with ScratchGPT")
208+
print("=" * 60)
209+
210+
# Parse arguments
211+
args = parse_args()
212+
213+
# Step 1: Download and parse chess data
214+
print("\n--- Downloading and Parsing Chess Games ---")
215+
data_loader = ChessDataLoader(args.game_url)
216+
games_text = data_loader.download_and_parse()
217+
218+
if not games_text.strip():
219+
print("ERROR: No games were parsed successfully!")
220+
sys.exit(1)
221+
222+
# Show sample of parsed games
223+
sample_games = games_text.split("\n")[:3]
224+
print("\nSample parsed games:")
225+
for i, game in enumerate(sample_games, 1):
226+
preview = game[:GAME_PREVIEW_MAX_LENGTH] + "..." if len(game) > GAME_PREVIEW_MAX_LENGTH else game
227+
print(f"{i}: {preview}")
228+
229+
# Step 2: Setup tokenizer
230+
print("\n--- Creating Chess Tokenizer ---")
231+
tokenizer = ChessTokenizer()
232+
print(f"Chess vocabulary size: {tokenizer.vocab_size:,}")
233+
234+
# Alternative approach using character-level tokenization:
235+
# tokenizer = CharTokenizer(text=games_text)
236+
# print(f"Character vocabulary size: {tokenizer.vocab_size}")
237+
#
238+
# Trade-offs:
239+
# - ChessTokenizer: Domain-specific, understands chess moves as units (~10k vocab)
240+
# - CharTokenizer: General, treats chess as character sequences (~60 vocab)
241+
# - ChessTokenizer should learn chess patterns more efficiently
242+
243+
# Step 3: Create chess-optimized configuration
244+
print("\n--- Creating Chess Model Configuration ---")
245+
config = create_chess_config(tokenizer.vocab_size)
246+
print(
247+
f"Model configuration: {config.architecture.embedding_size}D embeddings, "
248+
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads"
249+
)
250+
# Step 4: Setup device and model
251+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
252+
print(f"\nUsing device: {device}")
253+
254+
if device.type == "cpu":
255+
print("⚠️ WARNING: Training on CPU will be slow!")
256+
print(" Expected time: 1-2 hours per epoch")
257+
response = input("Continue? (y/N): ")
258+
if response.lower() != "y":
259+
sys.exit(1)
260+
261+
model = TransformerLanguageModel(config)
262+
model = model.to(device)
263+
total_params = sum(p.numel() for p in model.parameters())
264+
print(f"Model parameters: {total_params:,}")
265+
266+
# Step 5: Setup training
267+
optimizer = AdamW(model.parameters(), lr=config.training.learning_rate, betas=(0.9, 0.95), weight_decay=0.01)
268+
269+
# Create temporary file for chess games and data source
270+
with tempfile.TemporaryDirectory() as tmp_dir:
271+
temp_path = Path(tmp_dir)
272+
chess_games_file = temp_path / "chess_games.txt"
273+
274+
# Save parsed games to file
275+
with open(chess_games_file, "w", encoding="utf-8") as f:
276+
f.write(games_text)
277+
278+
# Create data source using ScratchGPT's standard approach
279+
data_source = create_data_source(str(chess_games_file))
280+
281+
# Create experiment directory
282+
experiment_dir = temp_path / "chess_experiment"
283+
284+
# Create trainer
285+
trainer = Trainer(
286+
model=model, config=config.training, optimizer=optimizer, experiment_path=experiment_dir, device=device
287+
)
288+
289+
# Save tokenizer
290+
save_tokenizer(experiment_dir, tokenizer)
291+
292+
# Step 6: Training
293+
print("\n--- Starting Chess Training ---")
294+
print("The model will learn to predict chess moves based on grandmaster games")
295+
print("Press Ctrl-C to stop training early and proceed to move generation demo")
296+
297+
start_time = time.time()
298+
299+
try:
300+
trainer.train(data_source=data_source, tokenizer=tokenizer)
301+
print(f"\n✅ Training completed in {time.time() - start_time:.1f} seconds")
302+
except KeyboardInterrupt:
303+
print(f"\n⚠️ Training interrupted after {time.time() - start_time:.1f} seconds")
304+
print("Proceeding with chess move generation demo...")
305+
306+
# Step 7: Chess Move Generation Demo
307+
print("\n--- Chess Move Generation Demo ---")
308+
model.eval()
309+
310+
# Test with famous chess openings
311+
test_positions = [
312+
"1. e4 e5 2. Nf3", # Italian Game start
313+
"1. d4 d5 2. c4", # Queen's Gambit
314+
"1. e4 c5", # Sicilian Defense
315+
"1. Nf3 Nf6 2. c4", # English Opening
316+
"1. e4 e6 2. d4", # French Defense
317+
]
318+
319+
print("Generating continuations for famous chess openings:")
320+
print("=" * 70)
321+
322+
for position in test_positions:
323+
print(f"\nPosition: {position}")
324+
print("-" * 50)
325+
326+
# Generate continuation
327+
continuation = generate_chess_moves(
328+
device=device, model=model, tokenizer=tokenizer, game_start=position + " ", max_moves=6, temperature=0.8
329+
)
330+
331+
# Extract generated part
332+
generated_part = continuation[len(position) :].strip()
333+
334+
# Show first several moves of continuation
335+
generated_moves = generated_part.split()[:12] # Show ~6 moves
336+
if generated_moves:
337+
print(f"Continuation: {' '.join(generated_moves)}")
338+
else:
339+
print("Generated: (no valid continuation)")
340+
341+
print("\n" + "=" * 70)
342+
print("Chess move prediction training complete!")
343+
print("\nWhat the model learned:")
344+
print("- Chess move patterns from thousands of grandmaster games")
345+
print("- Common responses to popular openings")
346+
print("- Typical piece development and tactical motifs")
347+
print("- The model doesn't know chess rules, just statistical patterns!")
348+
349+
print(f"\nExperiment saved temporarily to: {experiment_dir}")
350+
print("All files will be cleaned up when the script exits.")
351+
352+
353+
if __name__ == "__main__":
354+
main()

0 commit comments

Comments
 (0)