-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
179 lines (148 loc) · 6.66 KB
/
models.py
File metadata and controls
179 lines (148 loc) · 6.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
import numpy as np
from train_tools import l1_reg
class VPC_RNN(torch.nn.Module):
""" RNN model for the variational position reconstruction task
"""
def __init__(self, params, device = None, **kwargs):
"""
Args:
params (dict): dictionary of model parameters
device (str, optional): device to send torch tensors to.
Typically "cuda" for training, and "cpu" for inference.
Defaults to None.
"""
super().__init__(**kwargs)
self.params = params
# Use cuda if available
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
n_in = int(2 + params["context"] * 6) # 6 possible contexts, 2 velocities
# initialize layers
self.g = torch.nn.RNN(
input_size = n_in,
hidden_size = params["nodes"],
nonlinearity= "relu",
bias=False,
batch_first=True, device = self.device)
torch.nn.init.eye_(self.g.weight_hh_l0) # identity initialization
self.p = torch.nn.Linear(params["nodes"], params["outputs"], bias=False, device = self.device)
self.eps = torch.tensor(1e-14) # small epsilon for center of mass estimate
self.activation = torch.nn.ReLU()
self.loss_fn = torch.nn.MSELoss()
self.to(device)
self.optimizer = torch.optim.Adam(self.parameters(), lr=params["lr"])
def reset_state(self, shape):
return torch.zeros(shape, device = self.device)
def decode_phases(self, p, r):
pt = p/torch.maximum(self.eps, torch.sum(p, dim = -2, keepdim = True))
mu = torch.sum(r[...,None, :]*pt[...,None], dim = 1)
return mu
def decode_position(self, p, mu):
po = p/(torch.maximum(self.eps, torch.sum(p, dim = -1, keepdim = True)))
rhat = torch.sum(mu[:,None]*po[...,None], dim = -2)
return rhat
def forward(self, inputs, g_prev=None):
v = inputs[0] # input signal
r = inputs[1] # true position
if g_prev is None:
initial_state = self.reset_state((v.shape[0], self.params["nodes"]))
else:
initial_state = g_prev
g, _ = self.g(v, initial_state[None])
p = self.activation(self.p(g))
mu = self.decode_phases(p, r)
yhat = self.decode_position(p, mu)
return yhat, g, p, mu
def train_step(self, x, y, g_prev = None):
self.optimizer.zero_grad(set_to_none=True)
yhat, g, p, mu = self(x, g_prev)
activity_reg = l1_reg(self.params["al1"], g)
loss = self.loss_fn(yhat, y) + activity_reg
# parameter update
loss.backward()
self.optimizer.step()
return loss, yhat, g
def val_step(self, x, y, g_prev = None):
# val step and train step are equal, except for gradient
with torch.no_grad():
yhat, g, p, mu = self(x, g_prev)
activity_reg = l1_reg(self.params["al1"], g)
val_loss = self.loss_fn(yhat, y) + activity_reg
return val_loss, yhat, g
def inference(self, dataset):
"""Run model in inference mode, returning metrics *and* states
Args:
dataset : iterable of length N, returning input output pairs of torch tensors
should contain tensors of ((v, r), r) with shape
(((BS, T, Nin), (BS, T, 2)), (BS, T, 2)),
where BS is the batch size, T the number of timesteps
and Nin the number of inputs to the recurrent layer
(velocities + optional border + context signals)
Returns:
gs (np.ndarray): Recurrent states for each sample in dataset, shape (N, BS, T, Ng)
ps (np.ndarray): Output states for each sample in dataset, shape (N, BS, T, Np)
centers (np.ndarray): Center estimate for each sample in dataset, shape (N, Bs, Np, 2)
preds (np.ndarray): Model predictions for each sample in dataset, shape (N, BS, T, 2)
metrics (dict): Contains lists of inference metrics for each sample.
"""
gs = []
ps = []
centers = []
preds = []
with torch.no_grad():
rnn_state = None # sample initial state
for i, (x, y_true) in enumerate(dataset):
reset_state = (i % self.params["reset_interval"]) == 0
if reset_state:
rnn_state = None
else:
rnn_state = g[:, -1].detach().clone().to(self.device)
y_pred, g, p, center = self(x, rnn_state)
gs.append(g.cpu().numpy())
ps.append(p.cpu().numpy())
centers.append(center.cpu().numpy())
preds.append(y_pred.cpu().numpy())
# concat in same order
return [np.concatenate(var, axis = 0) for var in [gs, ps, centers, preds]]
class VPC_FF(VPC_RNN):
""" Feed Forward deep network for the variational position reconstruction task
"""
def __init__(self, params, device = None, **kwargs):
"""
Args:
params (dict): dictionary of model parameters
device (str, optional): device to send torch tensors to.
Typically "cuda" for training, and "cpu" for inference.
Defaults to None.
"""
super().__init__(params, device, **kwargs)
n_in = int(2 + params["context"] * 6) # 6 possible contexts, 2 velocities
self.g = torch.nn.Sequential(
torch.nn.Linear(n_in, 64, bias = False),
torch.nn.ReLU(),
torch.nn.Linear(64, 128, bias = False),
torch.nn.ReLU(),
torch.nn.Linear(128, params["nodes"], bias = False),
torch.nn.ReLU())
self.p = torch.nn.Linear(params["nodes"], params["outputs"], bias = False)
self.to(self.device)
self.optimizer = torch.optim.Adam(self.parameters(), lr=params["lr"])
def forward(self, inputs, g_prev = None):
"""
Forward pass of the model.
Parameters:
inputs (list): List of input tensors.
g_prev (Tensor, optional): Previous hidden state tensor. Not used for FF network
Returns:
Tuple: A tuple containing the predicted output tensor (yhat), the current hidden state (g), the activation (p), and the decoded center (mu).
"""
rc = inputs[0]
r = inputs[1]
g = self.g(rc)
p = self.activation(self.p(g))
mu = self.decode_phases(p, r)
yhat = self.decode_position(p, mu)
return yhat, g, p, mu