3030from brax .training .agents .ars import train as ars
3131from brax .training .agents .es import train as es
3232from brax .training .agents .ppo import networks as ppo_networks
33+ from brax .training .agents .ppo import networks_vision as ppo_networks_vision
3334from brax .training .agents .ppo import optimizer as ppo_optimizer
3435from brax .training .agents .ppo import train as ppo
3536from brax .training .agents .sac import networks as sac_networks
203204 0.5 ,
204205 'Value function loss coefficient for PPO.' ,
205206)
207+ _PPO_CNN_OUTPUT_CHANNELS = flags .DEFINE_string (
208+ 'ppo_cnn_output_channels' , None ,
209+ 'Comma-separated CNN output channels, e.g. "32,64,64".'
210+ )
211+ _PPO_CNN_KERNEL_SIZE = flags .DEFINE_string (
212+ 'ppo_cnn_kernel_size' , None ,
213+ 'Comma-separated CNN kernel sizes, e.g. "8,4,3".'
214+ )
215+ _PPO_CNN_STRIDE = flags .DEFINE_string (
216+ 'ppo_cnn_stride' , None ,
217+ 'Comma-separated CNN strides, e.g. "4,2,1".'
218+ )
219+ _PPO_CNN_PADDING = flags .DEFINE_string (
220+ 'ppo_cnn_padding' , None , 'CNN padding mode: "zeros" or "valid".'
221+ )
222+ _PPO_CNN_ACTIVATION = flags .DEFINE_string (
223+ 'ppo_cnn_activation' , None ,
224+ 'CNN activation function, e.g. "relu" or "elu".'
225+ )
226+ _PPO_CNN_MAX_POOL = flags .DEFINE_bool (
227+ 'ppo_cnn_max_pool' , None , 'Apply 2x2 max-pool after each conv layer.'
228+ )
229+ _PPO_CNN_GLOBAL_POOL = flags .DEFINE_string (
230+ 'ppo_cnn_global_pool' , None ,
231+ 'Global pooling: "avg", "max", "none", or "spatial_softmax".'
232+ )
233+ _PPO_CNN_KERNEL_INIT_FN = flags .DEFINE_string (
234+ 'ppo_cnn_kernel_init_fn' , None ,
235+ 'CNN kernel initializer, e.g. "orthogonal" or "lecun_normal".'
236+ )
237+ _PPO_CNN_KERNEL_INIT_KWARGS = flags .DEFINE_string (
238+ 'ppo_cnn_kernel_init_kwargs' , None ,
239+ 'JSON dict of kwargs for CNN kernel init, e.g. \' {"scale": 1.414}\' .' ,
240+ )
241+ _PPO_OUTPUT_KERNEL_INIT_FN = flags .DEFINE_string (
242+ 'ppo_output_kernel_init_fn' , None ,
243+ 'Output layer kernel initializer, e.g. "orthogonal".'
244+ )
245+ _PPO_OUTPUT_KERNEL_INIT_KWARGS = flags .DEFINE_string (
246+ 'ppo_output_kernel_init_kwargs' , None ,
247+ 'JSON dict of kwargs for output kernel init, e.g. \' {"scale": 0.01}\' .' ,
248+ )
249+
250+ _VISION = flags .DEFINE_bool (
251+ 'vision' , False , 'Enable vision-based (pixel) training.'
252+ )
253+ _AUGMENT_PIXELS = flags .DEFINE_bool (
254+ 'augment_pixels' , False , 'Apply random-translate augmentation to pixels.'
255+ )
256+
206257
207258# ARS hps.
208259_NUMBER_OF_DIRECTIONS = flags .DEFINE_integer (
241292 None ,
242293 'Overrides for the playground env config.' ,
243294)
295+ _WARP_KERNEL_CACHE_DIR = flags .DEFINE_string (
296+ 'warp_kernel_cache_dir' , None ,
297+ 'Directory for caching compiled Warp kernels.' ,
298+ )
244299
245300
246301def get_env_factory (env_name : str ):
@@ -252,6 +307,9 @@ def get_env_factory(env_name: str):
252307 randomizer_fn = None
253308 if _PLAYGROUND_CONFIG_OVERRIDES .value is not None :
254309 overrides = json .loads (_PLAYGROUND_CONFIG_OVERRIDES .value )
310+ if _VISION .value :
311+ overrides ['vision' ] = True
312+ overrides ['vision_config.nworld' ] = _NUM_ENVS .value
255313 if _PLAYGROUND_DM_CONTROL_SUITE .value :
256314 get_environment = lambda * args , ** kwargs : mjp .dm_control_suite .load ( # pytype: disable=attribute-error
257315 * args , ** kwargs , config_overrides = overrides
@@ -277,6 +335,10 @@ def get_env_factory(env_name: str):
277335
278336
279337def main (unused_argv ):
338+ if _WARP_KERNEL_CACHE_DIR .value is not None :
339+ import warp as wp
340+ wp .config .kernel_cache_dir = _WARP_KERNEL_CACHE_DIR .value
341+
280342 logdir = _LOGDIR .value
281343
282344 ckpt_dir = epath .Path (logdir ) / 'checkpoints'
@@ -339,14 +401,78 @@ def main(unused_argv):
339401 progress_fn = writer .write_scalars ,
340402 )
341403 elif _LEARNER .value == 'ppo' :
342- network_factory = ppo_networks .make_ppo_networks
404+ if _VISION .value :
405+ network_factory = ppo_networks_vision .make_ppo_networks_vision
406+ else :
407+ network_factory = ppo_networks .make_ppo_networks
343408 network_factory = functools .partial (
344409 network_factory ,
345410 distribution_type = _PPO_DISTRIBUTION_TYPE .value ,
346411 noise_std_type = _PPO_NOISE_STD_TYPE .value ,
347412 init_noise_std = _PPO_INIT_NOISE_STD .value ,
348413 activation = brax_networks .ACTIVATION [_PPO_ACTIVATION_FN .value ],
349414 )
415+ if _PPO_CNN_OUTPUT_CHANNELS .value is not None :
416+ network_factory = functools .partial (
417+ network_factory ,
418+ cnn_output_channels = [
419+ int (x ) for x in _PPO_CNN_OUTPUT_CHANNELS .value .split (',' )
420+ ],
421+ )
422+ if _PPO_CNN_KERNEL_SIZE .value is not None :
423+ network_factory = functools .partial (
424+ network_factory ,
425+ cnn_kernel_size = [
426+ int (x ) for x in _PPO_CNN_KERNEL_SIZE .value .split (',' )
427+ ],
428+ )
429+ if _PPO_CNN_STRIDE .value is not None :
430+ network_factory = functools .partial (
431+ network_factory ,
432+ cnn_stride = [
433+ int (x ) for x in _PPO_CNN_STRIDE .value .split (',' )
434+ ],
435+ )
436+ if _PPO_CNN_PADDING .value is not None :
437+ network_factory = functools .partial (
438+ network_factory , cnn_padding = _PPO_CNN_PADDING .value ,
439+ )
440+ if _PPO_CNN_ACTIVATION .value is not None :
441+ network_factory = functools .partial (
442+ network_factory , cnn_activation = _PPO_CNN_ACTIVATION .value ,
443+ )
444+ if _PPO_CNN_MAX_POOL .value is not None :
445+ network_factory = functools .partial (
446+ network_factory , cnn_max_pool = _PPO_CNN_MAX_POOL .value ,
447+ )
448+ if _PPO_CNN_GLOBAL_POOL .value is not None :
449+ network_factory = functools .partial (
450+ network_factory , cnn_global_pool = _PPO_CNN_GLOBAL_POOL .value ,
451+ )
452+ if _PPO_CNN_KERNEL_INIT_FN .value is not None :
453+ network_factory = functools .partial (
454+ network_factory ,
455+ cnn_kernel_init_fn = _PPO_CNN_KERNEL_INIT_FN .value ,
456+ )
457+ if _PPO_CNN_KERNEL_INIT_KWARGS .value is not None :
458+ network_factory = functools .partial (
459+ network_factory ,
460+ cnn_kernel_init_kwargs = json .loads (
461+ _PPO_CNN_KERNEL_INIT_KWARGS .value
462+ ),
463+ )
464+ if _PPO_OUTPUT_KERNEL_INIT_FN .value is not None :
465+ network_factory = functools .partial (
466+ network_factory ,
467+ output_kernel_init_fn = _PPO_OUTPUT_KERNEL_INIT_FN .value ,
468+ )
469+ if _PPO_OUTPUT_KERNEL_INIT_KWARGS .value is not None :
470+ network_factory = functools .partial (
471+ network_factory ,
472+ output_kernel_init_kwargs = json .loads (
473+ _PPO_OUTPUT_KERNEL_INIT_KWARGS .value
474+ ),
475+ )
350476 if _PPO_POLICY_HIDDEN_LAYER_SIZES .value is not None :
351477 policy_hidden_layer_sizes = [
352478 int (x ) for x in _PPO_POLICY_HIDDEN_LAYER_SIZES .value .split (',' )
@@ -401,6 +527,8 @@ def main(unused_argv):
401527 desired_kl = _PPO_DESIRED_KL .value ,
402528 num_resets_per_eval = _NUM_RESETS_PER_EVAL .value ,
403529 progress_fn = writer .write_scalars ,
530+ vision = _VISION .value ,
531+ augment_pixels = _AUGMENT_PIXELS .value ,
404532 save_checkpoint_path = ckpt_dir .as_posix (),
405533 restore_checkpoint_path = _RESTOREDIR .value ,
406534 vf_loss_coefficient = _PPO_VF_LOSS_COEFFICIENT .value ,
@@ -455,49 +583,50 @@ def main(unused_argv):
455583 get_environment , * _ = get_env_factory (_ENV .value )
456584 env = get_environment (_ENV .value )
457585
458- def do_rollout (rng , state ):
586+ def do_rollout_batched (rng , states ):
459587 data_attr_name = 'pipeline_state' if hasattr (env , 'sys' ) else 'data'
460- empty_data = getattr (state , data_attr_name ).__class__ (
461- ** {k : None for k in getattr (state , data_attr_name ).__annotations__ }
588+ empty_data = getattr (states , data_attr_name ).__class__ (
589+ ** {k : None for k in getattr (states , data_attr_name ).__annotations__ }
462590 ) # pytype: disable=attribute-error
463- empty_traj = state .__class__ (** {k : None for k in state .__annotations__ }) # pytype: disable=attribute-error
591+ empty_traj = states .__class__ (** {k : None for k in states .__annotations__ }) # pytype: disable=attribute-error
464592 empty_traj = empty_traj .replace (** {data_attr_name : empty_data })
465593
466594 def step (carry , _ ):
467- state , rng = carry
468- rng , act_key = jax .random .split (rng )
469- act = make_policy (params )(state .obs , act_key )[0 ]
470- state = env .step (state , act )
471- if hasattr (state , 'data' ):
472- # select a sub-set of the data for playground envs
595+ states , rng = carry
596+ rng_split = jax .vmap (jax .random .split )(rng )
597+ rng = rng_split [:, 0 ]
598+ act_key = rng_split [:, 1 ]
599+ act = jax .vmap (make_policy (params ))(states .obs , act_key )[0 ]
600+ states = jax .vmap (env .step )(states , act )
601+ if hasattr (states , 'data' ):
473602 traj_data = empty_traj .tree_replace ({
474- 'data.qpos' : state .data .qpos ,
475- 'data.qvel' : state .data .qvel ,
476- 'data.time' : state .data .time ,
477- 'data.ctrl' : state .data .ctrl ,
478- 'data.mocap_pos' : state .data .mocap_pos ,
479- 'data.mocap_quat' : state .data .mocap_quat ,
480- 'data.xfrc_applied' : state .data .xfrc_applied ,
603+ 'data.qpos' : states .data .qpos ,
604+ 'data.qvel' : states .data .qvel ,
605+ 'data.time' : states .data .time ,
606+ 'data.ctrl' : states .data .ctrl ,
607+ 'data.mocap_pos' : states .data .mocap_pos ,
608+ 'data.mocap_quat' : states .data .mocap_quat ,
609+ 'data.xfrc_applied' : states .data .xfrc_applied ,
481610 })
482- elif hasattr (state , 'pipeline_state' ):
483- # select the entire state for brax envs
611+ elif hasattr (states , 'pipeline_state' ):
484612 traj_data = empty_traj .replace (
485- ** {data_attr_name : getattr (state , data_attr_name )}
613+ ** {data_attr_name : getattr (states , data_attr_name )}
486614 )
487615 else :
488616 raise ValueError (
489- f'Unknown data attribute name: { data_attr_name } on state: { state } .'
617+ f'Unknown data attribute name: { data_attr_name } on state: { states } .'
490618 )
491- return (state , rng ), traj_data
619+ return (states , rng ), traj_data
492620
493621 _ , traj = jax .lax .scan (
494- step , (state , rng ), None , length = _EPISODE_LENGTH .value
622+ step , (states , rng ), None , length = _EPISODE_LENGTH .value
495623 )
496624 return traj
497625
498626 rng = jax .random .split (jax .random .PRNGKey (_SEED .value ), _NUM_VIDEOS .value )
499627 reset_states = jax .jit (jax .vmap (env .reset ))(rng )
500- traj_stacked = jax .jit (jax .vmap (do_rollout ))(rng , reset_states )
628+ traj_stacked = jax .jit (do_rollout_batched )(rng , reset_states )
629+ traj_stacked = jax .tree .map (lambda x : jax .numpy .moveaxis (x , 0 , 1 ), traj_stacked )
501630 trajectories = [None ] * _NUM_VIDEOS .value
502631 for i in range (_NUM_VIDEOS .value ):
503632 t = jax .tree .map (lambda x , i = i : x [i ], traj_stacked )
0 commit comments