-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSoftQAgent.py
More file actions
146 lines (123 loc) · 6.03 KB
/
SoftQAgent.py
File metadata and controls
146 lines (123 loc) · 6.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Optional
import gymnasium
import numpy as np
import torch
from Architectures import make_min_discrete_action_critic, make_mlp
from BaseAgent import BaseAgent, get_new_params
from callbacks import AUCCallback
from utils import polyak, check_polyak_tau, prepare_online_and_target
from Logger import WandBLogger, TensorboardLogger
class SoftQAgent(BaseAgent):
def __init__(self,
*args,
gamma: float = 0.99,
beta: float = 5.0,
use_target_network: bool = False,
target_update_interval: Optional[int] = None,
polyak_tau: Optional[float] = None,
architecture_kwargs: dict = {},
**kwargs,
):
super().__init__(*args, **kwargs)
self.kwargs = get_new_params(self, locals())
self.algo_name = 'SQL'
self.gamma = gamma
self.beta = beta
self.use_target_network = use_target_network
self.target_update_interval = target_update_interval
self.polyak_tau = polyak_tau
self.nA = self.env.action_space.n
# TODO: make this more robust so not necessary each time, i.e. put in baseagent
# Add algo_name and env_str to kwargs for logging
self.kwargs['algo_name'] = self.algo_name
self.kwargs['env_str'] = self.env_str
self.log_hparams(self.kwargs)
self.online_softqs, self.target_softqs = prepare_online_and_target(
use_target_network=self.use_target_network,
architecture=self.architecture,
architecture_kwargs=architecture_kwargs)
check_polyak_tau(self.use_target_network, self.polyak_tau, self.target_update_interval)
# Make (all) qs learnable:
self.optimizer = torch.optim.Adam(self.online_softqs.parameters(), lr=self.learning_rate)
# TODO: allow for non uniform priors
self.log_pi0 = -torch.log(torch.tensor(self.nA))
def exploration_policy(self, state: np.ndarray) -> int:
with torch.no_grad():
# return self.env.action_space.sample()
qvals = self.online_softqs(state)
# calculate boltzmann policy:
qvals = qvals.squeeze()
# sample from logits:
pi = torch.distributions.Categorical(logits = self.beta * qvals + self.log_pi0)
action = pi.sample()
return action.item()
def evaluation_policy(self, state: np.ndarray) -> int:
# Get the greedy action from the q values:
with torch.no_grad():
qvals = self.online_softqs(state).to(device=self.device) + 1 / self.beta * self.log_pi0
qvals = qvals.squeeze()
# return torch.argmax(qvals).item()
pi = torch.distributions.Categorical(logits = self.beta * qvals + self.log_pi0)
action = pi.sample()
return action.item()
def calculate_loss(self, batch):
states, actions, rewards, next_states, dones = batch
actions = actions.long()
dones = dones.float()
curr_softq = self.online_softqs(states).squeeze().gather(1, actions)
with torch.no_grad():
if isinstance(self.env.observation_space, gymnasium.spaces.Discrete):
states = states.squeeze()
next_states = next_states.squeeze()
next_softqs = self.target_softqs(next_states)
next_v = 1/self.beta * (torch.logsumexp(self.beta * next_softqs, dim=-1) + self.log_pi0)
next_v = next_v.reshape(-1, 1)
# Backup equation:
expected_curr_softq = rewards + self.gamma * next_v * (1-dones)
# Calculate the softq ("critic") loss:
loss = 0.5*torch.nn.functional.mse_loss(curr_softq, expected_curr_softq)
# loss += 0.005*torch.nn.functional.mse_loss(curr_softq, curr_softq.detach())
self.log_history("train/online_q_mean", curr_softq.mean().item(), self.learn_env_steps)
# log the loss:
self.log_history("train/loss", loss.item(), self.learn_env_steps)
return loss
def gradient_step(self, grad_step):
# Sample a batch from the replay buffer:
batch = self.buffer.sample(self.batch_size)
loss = self.calculate_loss(batch)
self.optimizer.zero_grad()
# Clip gradient norm
loss.backward()
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.online_softqs.parameters(), self.max_grad_norm)
self.optimizer.step()
def _on_step(self) -> None:
# Periodically update the target network:
if self.use_target_network and self.learn_env_steps % self.target_update_interval == 0:
# Use Polyak averaging as specified:
polyak(self.target_softqs, self.online_softqs, self.polyak_tau)
super()._on_step()
if __name__ == '__main__':
import gymnasium as gym
env = 'CartPole-v1'
logger = TensorboardLogger('logs/acro')
#logger = WandBLogger(entity='jacobhadamczyk', project='test')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# mlp = make_min_discrete_action_critic(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32], device=device)
agent = SoftQAgent(env,
architecture=make_min_discrete_action_critic,
architecture_kwargs={'obs_dim': gym.make(env).observation_space.shape[0],
'n_actions': gym.make(env).action_space.n,
'hidden_dims': [32, 32]},
loggers=(logger,),
learning_rate=0.001,
beta=5,
train_interval=10,
gradient_steps=4,
batch_size=256,
use_target_network=True,
target_update_interval=10,
polyak_tau=1.0,
eval_callbacks=[AUCCallback],
)
agent.learn(total_timesteps=26000)