From b0c51057d12ba0aeeefe7ed9f18551f745a313d4 Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Thu, 20 Nov 2025 09:02:10 -0600 Subject: [PATCH] test: consolidate test fixtures for consistency - Add grad_precision_tolerance fixture to conftest.py (1e-3) - Update sample_data to use shared rng_key and sample_size fixtures - Change sample size from 5 to 1000 for consistency with bijection tests - Use grad_precision_tolerance parameter in test_round_trip_grad_precision All test configurations now centralized in conftest.py. --- tests/conftest.py | 8 +++++++- tests/test_dataloader_jax_grad.py | 13 ++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a81b989..e28f01d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,4 +49,10 @@ def rms_roundtrip_error(): @pytest.fixture def statistical_tolerance(): """Tolerance for statistical tests (likelihood variance).""" - return 1e-10 \ No newline at end of file + return 1e-10 + + +@pytest.fixture +def grad_precision_tolerance(): + """Tolerance for gradient precision tests.""" + return 1e-3 diff --git a/tests/test_dataloader_jax_grad.py b/tests/test_dataloader_jax_grad.py index 647830b..94b8a81 100644 --- a/tests/test_dataloader_jax_grad.py +++ b/tests/test_dataloader_jax_grad.py @@ -67,10 +67,11 @@ def loader(config): @pytest.fixture -def sample_data(): +def sample_data(rng_key, sample_size): """Sample data in [-1, 1] range.""" - key = jax.random.PRNGKey(42) - return jax.random.uniform(key, shape=(5, 2), minval=-0.99, maxval=0.99) + return jax.random.uniform( + rng_key, shape=(sample_size, 2), minval=-0.99, maxval=0.99 + ) @pytest.mark.parametrize( @@ -89,7 +90,9 @@ def loss_fn(data): assert grads.shape == sample_data.shape -def test_round_trip_grad_precision(loader, sample_data): +def test_round_trip_grad_precision( + loader, sample_data, grad_precision_tolerance +): """Test round-trip transformation gradient precision.""" def round_trip_loss(data): @@ -103,4 +106,4 @@ def round_trip_loss(data): # Gradients should be well-behaved assert jnp.all(jnp.isfinite(grads)) # For perfect round-trip, gradient magnitude should be small - assert jnp.max(jnp.abs(grads)) < 0.1 + assert jnp.max(jnp.abs(grads)) < grad_precision_tolerance