Skip to content

Commit 341be4f

Browse files
authored
Merge pull request #28 from LabStrangeLoop/feature/chemistry
Feature/chemistry
2 parents 541dd31 + d41325f commit 341be4f

1 file changed

Lines changed: 316 additions & 0 deletions

File tree

examples/chemistry.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Chemical Reaction Prediction Example - Train a transformer to predict reaction products
4+
5+
This script demonstrates training a GPT-style model on chemical reactions from the USPTO
6+
patent database. It downloads reactions in SMILES notation and trains a transformer to
7+
complete reactions by predicting products from reactants.
8+
9+
The model learns chemical transformation patterns without knowing any chemistry - it just
10+
sees that certain molecular structures tend to produce other structures.
11+
12+
We use special tokens [BOS] and [EOS] to mark the beginning and end of each reaction,
13+
helping the model learn when to stop generating.
14+
15+
"""
16+
17+
import shutil
18+
import sys
19+
import tempfile
20+
import time
21+
from pathlib import Path
22+
23+
import torch
24+
from datasets import Dataset, load_dataset
25+
from torch.optim import AdamW
26+
27+
from scratchgpt import (
28+
CharTokenizer,
29+
ScratchGPTArchitecture,
30+
ScratchGPTConfig,
31+
ScratchGPTTraining,
32+
Trainer,
33+
TransformerLanguageModel,
34+
save_tokenizer,
35+
)
36+
from scratchgpt.data import create_data_source
37+
38+
# Test reactions for demonstration
39+
TEST_REACTIONS = [
40+
("CC(=O)O.CCO", "Esterification (acetic acid + ethanol)"),
41+
("c1ccccc1.Cl2", "Chlorination (benzene + chlorine)"),
42+
("CC=C.HBr", "Addition (propene + HBr)"),
43+
("CC(=O)Cl.N", "Amide formation (acetyl chloride + ammonia)"),
44+
("CCO.[O]", "Oxidation (ethanol + oxygen)"),
45+
]
46+
47+
# Display configuration
48+
MAX_DISPLAY_LENGTH: int = 80
49+
SEPARATOR_WIDTH: int = 70
50+
51+
# Summary text
52+
TRAINING_SUMMARY = """
53+
Chemical reaction prediction training complete!
54+
55+
What the model learned:
56+
- Patterns of how molecular structures transform in reactions
57+
- Common functional group conversions (esters, amides, etc.)
58+
- Product structures that typically result from given reactants
59+
- When to stop generating (using [EOS] token)
60+
- The model doesn't know chemistry rules, just statistical patterns!
61+
"""
62+
63+
64+
def truncate_for_display(text: str, max_length: int = MAX_DISPLAY_LENGTH) -> str:
65+
"""Truncate text for display, adding ellipsis if needed."""
66+
if len(text) > max_length:
67+
return text[:max_length] + "..."
68+
return text
69+
70+
71+
def load_reaction_dataset() -> Dataset:
72+
"""
73+
Load the USPTO-50k chemical reaction dataset from HuggingFace.
74+
75+
This dataset contains ~50,000 reactions extracted from US patents,
76+
represented in SMILES notation with atom mapping.
77+
"""
78+
print("Loading USPTO-50k reaction dataset from HuggingFace...")
79+
print("This dataset contains 50,000 chemical reactions from US patents.")
80+
81+
dataset: Dataset = load_dataset("pingzhili/uspto-50k", split="train")
82+
83+
print(f"✓ Loaded {len(dataset):,} reactions")
84+
return dataset
85+
86+
87+
def prepare_reaction_text(dataset: Dataset) -> str:
88+
"""
89+
Extract reaction SMILES and concatenate them into training text.
90+
91+
Reactions are in the format: reactants >> products
92+
Example: CC(=O)O.CCO >> CC(=O)OCC.O (esterification)
93+
94+
We wrap each reaction with special tokens:
95+
[BOS] reaction [EOS]
96+
"""
97+
print("\nPreparing reaction data for training...")
98+
99+
column_names: list[str] = dataset.column_names
100+
print(f"Dataset columns: {column_names}")
101+
102+
# Find reaction column
103+
possible_columns: list[str] = ["rxn_smiles", "reaction_smiles", "text", "reaction", "smiles", "rxn"]
104+
reaction_column: str | None = None
105+
106+
for column in possible_columns:
107+
if column in column_names:
108+
reaction_column = column
109+
break
110+
111+
if reaction_column is None:
112+
print("ERROR: Could not find reaction column!")
113+
print(f"Available columns: {column_names}")
114+
return ""
115+
116+
print(f"Using column: '{reaction_column}'")
117+
118+
# Extract and wrap reactions with special tokens
119+
reactions: list[str] = []
120+
for example in dataset:
121+
reaction: str = str(example[reaction_column]).strip()
122+
if reaction and ">>" in reaction:
123+
wrapped_reaction: str = f"[BOS]{reaction}[EOS]"
124+
reactions.append(wrapped_reaction)
125+
126+
print(f"Extracted {len(reactions):,} valid reactions")
127+
128+
# Show sample reactions
129+
print("\nSample reactions (with special tokens):")
130+
for index, reaction in enumerate(reactions[:3], start=1):
131+
display: str = truncate_for_display(reaction)
132+
print(f" {index}. {display}")
133+
134+
full_text: str = "\n".join(reactions)
135+
print(f"\nTotal text length: {len(full_text):,} characters")
136+
137+
return full_text
138+
139+
140+
def create_chemistry_config(vocab_size: int) -> ScratchGPTConfig:
141+
"""
142+
Create a configuration optimized for chemical reaction prediction.
143+
144+
Chemistry has different patterns than language or chess:
145+
- Reactions can be long (100-300 characters)
146+
- Pattern complexity is between chess and natural language
147+
- Needs to learn molecular substructure relationships
148+
"""
149+
architecture: ScratchGPTArchitecture = ScratchGPTArchitecture(
150+
block_size=256,
151+
embedding_size=256,
152+
num_heads=8,
153+
num_blocks=6,
154+
vocab_size=vocab_size,
155+
)
156+
157+
training: ScratchGPTTraining = ScratchGPTTraining(
158+
max_epochs=15,
159+
learning_rate=3e-4,
160+
batch_size=32,
161+
dropout_rate=0.1,
162+
random_seed=1337,
163+
iteration_type="chunking",
164+
)
165+
166+
return ScratchGPTConfig(architecture=architecture, training=training)
167+
168+
169+
def generate_reaction_products(
170+
device: torch.device,
171+
model: TransformerLanguageModel,
172+
tokenizer: CharTokenizer,
173+
reactants: str,
174+
max_tokens: int = 150,
175+
) -> str:
176+
"""
177+
Generate reaction products from given reactants.
178+
179+
The model completes the reaction by predicting what comes after '>>'.
180+
We start with [BOS] and stop when we hit [EOS].
181+
"""
182+
model.eval()
183+
184+
# Clean reactants - SMILES shouldn't have spaces
185+
reactants_clean: str = reactants.strip().replace(" ", "")
186+
187+
# Get the [EOS] token ID for stopping generation
188+
eos_in_vocab: bool = "[EOS]" in tokenizer.vocabulary
189+
eos_token_id: int | None = tokenizer.encode("[EOS]")[0] if eos_in_vocab else None
190+
191+
# Build prompt with special tokens and reaction arrow
192+
prompt: str = f"[BOS]{reactants_clean}>>"
193+
194+
with torch.no_grad():
195+
context: torch.Tensor = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
196+
generated: torch.Tensor = model.generate(
197+
context, max_new_tokens=max_tokens, temperature=0.8, stop_token=eos_token_id
198+
)
199+
result: str = tokenizer.decode(generated[0].tolist())
200+
201+
return result
202+
203+
204+
def main() -> None:
205+
print("Chemical Reaction Prediction with ScratchGPT")
206+
print("=" * 60)
207+
208+
# Load dataset from HuggingFace
209+
print("\n--- Loading Chemical Reaction Dataset ---")
210+
dataset: Dataset = load_reaction_dataset()
211+
212+
# Prepare reaction text
213+
reactions_text: str = prepare_reaction_text(dataset)
214+
215+
if not reactions_text.strip():
216+
print("ERROR: No valid reactions were extracted!")
217+
sys.exit(1)
218+
219+
# Create character-level tokenizer
220+
print("\n--- Creating Character Tokenizer ---")
221+
tokenizer: CharTokenizer = CharTokenizer(text=reactions_text)
222+
print(f"Vocabulary size: {tokenizer.vocab_size}")
223+
print("Includes special tokens: [BOS] (begin) and [EOS] (end)")
224+
sample_chars: list[str] = sorted(tokenizer.vocabulary)[:20]
225+
print(f"Sample characters: {sample_chars}")
226+
227+
# Create chemistry-optimized config
228+
print("\n--- Creating Chemistry Model Configuration ---")
229+
config: ScratchGPTConfig = create_chemistry_config(tokenizer.vocab_size)
230+
print(
231+
f"Model: {config.architecture.embedding_size}D embeddings, "
232+
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads"
233+
)
234+
235+
# Setup device and model
236+
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
237+
print(f"\nUsing device: {device}")
238+
239+
model: TransformerLanguageModel = TransformerLanguageModel(config)
240+
model = model.to(device)
241+
total_params: int = sum(p.numel() for p in model.parameters())
242+
print(f"Model parameters: {total_params:,}")
243+
244+
# Setup training with temporary directory
245+
temp_path: Path = Path(tempfile.mkdtemp())
246+
reactions_file: Path = temp_path / "reactions.txt"
247+
experiment_dir: Path = temp_path / "chemistry_experiment"
248+
249+
print("\nSaving reactions to temporary file...")
250+
with open(reactions_file, "w", encoding="utf-8") as file:
251+
file.write(reactions_text)
252+
253+
data_source = create_data_source(str(reactions_file))
254+
optimizer: AdamW = AdamW(model.parameters(), lr=config.training.learning_rate)
255+
trainer: Trainer = Trainer(
256+
model=model,
257+
config=config.training,
258+
optimizer=optimizer,
259+
experiment_path=experiment_dir,
260+
device=device,
261+
)
262+
263+
save_tokenizer(experiment_dir, tokenizer)
264+
265+
# Training
266+
print("\n--- Starting Chemical Reaction Training ---")
267+
print("The model will learn to predict reaction products from reactants")
268+
print("Press Ctrl-C to stop training early and see predictions")
269+
270+
start_time: float = time.time()
271+
272+
try:
273+
trainer.train(data_source=data_source, tokenizer=tokenizer)
274+
training_time: float = time.time() - start_time
275+
print(f"\n✅ Training completed in {training_time:.1f} seconds")
276+
except KeyboardInterrupt:
277+
training_time: float = time.time() - start_time
278+
print(f"\n⚠️ Training interrupted after {training_time:.1f} seconds")
279+
print("Proceeding with reaction prediction demo...")
280+
281+
# Prediction demo
282+
print("\n--- Chemical Reaction Prediction Demo ---")
283+
print("Testing the model's ability to predict reaction products")
284+
print("=" * SEPARATOR_WIDTH)
285+
286+
for reactants, reaction_name in TEST_REACTIONS:
287+
print(f"\nReaction: {reaction_name}")
288+
print(f"Reactants: {reactants}")
289+
print("-" * 50)
290+
291+
result: str = generate_reaction_products(device, model, tokenizer, reactants)
292+
293+
# Clean up special tokens from result
294+
result_clean: str = result.replace("[BOS]", "").replace("[EOS]", "")
295+
296+
# Extract predicted products (everything after >>)
297+
if ">>" in result_clean:
298+
predicted_products: str = result_clean.split(">>", 1)[1].strip()
299+
display: str = truncate_for_display(predicted_products)
300+
print(f"Predicted products: {display}")
301+
else:
302+
print("Generated: (incomplete prediction)")
303+
304+
print("\n" + "=" * SEPARATOR_WIDTH)
305+
306+
# Summary
307+
print(TRAINING_SUMMARY)
308+
print(f"Experiment saved temporarily to: {experiment_dir}")
309+
print("All files will be cleaned up when the script exits.")
310+
311+
# Cleanup
312+
shutil.rmtree(temp_path, ignore_errors=True)
313+
314+
315+
if __name__ == "__main__":
316+
main()

0 commit comments

Comments
 (0)