In SAC train(), a segment for the workflow when loading a checkpoint is as follows:
- Training state is initialized
training_state = _init_training_state(...)
- If
restore_checkpoint_path is not None, load params from the path and replace them in training_state
Issue: jax.device_put_replicated(training_state, jax.local_devices()[:local_devices_to_use]) is called at the bottom of _init_training_state() i.e. before we load checkpoint params. When we later call _unpmap((training_state.normalizer_params, training_state.policy_params)) we get an IndexError since params are not per-device: IndexError: Too many indices: 0-dimensional array indexed with 1 regular index at jax.tree_util.tree_map(lambda x: x[0], v).
For reference, the PPO implementation calls jax.device_put_replicated after the checkpoint params have been replaced in the training state.
In SAC
train(), a segment for the workflow when loading a checkpoint is as follows:training_state = _init_training_state(...)restore_checkpoint_pathis not None, load params from the path and replace them intraining_stateIssue:
jax.device_put_replicated(training_state, jax.local_devices()[:local_devices_to_use])is called at the bottom of_init_training_state()i.e. before we load checkpoint params. When we later call_unpmap((training_state.normalizer_params, training_state.policy_params))we get an IndexError since params are not per-device:IndexError: Too many indices: 0-dimensional array indexed with 1 regular indexatjax.tree_util.tree_map(lambda x: x[0], v).For reference, the PPO implementation calls
jax.device_put_replicatedafter the checkpoint params have been replaced in the training state.