diff --git a/tests/conftest.py b/tests/conftest.py index a81b989..e28f01d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,4 +49,10 @@ def rms_roundtrip_error(): @pytest.fixture def statistical_tolerance(): """Tolerance for statistical tests (likelihood variance).""" - return 1e-10 \ No newline at end of file + return 1e-10 + + +@pytest.fixture +def grad_precision_tolerance(): + """Tolerance for gradient precision tests.""" + return 1e-3 diff --git a/tests/test_dataloader_jax_grad.py b/tests/test_dataloader_jax_grad.py index 647830b..94b8a81 100644 --- a/tests/test_dataloader_jax_grad.py +++ b/tests/test_dataloader_jax_grad.py @@ -67,10 +67,11 @@ def loader(config): @pytest.fixture -def sample_data(): +def sample_data(rng_key, sample_size): """Sample data in [-1, 1] range.""" - key = jax.random.PRNGKey(42) - return jax.random.uniform(key, shape=(5, 2), minval=-0.99, maxval=0.99) + return jax.random.uniform( + rng_key, shape=(sample_size, 2), minval=-0.99, maxval=0.99 + ) @pytest.mark.parametrize( @@ -89,7 +90,9 @@ def loss_fn(data): assert grads.shape == sample_data.shape -def test_round_trip_grad_precision(loader, sample_data): +def test_round_trip_grad_precision( + loader, sample_data, grad_precision_tolerance +): """Test round-trip transformation gradient precision.""" def round_trip_loss(data): @@ -103,4 +106,4 @@ def round_trip_loss(data): # Gradients should be well-behaved assert jnp.all(jnp.isfinite(grads)) # For perfect round-trip, gradient magnitude should be small - assert jnp.max(jnp.abs(grads)) < 0.1 + assert jnp.max(jnp.abs(grads)) < grad_precision_tolerance