-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_setup.py
More file actions
67 lines (56 loc) · 2.26 KB
/
model_setup.py
File metadata and controls
67 lines (56 loc) · 2.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys, os
import json
class Parameters:
""" Class for setting all parameters pertaining to an experiment
"""
def __init__(self, loc=None) -> None:
"""Set experiment parameters
Attributes:
loc (str, optional): location of pre-existing parameters to be loaded. If None,
a new dict of experiment parameters is created. Defaults to None.
"""
if loc is None:
# if no load path is given, use default parameters
self.params = {} # experimental parameters
self.params["epochs"] = 100
self.params["batch_size"] = 64
self.params["lr"] = 1e-4 # learning rate
self.params["al1"] = 10.0 # l1 activity regularization
self.params["nodes"] = 500 # number of recurrent nodes
self.params["outputs"] = 100 # number of output nodes
self.params["reset_interval"] = 10 # > 1 is stateful
self.params["context"] = True # whether to give model context signal
else:
# load experimental parameters from file
self.params = self.load_params(loc)
def save_params(self, path):
""" Save class parameters to .json file
Args:
path (str): file location; where to store .json file
"""
with open(f"{path}/model_parameters.json", "w") as f:
json.dump(self.params, f, indent=4)
def load_params(self, path):
""" Load class parameters from .json file
Args:
path (str): File location of JSON model specification
Returns:
loaded_params (dict): experiment parameter dictionary loaded from file
"""
# load parameters from json file
file = f"{path}/model_parameters.json"
with open(file, "r") as f:
loaded_params = json.load(f)
return loaded_params
if __name__ == "__main__":
# Enter path to model, on the form mydir/experiment_name
try:
path = sys.argv[1]
except IndexError:
path = "./VPC"
print(f"No model path given. Default to {path}")
# create directories if they do not exist
if not os.path.exists(path):
os.makedirs(path)
parameters = Parameters()
parameters.save_params(path)