-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
390 lines (321 loc) · 14.5 KB
/
main.py
File metadata and controls
390 lines (321 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import argparse
import torch
import os
import time
import sys
import numpy as np
from collections import deque
import random
from misc.utils import create_dir, dump_to_json, CSVWriter, get_action_info
from misc.params_info import get_config
from misc.torch_utility import get_state
from oailibs import logger
import gym
import d4rl
parser = argparse.ArgumentParser()
# Optim params
parser.add_argument('--p_lr', type=float, default=0.0003, help = 'Learning rate for policy network')
parser.add_argument('--q_lr', type=float, default=0.0007, help = 'Learning rate for Q network')
parser.add_argument('--ptau', type=float, default=0.005 , help = 'Interpolation factor in polyak averaging')
parser.add_argument('--gamma', type=float, default=0.99, help = 'Discount factor [0,1]')
parser.add_argument("--max_timesteps", default=1e6, type=float, help = 'Total number of timesteps to train on')
parser.add_argument("--batch_size", default=256, type=int, help = 'Batch size for both actor and critic')
parser.add_argument('--hidden_sizes', nargs='+', type=int, default = [256, 256, 256], help = 'indicates hidden size actor/critic')
# General params
parser.add_argument('--env_name', type=str, default='halfcheetah-random-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--alg_name', type=str, default='cdc')
parser.add_argument('--disable_cuda', default=False, action='store_true')
parser.add_argument('--cuda_deterministic', default=False, action='store_true')
parser.add_argument("--gpu_id", default=0, type=int)
parser.add_argument('--set_num_threads', default=True, action='store_true', help='set_num_threads')
parser.add_argument('--log_id', default='dummy')
parser.add_argument('--check_point_dir', default='./ck')
parser.add_argument('--log_dir', default='./log_dir')
parser.add_argument('--save_freq', type=int, default = 1000)
parser.add_argument("--eval_freq", default =5e3, type=float, help = 'How often (time steps) we evaluate')
parser.add_argument("--num_trains_per_train_loop",default=1e3, type=float, help = 'Number of training iterations')
parser.add_argument('--num_evals', type=int, default = 10, help ='Length eval episode')
parser.add_argument('--fixed_seed_part', type=int, default = 100, help =' A fixed seed is used for the eval environment')
# sac hyper params
parser.add_argument('--LOG_STD_MAX', type=float, default=2)
parser.add_argument('--LOG_STD_MIN', type=float, default=-5)
parser.add_argument('--eta_coef', type=float, default=1, help = 'eta coef')
parser.add_argument('--lambda_coef', type=float, default=1, help = 'lambda coef')
parser.add_argument('--number_of_qs', type=int, default=4, help = 'Numebr of qs 2|3|4')
parser.add_argument('--num_samples', type=int, default=15, help = 'Numebr of action samples')
parser.add_argument('--nu', default=0.75, type=float, help = 'nu for convex combination of Qs')
parser.add_argument("--start_more_eval", default =800000, type=float, help = 'More eval start step')
parser.add_argument("--fast_eval_freq", default =1e3, type=float, help = 'How often (time steps) we evaluate')
def take_snapshot(args, ck_fname_part, model, update):
'''
This fucntion just save the current model and save some other info
'''
fname_ck = ck_fname_part + '.pt'
fname_json = ck_fname_part + '.json'
curr_state_actor = get_state(model.actor)
curr_state_critic = get_state(model.critic)
print('Saving a checkpoint for iteration %d in %s' % (update, fname_ck))
checkpoint = {
'args': args.__dict__,
'model_states_actor': curr_state_actor,
'model_states_critic': curr_state_critic,
}
torch.save(checkpoint, fname_ck)
del checkpoint['model_states_actor']
del checkpoint['model_states_critic']
del curr_state_actor
del curr_state_critic
dump_to_json(fname_json, checkpoint)
def setup_logAndCheckpoints(args):
# create folder if not there
create_dir(args.check_point_dir)
fname = str.lower(args.env_name) + '_' + args.alg_name + '_' + args.log_id
fname_log = os.path.join(args.log_dir, fname)
fname_eval = os.path.join(fname_log, 'eval.csv')
return os.path.join(args.check_point_dir, fname), fname_log, fname_eval
def print_model_info(models):
'''
models is [actor, critic]
'''
total_params = 0
total_params_Trainable = 0
for model in models:
for i in model.parameters():
total_params += np.prod(i.size())
if (i.requires_grad == True):
total_params_Trainable += np.prod(i.size())
print(model)
# since there are target actor and actor, critic and target crtitc ==> total_params * 4
print("Total number of ALL parameters: %d" % (total_params))
print("Total number of TRAINABLE parameters: %d" % (total_params_Trainable))
def make_env(eparams, fixed_seed_part=0):
'''
This function builds env
'''
import misc.env as env_fn
env = env_fn.build_env(
seed = eparams.seed,
env_name = eparams.env_name,
fixed_seed_part = fixed_seed_part,
)
return env
def evaluate_policy(policy, eps_num, itr, eparams):
'''
runs policy for X episodes and returns average reward
'''
num_evals = eparams.num_evals
avg_reward = 0.
# Follow BCQ for seeding of eval_env
eval_env = make_env(eparams, fixed_seed_part = eparams.fixed_seed_part)
for _ in range(num_evals):
obs = eval_env.reset()
done = False
while not done:
action = policy.select_action(np.array(obs), deterministic=True)
obs, reward, done, _ = eval_env.step(action)
avg_reward += reward
avg_reward /= num_evals
print("---------------------------------------")
print("Evaluation over %d episodes in episode num %d and update %d: %.3f" % (num_evals, eps_num, itr, avg_reward))
print("---------------------------------------")
return avg_reward
if __name__ == "__main__":
args = parser.parse_args()
print('------------')
print(args.__dict__)
print('------------')
##############################
#### Generic setups
##############################
CUDA_AVAL = torch.cuda.is_available()
if not args.disable_cuda and CUDA_AVAL:
gpu_id = "cuda:" + str(args.gpu_id)
device = torch.device(gpu_id)
print("**** Yayy we use GPU %s ****" % gpu_id)
else:
device = torch.device('cpu')
print("**** No GPU detected or GPU usage is disabled, sorry! ****")
if not args.disable_cuda and CUDA_AVAL and args.cuda_deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if args.set_num_threads == True:
torch.set_num_threads(1)
####
# train and evalution checkpoints, log folders, ck file names
create_dir(args.log_dir, cleanup = True)
# create folder for save checkpoints
ck_fname_part, log_file_dir, fname_csv_eval = setup_logAndCheckpoints(args)
logger.configure(dir = log_file_dir)
wrt_csv_eval = None
##############################
#### Init env, model, alg, batch generator etc
#### Step 1: build env
#### Step 2: Build model
#### Step 3: Initiate Alg e.g. d4pg
#### Step 4: Initiate batch/rollout generator
##############################
##### env setup
env = make_env(args)
######### SEED ##########
# build_env already calls set seed,
# Set seed the RNG for all devices (both CPU and CUDA)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
if not args.disable_cuda and CUDA_AVAL and args.cuda_deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print("****** cudnn.deterministic is set ******")
######### Build Networks
max_action = float(env.action_space.high[0])
m_list_p = []
if str.lower(args.alg_name) == 'cdc':
from models.networks import ActorSAC as netActor
from models.networks import CriticSACMulti as netCritic
actor_net = netActor(
action_space = env.action_space,
hidden_sizes =args.hidden_sizes,
input_dim = env.observation_space.shape,
max_action = max_action,
LOG_STD_MAX = args.LOG_STD_MAX,
LOG_STD_MIN = args.LOG_STD_MIN
).to(device)
critic_net = netCritic(
action_space = env.action_space,
hidden_sizes =args.hidden_sizes,
input_dim = env.observation_space.shape,
number_of_qs = args.number_of_qs,
).to(device)
m_list_p.append(actor_net)
m_list_p.append(critic_net)
else:
raise ValueError("%s model is not supported for %s env" % (args.env_name, env.observation_space.shape))
print_model_info(m_list_p)
print('-----------------------------')
print("Name of env:", args.env_name)
print("Observation_space:", env.observation_space )
print("Action space:", env.action_space )
print('----------------------------')
##### algorithm setup
if str.lower(args.alg_name) == 'cdc':
import algs.CDC.cdc as alg
alg = alg.CDC(
actor = actor_net,
critic = critic_net,
p_lr = args.p_lr,
q_lr = args.q_lr,
gamma = args.gamma,
ptau = args.ptau,
batch_size = args.batch_size,
max_action = max_action,
num_samples = args.num_samples,
action_dims = env.action_space.shape,
max_timesteps = args.max_timesteps,
eta_coef = get_config(args.env_name)['eta_coef'],
nu = args.nu,
lambda_coef = get_config(args.env_name)['lambda_coef'],
device = device,
)
else:
raise ValueError("%s alg is not supported" % args.alg_name)
##############################
# init replay buffer and load it
##############################
from misc.buffer import Buffer
from misc.loader_batch import BLoader
replay_buffer = Buffer(state_dim = env.observation_space.shape[0],
action_dim = get_action_info(env.action_space)[0],
)
print("There are %d sample in RB (before loading)" % (replay_buffer.size_rb()))
loader = BLoader(env=env,
replay_buffer=replay_buffer,
device=device)
loader.fillup()
print("There are %d sample in RB (after loading)" % (replay_buffer.size_rb()))
##############################
# Train and eval
#############################
# define some req vars
# just to keep params
take_snapshot(args, ck_fname_part, alg, 0)
episode_num = 0
update_iter = 0
timesteps_since_eval = 0
# Evaluate untrained policy
init_eval_res = evaluate_policy(alg, episode_num, update_iter, eparams=args)
eval_results = [init_eval_res]
wrt_csv_eval = CSVWriter(fname_csv_eval, {'total_timesteps':update_iter,
'eprewmean':eval_results[0],
'episode_num':episode_num})
wrt_csv_eval.write({'total_timesteps':update_iter,
'eprewmean':eval_results[0],
'episode_num':episode_num})
# Start total timer
tstart = time.time()
##############################
# Main loop train
#############################
print("Start main loop ...")
while update_iter < args.max_timesteps:
#######
# Interact and collect data until reset
#######
update_iter += args.num_trains_per_train_loop
episode_num += 1
#######
# run training to calculate loss, run backward, and update params
#######
alg_stats = alg.train(replay_buffer = replay_buffer,
iterations = int(args.num_trains_per_train_loop)
)
#######
# logging
#######
nseconds = time.time() - tstart
# Calculate the fps (frame per second)
fps = int(( update_iter) / nseconds)
logger.record_tabular("nupdates", update_iter)
logger.record_tabular("fps", fps)
logger.record_tabular("critic_loss", float(alg_stats['critic_loss']))
logger.record_tabular("actor_loss", float(alg_stats['actor_loss']))
logger.record_tabular("episode_num", episode_num)
if 'qs_loss' in alg_stats:
logger.record_tabular("qs_loss", float(alg_stats['qs_loss']))
if 'a_loss' in alg_stats:
logger.record_tabular("a_loss", float(alg_stats['a_loss']))
if 'min_qs' in alg_stats:
logger.record_tabular("min_qs", float(alg_stats['min_qs']))
if 'max_qs' in alg_stats:
logger.record_tabular("max_qs", float(alg_stats['max_qs']))
if 'avg_qs' in alg_stats:
logger.record_tabular("avg_qs", float(alg_stats['avg_qs']))
if 'std_qs' in alg_stats:
logger.record_tabular("std_qs", float(alg_stats['std_qs']))
if 'lg_loss' in alg_stats:
logger.record_tabular("lg_loss", float(alg_stats['lg_loss']))
logger.dump_tabular()
#######
# run eval
#######
if update_iter > args.start_more_eval:
args.eval_freq = args.fast_eval_freq
timesteps_since_eval += args.num_trains_per_train_loop
if timesteps_since_eval >= args.eval_freq:
timesteps_since_eval %= args.eval_freq
eval_temp = evaluate_policy(alg, episode_num, update_iter, eparams=args)
eval_results.append(eval_temp)
wrt_csv_eval.write({'total_timesteps':update_iter,
'eprewmean':eval_temp,
'episode_num':episode_num})
#######
# save for every interval-th episode or for the last epoch
#######
if (episode_num % args.save_freq == 0 or episode_num == args.max_timesteps - 1):
take_snapshot(args, ck_fname_part, alg, update_iter)
###############
# Done
###############
wrt_csv_eval.close()
print('Done')