-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
69 lines (54 loc) · 2.14 KB
/
main.py
File metadata and controls
69 lines (54 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from util import read_data, split_data, get_vocab, get_word2idx_idx2word, prepare_input_data, prepare_test_data, get_embedding_matrix
from train import Summarization
import sys
def preprocess_data(abstracts_path, titles_path):
print('Reading data ...')
total_df = read_data(abstracts_path, titles_path)
train_df, val_df = split_data(total_df)
print('Preparing vocabulary ...')
vocab = get_vocab(train_df, size=80000)
word2idx, idx2word = get_word2idx_idx2word(vocab)
return train_df, val_df, word2idx, idx2word
def main():
choice = sys.argv[1]
load_path = None
try:
load_path = sys.argv[2]
except:
if choice in ['val', 'test'] and load_path is None:
print('Please specify some path to load model weights from.')
return
test_file = None
try:
test_file = sys.argv[3]
if choice in ['train', 'val']:
print('This command is not supported.')
return
except:
if choice == 'test':
print('Please provide path to test csv file containing abstracts only.')
return
print('Choice: ', choice)
if load_path is not None:
print('Load path: ', load_path)
if test_file is not None:
print('Test file: ', test_file)
train_df, val_df, word2idx, idx2word = preprocess_data('./data/abstracts.pkl', './data/titles.pkl')
print('Preparing embedding matrix ...')
emb_matrix = get_embedding_matrix(word2idx, idx2word, './data/glove_vectors.txt', 'glove')
summarization = Summarization(emb_matrix, emb_dim=300, hidden_dim=128, word2idx=word2idx, idx2word=idx2word)
print('Preparing Input data ...')
if choice == 'train':
train_data = prepare_input_data(train_df, word2idx)
summarization.train(train_data, use_prev=load_path)
elif choice == 'val':
eval_data = prepare_input_data(val_df, word2idx)
summarization.eval(eval_data, val_df, load_path=load_path, evaluation='val', print_samples=True)
elif choice == 'test':
test_df, eval_data = prepare_test_data(test_file, word2idx)
summarization.eval(eval_data, test_df, load_path=load_path, evaluation='test', print_samples=True)
else:
print('This command is not supported.')
return
if __name__ == '__main__':
main()