-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
119 lines (104 loc) · 3.28 KB
/
utils.py
File metadata and controls
119 lines (104 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Utility Functions
Date:
- Jan. 28, 2023
Resources:
- https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#the-seq2seq-model
-
"""
import time
import math
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np
from params import *
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"}
self.n_words = 2 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
# Data Preprocessing
# def unicodeToAscii(s):
# # Turn a Unicode string to plain ASCII, thanks to
# # https://stackoverflow.com/a/518232/2809427
# return ''.join(
# c for c in unicodedata.normalize('NFD', s)
# if unicodedata.category(c) != 'Mn'
# )
#
# def normalizeString(s):
# # Lowercase, trim, and remove non-letter characters
# s = unicodeToAscii(s.lower().strip())
# s = re.sub(r"([.!?])", r" \1", s)
# s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
# return s
#
#
# def filterPair(p):
# return len(p[0].split(' ')) < MAX_LENGTH and \
# len(p[1].split(' ')) < MAX_LENGTH and \
# p[1].startswith(eng_prefixes)
#
#
# def filterPairs(pairs):
# return [pair for pair in pairs if filterPair(pair)]
#
#
# def readLangs(data_path, lang1, lang2, reverse=False):
# print("Reading lines...")
# # Read the file and split into lines
# lines = open(data_path, encoding='utf-8').\
# read().strip().split('\n')
#
# # Split every line into pairs and normalize
# pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
#
# # Reverse pairs, make Lang instances
# if reverse:
# pairs = [list(reversed(p)) for p in pairs]
# input_lang = Lang(lang2)
# output_lang = Lang(lang1)
# else:
# input_lang = Lang(lang1)
# output_lang = Lang(lang2)
#
# return input_lang, output_lang, pairs
#
#
# def prepareData(data_path, lang1, lang2, reverse=False):
# input_lang, output_lang, pairs = readLangs(data_path, lang1, lang2, reverse)
# print("Read %s sentence pairs" % len(pairs))
# pairs = filterPairs(pairs)
# print("Trimmed to %s sentence pairs" % len(pairs))
# print("Counting words...")
# for pair in pairs:
# input_lang.addSentence(pair[0])
# output_lang.addSentence(pair[1])
# print("Counted words:")
# print("Input Lang.: {}, Number of words: {}".format(input_lang.name, str(input_lang.n_words)))
# print("Output Lang.: {}, Number of words: {}".format(output_lang.name, str(output_lang.n_words)))
# return input_lang, output_lang, pairs