-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdistribution.py
More file actions
85 lines (60 loc) · 1.91 KB
/
distribution.py
File metadata and controls
85 lines (60 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.distributions as tdist
import math
class DistributionInterface():
def get_scale(self, state):
assert(False)
def sample(self, mu, state):
assert(False)
def log_density(self, act, mu, state):
assert(False)
def train(self):
pass
def eval(self):
pass
def zero_grad(self):
pass
def step(self):
pass
def set_ckpt(self, ckpt):
pass
def get_ckpt(self):
return {}
class FixedGaussianDistribution(DistributionInterface):
def __init__(self, dist, scale):
self.dist = dist
self.scale = scale
def get_scale(self, state):
return self.scale
def sample(self, mu, state):
return self.dist(mu, self.scale).sample()
def log_prob(self, act, mu, state):
return self.dist(mu, self.scale).log_prob(act)
def entropy(self, mu, state):
return self.dist(mu, self.scale).entropy()
class NetGaussianDistribution(DistributionInterface):
def __init__(self, dist, network, opt):
self.dist = dist
self.network = network
self.opt = opt
def get_scale(self, state):
return torch.exp(self.network(state.cuda()).cpu())
def sample(self, mu, state):
return self.dist(mu, self.get_scale(state)).sample()
def log_prob(self, act, mu, state):
return self.dist(mu, self.get_scale(state)).log_prob(act)
def entropy(self, mu, state):
return self.dist(mu, self.get_scale(state)).entropy()
def train(self):
self.network.train()
def eval(self):
self.network.eval()
def zero_grad(self):
self.opt.zero_grad()
def step(self):
self.opt.step()
def set_ckpt(self, ckpt):
assert('netdist' in ckpt)
self.network.load_state_dict(ckpt['netdist'])
def get_ckpt(self):
return {'netdist' : self.network.state_dict()}