Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions src/probabilistic_posrec/dataloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,20 @@ def data_transformation(
if max_pred is None:
max_pred = self.model_config.r_max + self.model_config.buffer

data = copy.deepcopy(data)

# Only transform first 2 dimensions (positions)
positions = data[:, :2]
positions_normalized = positions / max_pred
positions_transformed = self.get_unconstrain_vec()(
positions_normalized
)

# Update positions in-place
data[:, :2] = positions_transformed
# Dimensions 2+ pass through unchanged (already unbounded)

return data
# Concatenate transformed positions with unchanged dimensions
if data.shape[1] > 2:
return jnp.concatenate(
[positions_transformed, data[:, 2:]], axis=1
)
else:
return positions_transformed

def data_inv_transformation(
self,
Expand All @@ -187,23 +187,18 @@ def data_inv_transformation(
if max_pred is None:
max_pred = self.model_config.r_max + self.model_config.buffer

data = copy.deepcopy(data)

# Only inverse transform first 2 dimensions
positions_transformed = data[:, :2]
positions_normalized = self.get_constrain_vec()(
positions_transformed
)
positions = positions_normalized * max_pred

# Update positions in-place
if type(data) == np.ndarray:
data[:, :2] = positions
# Concatenate transformed positions with unchanged dimensions
if data.shape[1] > 2:
return jnp.concatenate([positions, data[:, 2:]], axis=1)
else:
data = data.at[:, :2].set(positions)
# Dimensions 2+ pass through unchanged

return data
return positions

def shuffle_data(
self,
Expand Down
106 changes: 106 additions & 0 deletions tests/test_dataloader_jax_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Test JAX grad compatibility for dataloader transformation methods.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import pytest

from probabilistic_posrec.config import (
Config,
ModelConfig,
TrainingConfig,
DataConfig,
OutputConfig,
SimulationConfig,
)
from probabilistic_posrec.dataloader.base import Loader


class MockLoader(Loader):
"""Mock loader for testing transformations without loading real data."""

def load(self) -> None:
pass

def preprocess(self) -> None:
pass

def get_data_array(self):
return jnp.array([]), jnp.array([])

def _get_hitpatterns_and_targets(self):
return jnp.array([]), jnp.array([])

def _prepare_hitpatterns(self, hitpatterns):
return hitpatterns, jnp.ones((len(hitpatterns), 1))


@pytest.fixture
def config():
"""Minimal configuration for testing."""
return Config(
model=ModelConfig(
r_max=66.4,
buffer=20.0,
scale=1.0,
eps=1e-7,
log_area_scale=10.0,
transform_type="tanh",
),
training=TrainingConfig(random_seed=42),
data=DataConfig(
disabled_pmts=[],
randomize_off_pmt=False,
randomize_off_pmt_max=50,
),
output=OutputConfig(),
simulation=SimulationConfig(),
)


@pytest.fixture
def loader(config):
return MockLoader(config)


@pytest.fixture
def sample_data():
"""Sample data in [-1, 1] range."""
key = jax.random.PRNGKey(42)
return jax.random.uniform(key, shape=(5, 2), minval=-0.99, maxval=0.99)


@pytest.mark.parametrize(
"transform_fn", ["data_transformation", "data_inv_transformation"]
)
def test_transformation_grad_compatible(loader, sample_data, transform_fn):
"""Test that transformation works with jax.grad."""
transform = getattr(loader, transform_fn)

def loss_fn(data):
transformed = transform(data)
return jnp.sum(transformed)

grad_fn = jax.grad(loss_fn)
grads = grad_fn(sample_data)
assert grads.shape == sample_data.shape


def test_round_trip_grad_precision(loader, sample_data):
"""Test round-trip transformation gradient precision."""

def round_trip_loss(data):
transformed = loader.data_transformation(data)
restored = loader.data_inv_transformation(transformed)
return jnp.sum((restored - data) ** 2)

grad_fn = jax.grad(round_trip_loss)
grads = grad_fn(sample_data)

# Gradients should be well-behaved
assert jnp.all(jnp.isfinite(grads))
# For perfect round-trip, gradient magnitude should be small
assert jnp.max(jnp.abs(grads)) < 0.1
Loading