-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
48 lines (37 loc) · 981 Bytes
/
train.py
File metadata and controls
48 lines (37 loc) · 981 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import glob
from encoder import Encoder
from model import SequenceModel, IdentifiersModel
from scrapper import get_data
IDENTIFIERS_TEST_CASES = [
["i", "m", "p"],
["R", "a", "c", "t"],
["C", "h", "i"]
]
SEQUENCES_TEST_CASES = [
["import", "React", "from"],
["import", "type"],
["import", "{"]
]
def train(model_type):
data = get_data(model_type)
encoder = Encoder(data)
model = None
test_cases = None
iterations = 20
if model_type == "sequences":
iterations = 100
test_cases = SEQUENCES_TEST_CASES
model = SequenceModel(model_type, encoder)
else:
test_cases = IDENTIFIERS_TEST_CASES
model = IdentifiersModel(model_type, encoder)
try:
model.train(test_cases=test_cases, iterations=iterations)
except KeyboardInterrupt:
pass
print("saving")
model.save()
if __name__ == "__main__":
import sys
model_type = sys.argv[1]
train(model_type)