Skip to content

Commit 90450dd

Browse files
chr1sj0neslearned_optimization authors
authored andcommitted
Fix various type annotations.
PiperOrigin-RevId: 572595145
1 parent 463ab9a commit 90450dd

6 files changed

Lines changed: 25 additions & 16 deletions

File tree

learned_optimization/learned_optimizers/nn_adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class NNAdamState:
8282
"""
8383
params: Any
8484
state: Any
85-
iteration: int
85+
iteration: jax.Array
8686
rolling_features: MeanAndMeanSquareAccumulator
8787
per_layer_lr: Any
8888
per_layer_beta1: Any

learned_optimization/optimizers/optax_opts.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,13 @@ def __init__(self,
384384

385385
# SM3 doesn't support scalars, so we have to reshape the params and grads.
386386

387-
def init(self,
388-
params: Any,
389-
model_state: Optional[Any] = None,
390-
num_steps: Optional[int] = None,
391-
key: chex.PRNGKey = None) -> SM3OptState:
387+
def init(
388+
self,
389+
params: Any,
390+
model_state: Optional[Any] = None,
391+
num_steps: Optional[int] = None,
392+
key: Optional[chex.PRNGKey] = None,
393+
) -> SM3OptState:
392394
should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test
393395
params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape)
394396
out = super().init(params, model_state, num_steps, key)

learned_optimization/outer_train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def metrics_and_info_from_gradients(
207207
gathered_grads: Sequence[GradientsFromWorker],
208208
steps: Sequence[int],
209209
current_step: int,
210-
) -> Tuple[Mapping[str, float], Sequence[int], int]:
210+
) -> Tuple[Mapping[str, float], Sequence[Union[int, jax.Array]], int]:
211211
"""Perform one outer-iteration on a batch of gradients from workers.
212212
213213
Args:
@@ -222,17 +222,17 @@ def metrics_and_info_from_gradients(
222222
applied_inner_steps: number if inner steps performed this outer step.
223223
"""
224224

225-
worker_ids = jnp.asarray([t.worker_id for t in gathered_grads])
226-
inner_steps = onp.asarray([t.total_inner_steps for t in gathered_grads])
225+
worker_ids = [t.worker_id for t in gathered_grads]
226+
inner_steps = [t.total_inner_steps for t in gathered_grads]
227227

228-
applied_inner_steps = onp.sum(inner_steps)
228+
applied_inner_steps = int(onp.sum(inner_steps))
229229
metrics = {}
230230
metrics["unique_worker"] = float(len(onp.unique(worker_ids)))
231231

232-
avg_stale = current_step - onp.mean(steps)
232+
avg_stale = current_step - float(onp.mean(steps))
233233
metrics["avg_staleness"] = avg_stale
234234

235-
max_stale = current_step - onp.min(steps)
235+
max_stale = current_step - float(onp.min(steps))
236236
metrics["max_staleness"] = max_stale
237237

238238
return metrics, worker_ids, applied_inner_steps

learned_optimization/outer_trainers/full_es.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def last_recompute_antithetic_es(
249249
recompute_samples: int,
250250
clip_loss_diff: Optional[float] = None,
251251
sign_delta_loss_scalar: Optional[float] = None,
252-
) -> Tuple[float, MetaParams]:
252+
) -> Tuple[jax.Array, MetaParams]:
253253
"""Compute an ES gradient estimate by recomputing the loss on both unrolls.
254254
255255
Args:

learned_optimization/outer_trainers/truncated_es.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def compute_es_grad(
4646
vec_pos: MetaParams,
4747
std: float,
4848
sign_delta_loss_scalar: Optional[float] = None,
49-
) -> Tuple[float, MetaParams, truncated_step_mod.TruncatedUnrollOut, float]:
49+
) -> Tuple[
50+
jax.Array, MetaParams, truncated_step_mod.TruncatedUnrollOut, jax.Array
51+
]:
5052
"""Compute the ES gradient estimate from the outputs of many unrolls.
5153
5254
Args:

learned_optimization/outer_trainers/truncated_pes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ def compute_pes_grad(
5353
vec_pos: MetaParams,
5454
std: float,
5555
sign_delta_loss_scalar: Optional[float] = None,
56-
) -> Tuple[float, MetaParams, MetaParams, truncated_step_mod.TruncatedUnrollOut,
57-
float]:
56+
) -> Tuple[
57+
jax.Array,
58+
MetaParams,
59+
MetaParams,
60+
truncated_step_mod.TruncatedUnrollOut,
61+
jax.Array,
62+
]:
5863
"""Compute the PES gradient estimate from the outputs of many unrolls.
5964
6065
Args:

0 commit comments

Comments
 (0)