From e643e18ec71f7b480808246e76c6e01089c69459 Mon Sep 17 00:00:00 2001 From: Chaitanya Mishra Date: Sat, 24 Jan 2026 14:03:02 +0530 Subject: [PATCH] math: scan solve_pgs iterations --- brax/math.py | 6 ++++-- brax/math_test.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/brax/math.py b/brax/math.py index bf2cf058c..fb9026535 100644 --- a/brax/math.py +++ b/brax/math.py @@ -268,9 +268,11 @@ def get_x(x, xs): return x, None - # TODO(brax-team): turn this into a scan - for _ in range(num_iters): + def iterate(x, _): x, _ = jax.lax.scan(get_x, x, (jp.arange(num_rows), a, b)) + return x, None + + x, _ = jax.lax.scan(iterate, x, None, length=num_iters) return x diff --git a/brax/math_test.py b/brax/math_test.py index 703a78689..3868b0da8 100644 --- a/brax/math_test.py +++ b/brax/math_test.py @@ -65,6 +65,13 @@ def test_from_to(self): rot = math.from_to(v1, v2) np.testing.assert_array_almost_equal(v2, math.rotate(v1, rot)) + def test_solve_pgs_diagonal(self): + a = jp.diag(jp.array([2.0, 4.0, 8.0])) + b = jp.array([-2.0, 2.0, -8.0]) + expected = jp.array([1.0, 0.0, 1.0]) + x = math.solve_pgs(a, b, num_iters=2) + np.testing.assert_allclose(x, expected, rtol=1e-6, atol=1e-6) + class OrthoganalsTest(parameterized.TestCase): """Tests the orthogonals function."""