-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
152 lines (129 loc) · 5.13 KB
/
train.py
File metadata and controls
152 lines (129 loc) · 5.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
"""TRN training entry point.
Usage:
# Toy run (synthetic data, no real dataset needed):
python train.py --synthetic --steps 200 --model-size toy
# Real data (pre-tokenized .txt file, char-level tokenization):
python train.py --data path/to/text.txt --steps 10000 --model-size trn_100m
"""
from __future__ import annotations
import argparse
import sys
import tempfile
from pathlib import Path
import numpy as np
import torch
# Allow running from project root without install
sys.path.insert(0, str(Path(__file__).parent / "src"))
from trimemory.config import TRNConfig
from trimemory.data import PackedDataset
from trimemory.eval import compute_perplexity
from trimemory.model import TRNModel
from trimemory.trainer import TrainConfig, Trainer
def _make_synthetic_dataset(
tmp_dir: Path, n_tokens: int, seq_len: int, vocab_size: int
) -> PackedDataset:
"""Create a temporary random token binary file and return a PackedDataset."""
path = tmp_dir / "synthetic.bin"
rng = np.random.default_rng(seed=0)
data = rng.integers(0, vocab_size, size=n_tokens, dtype=np.uint16)
data.tofile(str(path))
return PackedDataset(path, seq_len)
def _tokenize_chars(text: str, tmp_dir: Path, seq_len: int) -> PackedDataset:
"""Char-level tokenization: map characters to uint16 token ids."""
chars = sorted(set(text))
char_to_id = {c: i for i, c in enumerate(chars)}
tokens = np.array([char_to_id[c] for c in text], dtype=np.uint16)
path = tmp_dir / "tokens.bin"
tokens.tofile(str(path))
return PackedDataset(path, seq_len)
def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Train a Temporal Resonance Network")
parser.add_argument(
"--model-size",
default="toy",
choices=["toy", "trn_100m", "trn_400m", "trn_1b"],
)
parser.add_argument("--data", default=None, help="Path to .txt file (char tokenized)")
parser.add_argument(
"--synthetic",
action="store_true",
help="Use synthetic random data (for testing)",
)
parser.add_argument("--steps", type=int, default=1_000)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--seq-len",
type=int,
default=None,
help="Override seq_len (default: from model config)",
)
parser.add_argument("--warmup", type=int, default=100)
parser.add_argument("--checkpoint-dir", default="checkpoints")
parser.add_argument(
"--save-every",
type=int,
default=0,
help="Save checkpoint every N steps (0=disabled)",
)
parser.add_argument("--device", default="cpu")
parser.add_argument("--eval-at-end", action="store_true", default=True)
args = parser.parse_args(argv)
# Build model config
cfg_factory = {
"toy": TRNConfig.toy,
"trn_100m": TRNConfig.trn_100m,
"trn_400m": TRNConfig.trn_400m,
"trn_1b": TRNConfig.trn_1b,
}[args.model_size]
model_cfg = cfg_factory()
seq_len = args.seq_len or model_cfg.max_seq_len
print(
f"Model: {args.model_size} | d_model={model_cfg.d_model} "
f"layers={model_cfg.n_layers} K={model_cfg.n_oscillators}"
)
# Build dataset (synthetic or from text file)
with tempfile.TemporaryDirectory() as _tmp:
tmp_dir = Path(_tmp)
if args.synthetic:
print("Using synthetic random dataset")
n_tokens = max(seq_len * 200, 50_000)
dataset = _make_synthetic_dataset(tmp_dir, n_tokens, seq_len, model_cfg.vocab_size)
elif args.data:
text = Path(args.data).read_text(encoding="utf-8")
dataset = _tokenize_chars(text, tmp_dir, seq_len)
print(f"Loaded text -> {len(dataset)} sequences of length {seq_len}")
else:
parser.error("Provide --data or --synthetic")
print(f"Dataset: {len(dataset)} sequences")
model = TRNModel(model_cfg)
n_params = model.num_parameters()
print(f"Parameters: {n_params:,} (non-embedding)")
train_cfg = TrainConfig(
max_steps=args.steps,
warmup_steps=args.warmup,
lr=args.lr,
batch_size=args.batch_size,
save_interval=args.save_every,
checkpoint_dir=args.checkpoint_dir,
device=args.device,
)
trainer = Trainer(model, dataset, cfg=train_cfg)
losses = trainer.train()
if losses:
n10 = min(10, len(losses))
first10 = sum(losses[:n10]) / n10
last10 = sum(losses[-n10:]) / n10
direction = "improved" if last10 < first10 else "did not improve"
print(
f"\nTraining complete. Loss: {first10:.4f} -> {last10:.4f} ({direction})"
)
if args.eval_at_end:
# Rebuild dataset outside temp dir for final eval — just report from loss history
if losses:
import math
final_loss = sum(losses[-min(10, len(losses)):]) / min(10, len(losses))
print(f"Approx final perplexity (train): {math.exp(final_loss):.2f}")
if __name__ == "__main__":
main()