-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathpretokenize_data.py
More file actions
192 lines (157 loc) · 6.24 KB
/
pretokenize_data.py
File metadata and controls
192 lines (157 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import argparse
import multiprocessing as mp
import numpy as np
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import math
import glob
# Global tokenizer for multiprocessing
_tokenizer = None
def init_worker(tokenizer_name):
"""Initialize tokenizer in each worker process."""
global _tokenizer
_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Set a very large max length to suppress warning by effectively disabling truncation
_tokenizer.model_max_length = int(1e9)
if _tokenizer.eos_token_id is None:
raise ValueError(f"Tokenizer {tokenizer_name} must have an EOS token.")
def tokenize_doc(doc):
"""Tokenize a single document and append EOS token."""
global _tokenizer
text = doc.get("text")
if not text or not isinstance(text, str):
return np.array([], dtype=np.uint16)
tokens = _tokenizer.encode(text, add_special_tokens=False)
tokens.append(_tokenizer.eos_token_id)
tokens_array = np.array(tokens, dtype=np.uint16)
if not ((0 <= tokens_array) & (tokens_array < 2**16)).all():
raise ValueError(
f"Token IDs exceed uint16 range. Vocab size: {_tokenizer.vocab_size}"
)
return tokens_array
def main(args):
target_tokens = args.total_tokens
shard_size = args.shard_size
expected_shards = math.ceil(target_tokens / shard_size)
# Check existing shards
if os.path.isdir(args.output_dir):
existing = glob.glob(os.path.join(args.output_dir, "train_*.npy"))
existing_count = sum(1 for f in existing if os.path.basename(f)[6:-4].isdigit())
if existing_count >= expected_shards:
print(
f"Found {existing_count} shards (>= {expected_shards} expected). Skipping."
)
return
else:
print(f"Found {existing_count}/{expected_shards} shards. Continuing.")
os.makedirs(args.output_dir, exist_ok=True)
print(f"Dataset: {args.dataset}")
print(f"Tokenizer: {args.tokenizer}")
print(f"Target: {target_tokens / 1e9:.1f}B tokens")
print(f"Shard size: {shard_size / 1e6:.0f}M tokens")
print(f"Expected shards: {expected_shards}")
# Load and shuffle dataset
dataset = load_dataset(
args.dataset, split="train", streaming=True, trust_remote_code=True
)
dataset = dataset.shuffle(seed=args.seed, buffer_size=args.buffer_size)
num_proc = args.num_proc if args.num_proc > 0 else max(1, os.cpu_count() * 3 // 4)
print(f"Using {num_proc} processes")
# Initialize processing variables
shard_idx = 0
shard_buffer = np.empty(shard_size, dtype=np.uint16)
tokens_in_shard = 0
total_tokens = 0
with mp.Pool(num_proc, initializer=init_worker, initargs=(args.tokenizer,)) as pool:
with tqdm(total=target_tokens, unit="tokens", desc="Tokenizing") as pbar:
for doc_tokens in pool.imap(
tokenize_doc, iter(dataset), chunksize=args.chunk_size
):
if total_tokens >= target_tokens:
break
if len(doc_tokens) == 0:
continue
# Process tokens from this document
doc_idx = 0
while doc_idx < len(doc_tokens) and total_tokens < target_tokens:
space_left = shard_size - tokens_in_shard
doc_left = len(doc_tokens) - doc_idx
global_left = target_tokens - total_tokens
take = min(space_left, doc_left, global_left)
if take == 0:
break
# Copy tokens to shard buffer
shard_buffer[tokens_in_shard : tokens_in_shard + take] = doc_tokens[
doc_idx : doc_idx + take
]
tokens_in_shard += take
total_tokens += take
pbar.update(take)
doc_idx += take
# Save shard if full
if tokens_in_shard == shard_size:
shard_path = os.path.join(
args.output_dir, f"train_{shard_idx:06d}.npy"
)
np.save(shard_path, shard_buffer)
tqdm.write(
f"Saved shard {shard_idx} ({shard_size / 1e6:.0f}M tokens)"
)
shard_idx += 1
tokens_in_shard = 0
# Save final partial shard
if tokens_in_shard > 0:
shard_path = os.path.join(args.output_dir, f"train_{shard_idx:06d}.npy")
np.save(shard_path, shard_buffer[:tokens_in_shard])
tqdm.write(
f"Saved final shard {shard_idx} ({tokens_in_shard / 1e6:.1f}M tokens)"
)
print(f"Completed. Total tokens: {total_tokens:,}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Pre-tokenize dataset and save as shards"
)
parser.add_argument(
"--dataset",
default="mlfoundations/dclm-baseline-1.0-parquet",
help="HuggingFace dataset name",
)
parser.add_argument(
"--tokenizer",
default="togethercomputer/LLaMA-2-7B-32K",
help="Transformers tokenizer",
)
parser.add_argument(
"--output_dir",
default="~/datasets/dclm_tokenized",
help="Output root directory for shards",
)
parser.add_argument(
"--shard_size",
type=int,
default=1024**3,
help="Tokens per shard (default: 100M tokens)",
)
parser.add_argument(
"--total_tokens",
type=int,
default=10 * 1024**3,
help="Total tokens to process (default: 10B tokens)",
)
parser.add_argument(
"--seed", type=int, default=42, help="Seed for dataset shuffling"
)
parser.add_argument(
"--buffer_size", type=int, default=10000, help="Shuffle buffer size"
)
parser.add_argument(
"--num_proc", type=int, default=-1, help="Number of processes (-1 for auto)"
)
parser.add_argument(
"--chunk_size", type=int, default=256, help="Processing chunk size"
)
args = parser.parse_args()
args.output_dir = os.path.expanduser(args.output_dir)
main(args)