From 542e91db00c4354939c3dfe5ce7c3dd0656c106d Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Thu, 20 Nov 2025 07:58:02 -0600 Subject: [PATCH 1/2] test(dataloader): add tests for JAX grad compatibility in transformation methods --- tests/test_dataloader_jax_grad.py | 106 ++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 tests/test_dataloader_jax_grad.py diff --git a/tests/test_dataloader_jax_grad.py b/tests/test_dataloader_jax_grad.py new file mode 100644 index 0000000..647830b --- /dev/null +++ b/tests/test_dataloader_jax_grad.py @@ -0,0 +1,106 @@ +""" +Test JAX grad compatibility for dataloader transformation methods. +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +import pytest + +from probabilistic_posrec.config import ( + Config, + ModelConfig, + TrainingConfig, + DataConfig, + OutputConfig, + SimulationConfig, +) +from probabilistic_posrec.dataloader.base import Loader + + +class MockLoader(Loader): + """Mock loader for testing transformations without loading real data.""" + + def load(self) -> None: + pass + + def preprocess(self) -> None: + pass + + def get_data_array(self): + return jnp.array([]), jnp.array([]) + + def _get_hitpatterns_and_targets(self): + return jnp.array([]), jnp.array([]) + + def _prepare_hitpatterns(self, hitpatterns): + return hitpatterns, jnp.ones((len(hitpatterns), 1)) + + +@pytest.fixture +def config(): + """Minimal configuration for testing.""" + return Config( + model=ModelConfig( + r_max=66.4, + buffer=20.0, + scale=1.0, + eps=1e-7, + log_area_scale=10.0, + transform_type="tanh", + ), + training=TrainingConfig(random_seed=42), + data=DataConfig( + disabled_pmts=[], + randomize_off_pmt=False, + randomize_off_pmt_max=50, + ), + output=OutputConfig(), + simulation=SimulationConfig(), + ) + + +@pytest.fixture +def loader(config): + return MockLoader(config) + + +@pytest.fixture +def sample_data(): + """Sample data in [-1, 1] range.""" + key = jax.random.PRNGKey(42) + return jax.random.uniform(key, shape=(5, 2), minval=-0.99, maxval=0.99) + + +@pytest.mark.parametrize( + "transform_fn", ["data_transformation", "data_inv_transformation"] +) +def test_transformation_grad_compatible(loader, sample_data, transform_fn): + """Test that transformation works with jax.grad.""" + transform = getattr(loader, transform_fn) + + def loss_fn(data): + transformed = transform(data) + return jnp.sum(transformed) + + grad_fn = jax.grad(loss_fn) + grads = grad_fn(sample_data) + assert grads.shape == sample_data.shape + + +def test_round_trip_grad_precision(loader, sample_data): + """Test round-trip transformation gradient precision.""" + + def round_trip_loss(data): + transformed = loader.data_transformation(data) + restored = loader.data_inv_transformation(transformed) + return jnp.sum((restored - data) ** 2) + + grad_fn = jax.grad(round_trip_loss) + grads = grad_fn(sample_data) + + # Gradients should be well-behaved + assert jnp.all(jnp.isfinite(grads)) + # For perfect round-trip, gradient magnitude should be small + assert jnp.max(jnp.abs(grads)) < 0.1 From 34060d79688a4ca52d3ee0e83558a851aba3ff99 Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Thu, 20 Nov 2025 08:34:42 -0600 Subject: [PATCH 2/2] refactor(dataloader): improve JAX transformation compatibility with functional updates Improves upon the dataloader transformations to ensure robust JAX compatibility: - Fixes both data_transformation and data_inv_transformation for jax.grad - Uses purely functional approach with jnp.concatenate instead of in-place mutations - Removes unnecessary copy.deepcopy calls - Works uniformly with NumPy and JAX arrays without type checking - Passes all jax.grad compatibility tests --- src/probabilistic_posrec/dataloader/base.py | 27 +++++++++------------ 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/src/probabilistic_posrec/dataloader/base.py b/src/probabilistic_posrec/dataloader/base.py index 36b116f..f0beb38 100644 --- a/src/probabilistic_posrec/dataloader/base.py +++ b/src/probabilistic_posrec/dataloader/base.py @@ -155,8 +155,6 @@ def data_transformation( if max_pred is None: max_pred = self.model_config.r_max + self.model_config.buffer - data = copy.deepcopy(data) - # Only transform first 2 dimensions (positions) positions = data[:, :2] positions_normalized = positions / max_pred @@ -164,11 +162,13 @@ def data_transformation( positions_normalized ) - # Update positions in-place - data[:, :2] = positions_transformed - # Dimensions 2+ pass through unchanged (already unbounded) - - return data + # Concatenate transformed positions with unchanged dimensions + if data.shape[1] > 2: + return jnp.concatenate( + [positions_transformed, data[:, 2:]], axis=1 + ) + else: + return positions_transformed def data_inv_transformation( self, @@ -187,8 +187,6 @@ def data_inv_transformation( if max_pred is None: max_pred = self.model_config.r_max + self.model_config.buffer - data = copy.deepcopy(data) - # Only inverse transform first 2 dimensions positions_transformed = data[:, :2] positions_normalized = self.get_constrain_vec()( @@ -196,14 +194,11 @@ def data_inv_transformation( ) positions = positions_normalized * max_pred - # Update positions in-place - if type(data) == np.ndarray: - data[:, :2] = positions + # Concatenate transformed positions with unchanged dimensions + if data.shape[1] > 2: + return jnp.concatenate([positions, data[:, 2:]], axis=1) else: - data = data.at[:, :2].set(positions) - # Dimensions 2+ pass through unchanged - - return data + return positions def shuffle_data( self,