-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
125 lines (97 loc) · 3.75 KB
/
dataset.py
File metadata and controls
125 lines (97 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
import random
import time
from utils import tokenize, partition, filter_empty
class TextDataset(Dataset):
"""
Args:
cache_dir (string): The directory to cache the transformer datasets and where to retrieve them
split (string): train / validation / test
seq_len (int): Length per block
block_len (int): Number of blocks to return per batch
device (string): cuda / cpu
"""
def __init__(self,
name,
cache_dir,
split,
seq_len,
block_len,
device="cuda",
sep_padding=False,
max_len=None
):
super().__init__()
if name == "pg19":
self.data = load_dataset("pg19", split=split, cache_dir=cache_dir)
key = "text"
elif name == "scientific_papers":
self.data = load_dataset("scientific_papers", "arxiv", split=split, cache_dir=cache_dir)
key = "article"
print("Dataset loaded")
start = time.time()
# List[(total_len, seq_len)]
self.data = filter_empty(partition(tokenize([data[key] for data in self.data]),
max_len=seq_len),
min_len=block_len + 1)
print("Dataset tokenized and partitioned in ", time.time() - start)
self.seq_len = seq_len
self.block_len = block_len
self.device = device
if sep_padding:
self.data = self.add_sep_padding(self.data, w=128)
if max_len is not None:
self.data = [x[:max_len] for x in self.data]
self.size = sum([x.size(0) for x in self.data]) // block_len
@staticmethod
def add_sep_padding(data, w, seq_len=512, p=0.1, sep_token=102):
"""
Args:
data (List[Tensor]): List of tensors
w (int): window size
seq_len (int): sequence length
p (int): padding probability
sep_token (int): token id associated with [SEP]
Returns:
padded (List[Tensor]): List of padded tensors
"""
assert seq_len % w == 0
assert seq_len // w > 1
padded = []
for x in data:
ans = []
x = x.view(-1, w)
for window in x:
ans.append(window)
if random.random() < p:
length = random.randrange(1, (seq_len // w) - 1)
for t in range(length):
pad = torch.full((w,), sep_token)
if t == 0:
pad[0] = 50
if t == length - 1:
pad[-1] = 51
ans.append(pad)
while (len(ans) * w) % seq_len != 0:
ans.append(torch.full((w,), sep_token))
padded.append(torch.stack(ans).view(-1, seq_len))
for pad in padded:
assert pad.shape == (pad.size(0), seq_len)
return padded
def __getitem__(self, index):
"""
Index is not used
Returns:
output (Tensor): Tensor with shape (block_len, seq_len+1)
"""
bidx = random.randrange(0, len(self.data))
tidx = random.randrange(0, self.data[bidx].size(0) - self.block_len)
last_token = self.data[bidx][tidx+1:tidx+self.block_len+1, 0].unsqueeze(-1).long()
data = self.data[bidx][tidx:tidx+self.block_len].long()
data = torch.concat([data, last_token], axis=-1)
assert data.shape == (self.block_len, self.seq_len+1)
return data.to(self.device)
def __len__(self):
return self.size