-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_train.py
More file actions
60 lines (46 loc) · 1.84 KB
/
model_train.py
File metadata and controls
60 lines (46 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import torch
import hamiltorch
from config import Config
import norm_flows
import distribution as dist
import matplotlib.pyplot as plt
conf = Config()
n_samples = conf.batch_size
step_size = 0.3
num_steps_per_sample = 10
hamiltorch.set_random_seed(131)
params_init = torch.zeros(2)
def sample(d):
params_hmc = hamiltorch.sample(log_prob_func=d.log_prob, params_init=params_init, num_samples=n_samples, step_size=step_size, num_steps_per_sample=num_steps_per_sample)
return params_hmc
def inverse_temperature(epoch):
return min(1, 0.01 + epoch/conf.epoch_num)
def main():
d = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
vae_model = norm_flows.VaeNormalizingFlow(conf.x_dim, conf.hidden_dim, conf.window_size, conf.z_dim, conf.flow_num, conf.flow_type)
opt = torch.optim.Adam(vae_model.parameters(), lr=1e-3, amsgrad=True)
losses = []
for epoch in range(conf.epoch_num):
samples = [d.sample().unsqueeze(0) for _ in range(conf.batch_size)]
data = torch.cat(samples, dim=0)
z, log_q_z, log_likelihood = vae_model(data)
beta = inverse_temperature(epoch)
loss = vae_model.free_energy(z, log_q_z, log_likelihood, beta)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss)
print('epoch: %d loss: %.4f' % (epoch, loss))
print('log_q(z): %.4f log_likelihood: %.4f' % (log_q_z.mean(), log_likelihood.mean()))
plot_loss_moment(losses)
def plot_loss_moment(losses):
_, ax = plt.subplots(figsize=(16, 9), dpi=80)
ax.plot(losses, 'blue', label='train', linewidth=1)
ax.set_title('Loss change in training')
ax.set_ylabel('Loss')
ax.set_xlabel('Iteration')
ax.legend(loc='upper right')
plt.savefig(os.path.join('./output/', 'loss_vae.png'))
if __name__ == "__main__":
main()