@@ -207,7 +207,7 @@ def metrics_and_info_from_gradients(
207207 gathered_grads : Sequence [GradientsFromWorker ],
208208 steps : Sequence [int ],
209209 current_step : int ,
210- ) -> Tuple [Mapping [str , float ], Sequence [int ], int ]:
210+ ) -> Tuple [Mapping [str , float ], Sequence [Union [ int , jax . Array ] ], int ]:
211211 """Perform one outer-iteration on a batch of gradients from workers.
212212
213213 Args:
@@ -222,17 +222,17 @@ def metrics_and_info_from_gradients(
222222 applied_inner_steps: number if inner steps performed this outer step.
223223 """
224224
225- worker_ids = jnp . asarray ( [t .worker_id for t in gathered_grads ])
226- inner_steps = onp . asarray ( [t .total_inner_steps for t in gathered_grads ])
225+ worker_ids = [t .worker_id for t in gathered_grads ]
226+ inner_steps = [t .total_inner_steps for t in gathered_grads ]
227227
228- applied_inner_steps = onp .sum (inner_steps )
228+ applied_inner_steps = int ( onp .sum (inner_steps ) )
229229 metrics = {}
230230 metrics ["unique_worker" ] = float (len (onp .unique (worker_ids )))
231231
232- avg_stale = current_step - onp .mean (steps )
232+ avg_stale = current_step - float ( onp .mean (steps ) )
233233 metrics ["avg_staleness" ] = avg_stale
234234
235- max_stale = current_step - onp .min (steps )
235+ max_stale = current_step - float ( onp .min (steps ) )
236236 metrics ["max_staleness" ] = max_stale
237237
238238 return metrics , worker_ids , applied_inner_steps
0 commit comments