-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathptb.py
More file actions
103 lines (77 loc) · 2.57 KB
/
ptb.py
File metadata and controls
103 lines (77 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# coding: utf-8
import sys
import os
sys.path.append('..')
try:
import urllib.request
except ImportError:
raise ImportError('Use Python3!')
import pickle
import numpy as np
url_base = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/'
key_file = {
'train':'ptb.train.txt',
'test':'ptb.test.txt',
'valid':'ptb.valid.txt'
}
save_file = {
'train':'ptb.train.npy',
'test':'ptb.test.npy',
'valid':'ptb.valid.npy'
}
vocab_file = 'ptb.vocab.pkl'
dataset_dir = os.path.dirname(os.path.abspath(__file__))
def _download(file_name):
file_path = dataset_dir + '/' + file_name
if os.path.exists(file_path):
return
print('Downloading ' + file_name + ' ... ')
try:
urllib.request.urlretrieve(url_base + file_name, file_path)
except urllib.error.URLError:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
urllib.request.urlretrieve(url_base + file_name, file_path)
print('Done')
def load_vocab():
vocab_path = dataset_dir + '/' + vocab_file
if os.path.exists(vocab_path):
with open(vocab_path, 'rb') as f:
word_to_id, id_to_word = pickle.load(f)
return word_to_id, id_to_word
word_to_id = {}
id_to_word = {}
data_type = 'train'
file_name = key_file[data_type]
file_path = dataset_dir + '/' + file_name
_download(file_name)
words = open(file_path).read().replace('\n', '<eos>').strip().split()
for i, word in enumerate(words):
if word not in word_to_id:
tmp_id = len(word_to_id)
word_to_id[word] = tmp_id
id_to_word[tmp_id] = word
with open(vocab_path, 'wb') as f:
pickle.dump((word_to_id, id_to_word), f)
return word_to_id, id_to_word
def load_data(data_type='train'):
'''
:param data_type: 데이터 유형: 'train' or 'test' or 'valid (val)'
:return:
'''
if data_type == 'val': data_type = 'valid'
save_path = dataset_dir + '/' + save_file[data_type]
word_to_id, id_to_word = load_vocab()
if os.path.exists(save_path):
corpus = np.load(save_path)
return corpus, word_to_id, id_to_word
file_name = key_file[data_type]
file_path = dataset_dir + '/' + file_name
_download(file_name)
words = open(file_path).read().replace('\n', '<eos>').strip().split()
corpus = np.array([word_to_id[w] for w in words])
np.save(save_path, corpus)
return corpus, word_to_id, id_to_word
if __name__ == '__main__':
for data_type in ('train', 'val', 'test'):
load_data(data_type)