-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata_loader_stage1.py
More file actions
313 lines (271 loc) · 11.3 KB
/
data_loader_stage1.py
File metadata and controls
313 lines (271 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import torch as T
from torch.utils.data import DataLoader, Dataset
import pickle
import os
import numpy as np
from collections import Counter
EOS = '<eos>'
START = '<start>'
UNKNOWN = '<unknown>'
EXPLETIVES = ['arse',
'ass',
'asshole',
'bastard',
'bitch',
'bitchin',
'blowjob',
'blowjobs',
'bollocks',
'child-fucker',
'cock',
'crap',
'cunt',
'christ',
'damn',
'fuck',
'gangbang',
'goddamn',
'godsdamn',
'hell',
'jesus',
'motherfucker',
'nigga',
'nigger',
'penis',
'penises',
'pussy',
'puta',
'shit',
'shitass',
'twat',
'vagina',
'vaginas',
'whore',
'slut',
'skank'
]
class UbuntuDialogDataset(Dataset):
def __init__(self,
root='.',
wordcount_pkl='wordcount.pkl',
usercount_pkl='usercount.pkl',
turncount_pkl='turncount.pkl',
max_sentence_lengths_pkl='max_sentence_lengths.pkl',
vocab_size=159996, # excluding padding, <unknown>, <eos>, <start>
user_size=None,
min_word_occurrence=None,
min_user_occurrence=5,
max_sentence_length_allowed=50,
max_turns_allowed=10,
coalesce_types=[],
):
'''
Each item is a three-element key-value pair:
* addressee
* speaker
* words (each element is a sequence of words)
None of the tokenized words has the form of '<xxx>', because NLTK always
tokenizes '<' and '>' as separate tokens. So '<xxx>' can be used as
special tokens, e.g. '<unknown>', '<path>', etc.
None of the users has the form of '<xxx>', so the same applies for users
as well.
'''
self.max_sentence_length_allowed = max_sentence_length_allowed
self._pkls = []
self._max_turns_allowed = max_turns_allowed
with open(turncount_pkl, 'rb') as f:
self._turncount = pickle.load(f)
with open(max_sentence_lengths_pkl, 'rb') as f:
self._max_sentence_lengths = pickle.load(f)
with open(wordcount_pkl, 'rb') as f:
self._wordcount = pickle.load(f)
with open(usercount_pkl, 'rb') as f:
self._usercount = pickle.load(f, encoding='windows-1252')
for curdir in os.listdir(root):
files = os.listdir(os.path.join(root, curdir))
pkls = [os.path.join(root, curdir, f)
for f in files
if f.endswith('.pkl')]
self._pkls.extend(pkls)
if coalesce_types is not None:
self._coalesce(coalesce_types)
if not min_word_occurrence:
self._vocab = [''] + list(list(zip(*self._wordcount.most_common(vocab_size)))[0])
else:
self._vocab = [''] + [w for w, c in self._wordcount.items() if c >= min_word_occurrence]
self._vocab.append(START)
self._vocab.append(EOS)
if not min_user_occurrence:
self._users = [''] + list(list(zip(*self._usercount.most_common(user_size)))[0])
else:
self._users = [''] + [u for u, c in self._usercount.items() if c >= min_user_occurrence]
self._ivocab = {w: i for i, w in enumerate(self._vocab)}
self._iusers = {u: i for i, u in enumerate(self._users)}
def _coalesce(self, types):
'''
Right now it is treating any word starting with a '/' as a
"<path>" token.
Certainly inaccurate.
Feel free to update this.
'''
for w in list(self._wordcount.keys()):
if 'path' in types and w[0] == '/':
self._wordcount['<path>'] = self._wordcount.get('<path>', 0) + self._wordcount[w]
del self._wordcount[w]
def __len__(self):
return len(self._pkls)
def __getitem__(self, i):
return self.get_indexed_item(i)
def get_raw_item(self, i):
'''
Gets a dialogue with words/users as strings
'''
with open(self._pkls[i], 'rb') as f:
item = pickle.load(f)
for i in range(len(item['words'])):
item['words'][i] = item['words'][i][:self.max_sentence_length_allowed - 2] # START and EOS
item['words'][i].insert(0, START)
item['words'][i].append(EOS)
return item
def get_indexed_item(self, i):
'''
Gets a dialogue with words/users translated as indices
'''
item = self.get_raw_item(i)
addressee_idx = [self.index_user(u) for u in item['addressee']]
speaker_idx = [self.index_user(u) for u in item['speaker']]
word_idx = [[self.index_word(w) for w in s] for s in item['words']]
if len(speaker_idx) > self._max_turns_allowed:
start = np.random.randint(len(speaker_idx) - self._max_turns_allowed)
else:
start = 0
addressee_idx = addressee_idx[start:start+self._max_turns_allowed]
speaker_idx = speaker_idx[start:start+self._max_turns_allowed]
word_idx = word_idx[start:start+self._max_turns_allowed]
return addressee_idx, speaker_idx, word_idx
def translate_item(self, addressee_idx, speaker_idx, word_idx):
'''
Translates user/word indices into strings, and remove trailing paddings
for each sentence.
'''
addressee = [self.get_user(i) for i in addressee_idx] if addressee_idx is not None else None
speaker = [self.get_user(i) for i in speaker_idx] if speaker_idx is not None else None
words = [[self.get_word(i) for i in s] for s in word_idx] if word_idx is not None else None
if words is not None:
for s in words:
while s[-1] == '':
s.pop(-1)
for word_list in words:
for i in range(len(word_list)):
if word_list[i] in EXPLETIVES:
word_list[i] = '<expletive>'
return addressee, speaker, words
def index(self, addressee, speaker, sentences, padding=False):
addressee = [self.index_user(u) for u in addressee] if addressee is not None else None
speaker = [self.index_user(u) for u in speaker] if speaker is not None else None
words = [[self.get_word(w) for w in s] for s in sentences] if sentences is not None else None
if padding:
maxlen = max(len(s) for s in words)
for s in words:
s.extend([0] * (maxlen - len(s)))
return addressee, speaker, words
@property
def unknown_word_index(self):
return len(self._vocab)
@property
def start_token_index(self):
return self.index_word(START)
@property
def end_token_index(self):
return self.index_word(EOS)
@property
def unknown_user_index(self):
return len(self._users)
@property
def vocab(self):
'''
The vocabulary
First word is padding; the real vocabulary starts from index 1
'''
return self._vocab
@property
def users(self):
'''
The users
First user is null; the real list of users starts from index 1
'''
return self._users
def get_word(self, i):
if 0 <= i < len(self._vocab): # padding already in self._vocab
return self._vocab[i]
elif i == self.unknown_word_index:
return UNKNOWN
else:
raise ValueError('index out of range')
def get_user(self, i):
if 0 <= i < len(self._users): # null already in self._users
return self._users[i]
elif i == self.unknown_user_index:
return UNKNOWN
else:
raise ValueError('index out of range')
def index_word(self, word):
return self._ivocab[word] if word in self._ivocab else self.unknown_word_index
def index_user(self, user):
return self._iusers[user] if user in self._iusers else self.unknown_user_index
def collate_as_tensor(samples):
'''
Returns:
turns: LongTensor of (batch_size,)
sentence_lengths_padded: LongTensor of (batch_size, max(turns))
speaker_padded: LongTensor of (batch_size, max(turns))
addressee_padded: LongTensor of (batch_size, max(turns))
words_padded: LongTensor of (batch_size, max(turns), max(sentence_lengths))
words_reverse_padded: LongTensor of (batch_size, max(turns), max(sentence_lengths))
* For bidirectional RNN on sentence level
'''
batch_size = len(samples)
addressees, speakers, sentences = zip(*samples)
def pad_by_turns(l, max_turns):
return np.array([s + [0] * (max_turns - len(s)) for s in l])
turns = np.array([len(s) for s in speakers])
sentence_lengths = [[len(s) for s in sent] for sent in sentences]
max_turns = np.max(turns)
sentence_lengths_padded = pad_by_turns(sentence_lengths, max_turns)
speaker_padded = pad_by_turns(speakers, max_turns)
addressee_padded = pad_by_turns(addressees, max_turns)
max_sentence_length = np.max(sentence_lengths_padded)
words_padded = np.zeros((batch_size, max_turns, max_sentence_length), dtype=np.int64)
words_reverse_padded = np.zeros((batch_size, max_turns, max_sentence_length), dtype=np.int64)
for i in range(batch_size):
for j in range(turns[i]):
words_padded[i, j, :sentence_lengths[i][j]] = sentences[i][j]
words_reverse_padded[i, j, :sentence_lengths[i][j]] = sentences[i][j][::-1]
if 0: #only for pytorch version
turns = T.from_numpy(turns)
sentence_lengths_padded = T.from_numpy(sentence_lengths_padded)
speaker_padded = T.from_numpy(speaker_padded)
addressee_padded = T.from_numpy(addressee_padded)
words_padded = T.from_numpy(words_padded)
words_reverse_padded = T.from_numpy(words_reverse_padded)
return turns, sentence_lengths_padded, speaker_padded, addressee_padded, words_padded, words_reverse_padded
class UbuntuDialogDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, num_workers=0):
DataLoader.__init__(self,
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_as_tensor,
num_workers=num_workers,
drop_last=True,
)
# Usage:
# dataset = UbuntuDialogDataset('../ubuntu-ranking-dataset-creator/src/dialogs', 'wordcount.pkl', 'usercount.pkl')
# dataloader = UbuntuDialogDataLoader(dataset, 16)
# for item in dataloader:
# turns, sentence_lengths_padded, speaker_padded, addressee_padded, words_padded, words_reverse_padded = item
# ...
#
# wordcount.pkl and usercount.pkl are generated from stage1 parser.
# If you want asynchronous data loading, use something like
# dataloader = UbuntuDialogDataLoader(dataset, 16, 4)