Skip to content

Commit 4597b9c

Browse files
btabacopybara-github
authored andcommitted
Add CNN kernel init and spatial softmax (h/t zakka)
PiperOrigin-RevId: 879785561 Change-Id: I0e9dc87128ecad6f63953f06b533713f4e6994c8
1 parent d2cb645 commit 4597b9c

4 files changed

Lines changed: 224 additions & 29 deletions

File tree

brax/training/agents/ppo/networks_vision.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def make_ppo_networks_vision(
6363
cnn_activation: networks.ActivationFn = linen.relu,
6464
cnn_max_pool: bool = False,
6565
cnn_global_pool: str = 'avg',
66+
cnn_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_normal,
67+
cnn_kernel_init_kwargs: Mapping[str, Any] | None = None,
68+
output_kernel_init_fn: networks.Initializer | None = None,
69+
output_kernel_init_kwargs: Mapping[str, Any] | None = None,
6670
) -> PPONetworks:
6771
"""Make Vision PPO networks with preprocessor.
6872
@@ -94,10 +98,13 @@ def make_ppo_networks_vision(
9498
cnn_activation: activation function or name (e.g. 'elu', 'relu').
9599
cnn_max_pool: whether to apply 2x2 max-pool after each conv layer.
96100
cnn_global_pool: pooling over spatial dims — 'avg', 'max', or 'none'.
101+
cnn_kernel_init_fn: kernel initializer factory for CNN conv layers.
102+
cnn_kernel_init_kwargs: kwargs for CNN kernel init factory.
97103
"""
98104
policy_kernel_init_kwargs = policy_network_kernel_init_kwargs or {}
99105
value_kernel_init_kwargs = value_network_kernel_init_kwargs or {}
100106
mean_kernel_init_kwargs_ = mean_kernel_init_kwargs or {}
107+
cnn_kernel_init_kwargs_ = cnn_kernel_init_kwargs or {}
101108

102109
# Resolve string-based CNN config values.
103110
resolved_padding = _PADDING_MAP.get(
@@ -108,6 +115,19 @@ def make_ppo_networks_vision(
108115
if isinstance(cnn_activation, str)
109116
else cnn_activation
110117
)
118+
resolved_cnn_kernel_init_fn: networks.Initializer = (
119+
networks.KERNEL_INITIALIZER[cnn_kernel_init_fn]
120+
if isinstance(cnn_kernel_init_fn, str)
121+
else cnn_kernel_init_fn
122+
)
123+
output_kernel_init_kwargs_ = output_kernel_init_kwargs or {}
124+
resolved_output_kernel_init_fn = None
125+
if output_kernel_init_fn is not None:
126+
resolved_output_kernel_init_fn = (
127+
networks.KERNEL_INITIALIZER[output_kernel_init_fn]
128+
if isinstance(output_kernel_init_fn, str)
129+
else output_kernel_init_fn
130+
)
111131

112132
parametric_action_distribution: distribution.ParametricDistribution
113133
if distribution_type == 'normal':
@@ -149,6 +169,11 @@ def make_ppo_networks_vision(
149169
cnn_activation=resolved_cnn_activation,
150170
cnn_max_pool=cnn_max_pool,
151171
cnn_global_pool=cnn_global_pool,
172+
cnn_kernel_init=resolved_cnn_kernel_init_fn(**cnn_kernel_init_kwargs_),
173+
output_kernel_init=(
174+
resolved_output_kernel_init_fn(**output_kernel_init_kwargs_)
175+
if resolved_output_kernel_init_fn is not None else None
176+
),
152177
)
153178

154179
value_network = networks.make_value_network_vision(
@@ -166,6 +191,7 @@ def make_ppo_networks_vision(
166191
cnn_activation=resolved_cnn_activation,
167192
cnn_max_pool=cnn_max_pool,
168193
cnn_global_pool=cnn_global_pool,
194+
cnn_kernel_init=resolved_cnn_kernel_init_fn(**cnn_kernel_init_kwargs_),
169195
)
170196

171197
return PPONetworks(

brax/training/learner.py

Lines changed: 154 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from brax.training.agents.ars import train as ars
3131
from brax.training.agents.es import train as es
3232
from brax.training.agents.ppo import networks as ppo_networks
33+
from brax.training.agents.ppo import networks_vision as ppo_networks_vision
3334
from brax.training.agents.ppo import optimizer as ppo_optimizer
3435
from brax.training.agents.ppo import train as ppo
3536
from brax.training.agents.sac import networks as sac_networks
@@ -203,6 +204,56 @@
203204
0.5,
204205
'Value function loss coefficient for PPO.',
205206
)
207+
_PPO_CNN_OUTPUT_CHANNELS = flags.DEFINE_string(
208+
'ppo_cnn_output_channels', None,
209+
'Comma-separated CNN output channels, e.g. "32,64,64".'
210+
)
211+
_PPO_CNN_KERNEL_SIZE = flags.DEFINE_string(
212+
'ppo_cnn_kernel_size', None,
213+
'Comma-separated CNN kernel sizes, e.g. "8,4,3".'
214+
)
215+
_PPO_CNN_STRIDE = flags.DEFINE_string(
216+
'ppo_cnn_stride', None,
217+
'Comma-separated CNN strides, e.g. "4,2,1".'
218+
)
219+
_PPO_CNN_PADDING = flags.DEFINE_string(
220+
'ppo_cnn_padding', None, 'CNN padding mode: "zeros" or "valid".'
221+
)
222+
_PPO_CNN_ACTIVATION = flags.DEFINE_string(
223+
'ppo_cnn_activation', None,
224+
'CNN activation function, e.g. "relu" or "elu".'
225+
)
226+
_PPO_CNN_MAX_POOL = flags.DEFINE_bool(
227+
'ppo_cnn_max_pool', None, 'Apply 2x2 max-pool after each conv layer.'
228+
)
229+
_PPO_CNN_GLOBAL_POOL = flags.DEFINE_string(
230+
'ppo_cnn_global_pool', None,
231+
'Global pooling: "avg", "max", "none", or "spatial_softmax".'
232+
)
233+
_PPO_CNN_KERNEL_INIT_FN = flags.DEFINE_string(
234+
'ppo_cnn_kernel_init_fn', None,
235+
'CNN kernel initializer, e.g. "orthogonal" or "lecun_normal".'
236+
)
237+
_PPO_CNN_KERNEL_INIT_KWARGS = flags.DEFINE_string(
238+
'ppo_cnn_kernel_init_kwargs', None,
239+
'JSON dict of kwargs for CNN kernel init, e.g. \'{"scale": 1.414}\'.',
240+
)
241+
_PPO_OUTPUT_KERNEL_INIT_FN = flags.DEFINE_string(
242+
'ppo_output_kernel_init_fn', None,
243+
'Output layer kernel initializer, e.g. "orthogonal".'
244+
)
245+
_PPO_OUTPUT_KERNEL_INIT_KWARGS = flags.DEFINE_string(
246+
'ppo_output_kernel_init_kwargs', None,
247+
'JSON dict of kwargs for output kernel init, e.g. \'{"scale": 0.01}\'.',
248+
)
249+
250+
_VISION = flags.DEFINE_bool(
251+
'vision', False, 'Enable vision-based (pixel) training.'
252+
)
253+
_AUGMENT_PIXELS = flags.DEFINE_bool(
254+
'augment_pixels', False, 'Apply random-translate augmentation to pixels.'
255+
)
256+
206257

207258
# ARS hps.
208259
_NUMBER_OF_DIRECTIONS = flags.DEFINE_integer(
@@ -241,6 +292,10 @@
241292
None,
242293
'Overrides for the playground env config.',
243294
)
295+
_WARP_KERNEL_CACHE_DIR = flags.DEFINE_string(
296+
'warp_kernel_cache_dir', None,
297+
'Directory for caching compiled Warp kernels.',
298+
)
244299

245300

246301
def get_env_factory(env_name: str):
@@ -252,6 +307,9 @@ def get_env_factory(env_name: str):
252307
randomizer_fn = None
253308
if _PLAYGROUND_CONFIG_OVERRIDES.value is not None:
254309
overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value)
310+
if _VISION.value:
311+
overrides['vision'] = True
312+
overrides['vision_config.nworld'] = _NUM_ENVS.value
255313
if _PLAYGROUND_DM_CONTROL_SUITE.value:
256314
get_environment = lambda *args, **kwargs: mjp.dm_control_suite.load( # pytype: disable=attribute-error
257315
*args, **kwargs, config_overrides=overrides
@@ -277,6 +335,10 @@ def get_env_factory(env_name: str):
277335

278336

279337
def main(unused_argv):
338+
if _WARP_KERNEL_CACHE_DIR.value is not None:
339+
import warp as wp
340+
wp.config.kernel_cache_dir = _WARP_KERNEL_CACHE_DIR.value
341+
280342
logdir = _LOGDIR.value
281343

282344
ckpt_dir = epath.Path(logdir) / 'checkpoints'
@@ -339,14 +401,78 @@ def main(unused_argv):
339401
progress_fn=writer.write_scalars,
340402
)
341403
elif _LEARNER.value == 'ppo':
342-
network_factory = ppo_networks.make_ppo_networks
404+
if _VISION.value:
405+
network_factory = ppo_networks_vision.make_ppo_networks_vision
406+
else:
407+
network_factory = ppo_networks.make_ppo_networks
343408
network_factory = functools.partial(
344409
network_factory,
345410
distribution_type=_PPO_DISTRIBUTION_TYPE.value,
346411
noise_std_type=_PPO_NOISE_STD_TYPE.value,
347412
init_noise_std=_PPO_INIT_NOISE_STD.value,
348413
activation=brax_networks.ACTIVATION[_PPO_ACTIVATION_FN.value],
349414
)
415+
if _PPO_CNN_OUTPUT_CHANNELS.value is not None:
416+
network_factory = functools.partial(
417+
network_factory,
418+
cnn_output_channels=[
419+
int(x) for x in _PPO_CNN_OUTPUT_CHANNELS.value.split(',')
420+
],
421+
)
422+
if _PPO_CNN_KERNEL_SIZE.value is not None:
423+
network_factory = functools.partial(
424+
network_factory,
425+
cnn_kernel_size=[
426+
int(x) for x in _PPO_CNN_KERNEL_SIZE.value.split(',')
427+
],
428+
)
429+
if _PPO_CNN_STRIDE.value is not None:
430+
network_factory = functools.partial(
431+
network_factory,
432+
cnn_stride=[
433+
int(x) for x in _PPO_CNN_STRIDE.value.split(',')
434+
],
435+
)
436+
if _PPO_CNN_PADDING.value is not None:
437+
network_factory = functools.partial(
438+
network_factory, cnn_padding=_PPO_CNN_PADDING.value,
439+
)
440+
if _PPO_CNN_ACTIVATION.value is not None:
441+
network_factory = functools.partial(
442+
network_factory, cnn_activation=_PPO_CNN_ACTIVATION.value,
443+
)
444+
if _PPO_CNN_MAX_POOL.value is not None:
445+
network_factory = functools.partial(
446+
network_factory, cnn_max_pool=_PPO_CNN_MAX_POOL.value,
447+
)
448+
if _PPO_CNN_GLOBAL_POOL.value is not None:
449+
network_factory = functools.partial(
450+
network_factory, cnn_global_pool=_PPO_CNN_GLOBAL_POOL.value,
451+
)
452+
if _PPO_CNN_KERNEL_INIT_FN.value is not None:
453+
network_factory = functools.partial(
454+
network_factory,
455+
cnn_kernel_init_fn=_PPO_CNN_KERNEL_INIT_FN.value,
456+
)
457+
if _PPO_CNN_KERNEL_INIT_KWARGS.value is not None:
458+
network_factory = functools.partial(
459+
network_factory,
460+
cnn_kernel_init_kwargs=json.loads(
461+
_PPO_CNN_KERNEL_INIT_KWARGS.value
462+
),
463+
)
464+
if _PPO_OUTPUT_KERNEL_INIT_FN.value is not None:
465+
network_factory = functools.partial(
466+
network_factory,
467+
output_kernel_init_fn=_PPO_OUTPUT_KERNEL_INIT_FN.value,
468+
)
469+
if _PPO_OUTPUT_KERNEL_INIT_KWARGS.value is not None:
470+
network_factory = functools.partial(
471+
network_factory,
472+
output_kernel_init_kwargs=json.loads(
473+
_PPO_OUTPUT_KERNEL_INIT_KWARGS.value
474+
),
475+
)
350476
if _PPO_POLICY_HIDDEN_LAYER_SIZES.value is not None:
351477
policy_hidden_layer_sizes = [
352478
int(x) for x in _PPO_POLICY_HIDDEN_LAYER_SIZES.value.split(',')
@@ -401,6 +527,8 @@ def main(unused_argv):
401527
desired_kl=_PPO_DESIRED_KL.value,
402528
num_resets_per_eval=_NUM_RESETS_PER_EVAL.value,
403529
progress_fn=writer.write_scalars,
530+
vision=_VISION.value,
531+
augment_pixels=_AUGMENT_PIXELS.value,
404532
save_checkpoint_path=ckpt_dir.as_posix(),
405533
restore_checkpoint_path=_RESTOREDIR.value,
406534
vf_loss_coefficient=_PPO_VF_LOSS_COEFFICIENT.value,
@@ -455,49 +583,50 @@ def main(unused_argv):
455583
get_environment, *_ = get_env_factory(_ENV.value)
456584
env = get_environment(_ENV.value)
457585

458-
def do_rollout(rng, state):
586+
def do_rollout_batched(rng, states):
459587
data_attr_name = 'pipeline_state' if hasattr(env, 'sys') else 'data'
460-
empty_data = getattr(state, data_attr_name).__class__(
461-
**{k: None for k in getattr(state, data_attr_name).__annotations__}
588+
empty_data = getattr(states, data_attr_name).__class__(
589+
**{k: None for k in getattr(states, data_attr_name).__annotations__}
462590
) # pytype: disable=attribute-error
463-
empty_traj = state.__class__(**{k: None for k in state.__annotations__}) # pytype: disable=attribute-error
591+
empty_traj = states.__class__(**{k: None for k in states.__annotations__}) # pytype: disable=attribute-error
464592
empty_traj = empty_traj.replace(**{data_attr_name: empty_data})
465593

466594
def step(carry, _):
467-
state, rng = carry
468-
rng, act_key = jax.random.split(rng)
469-
act = make_policy(params)(state.obs, act_key)[0]
470-
state = env.step(state, act)
471-
if hasattr(state, 'data'):
472-
# select a sub-set of the data for playground envs
595+
states, rng = carry
596+
rng_split = jax.vmap(jax.random.split)(rng)
597+
rng = rng_split[:, 0]
598+
act_key = rng_split[:, 1]
599+
act = jax.vmap(make_policy(params))(states.obs, act_key)[0]
600+
states = jax.vmap(env.step)(states, act)
601+
if hasattr(states, 'data'):
473602
traj_data = empty_traj.tree_replace({
474-
'data.qpos': state.data.qpos,
475-
'data.qvel': state.data.qvel,
476-
'data.time': state.data.time,
477-
'data.ctrl': state.data.ctrl,
478-
'data.mocap_pos': state.data.mocap_pos,
479-
'data.mocap_quat': state.data.mocap_quat,
480-
'data.xfrc_applied': state.data.xfrc_applied,
603+
'data.qpos': states.data.qpos,
604+
'data.qvel': states.data.qvel,
605+
'data.time': states.data.time,
606+
'data.ctrl': states.data.ctrl,
607+
'data.mocap_pos': states.data.mocap_pos,
608+
'data.mocap_quat': states.data.mocap_quat,
609+
'data.xfrc_applied': states.data.xfrc_applied,
481610
})
482-
elif hasattr(state, 'pipeline_state'):
483-
# select the entire state for brax envs
611+
elif hasattr(states, 'pipeline_state'):
484612
traj_data = empty_traj.replace(
485-
**{data_attr_name: getattr(state, data_attr_name)}
613+
**{data_attr_name: getattr(states, data_attr_name)}
486614
)
487615
else:
488616
raise ValueError(
489-
f'Unknown data attribute name: {data_attr_name} on state: {state}.'
617+
f'Unknown data attribute name: {data_attr_name} on state: {states}.'
490618
)
491-
return (state, rng), traj_data
619+
return (states, rng), traj_data
492620

493621
_, traj = jax.lax.scan(
494-
step, (state, rng), None, length=_EPISODE_LENGTH.value
622+
step, (states, rng), None, length=_EPISODE_LENGTH.value
495623
)
496624
return traj
497625

498626
rng = jax.random.split(jax.random.PRNGKey(_SEED.value), _NUM_VIDEOS.value)
499627
reset_states = jax.jit(jax.vmap(env.reset))(rng)
500-
traj_stacked = jax.jit(jax.vmap(do_rollout))(rng, reset_states)
628+
traj_stacked = jax.jit(do_rollout_batched)(rng, reset_states)
629+
traj_stacked = jax.tree.map(lambda x: jax.numpy.moveaxis(x, 0, 1), traj_stacked)
501630
trajectories = [None] * _NUM_VIDEOS.value
502631
for i in range(_NUM_VIDEOS.value):
503632
t = jax.tree.map(lambda x, i=i: x[i], traj_stacked)

0 commit comments

Comments
 (0)