From afa10a08e906d2aa6b68478d4cf516893d807e9b Mon Sep 17 00:00:00 2001 From: Sebastian Vetter Date: Thu, 20 Nov 2025 02:52:14 -0600 Subject: [PATCH] Added option for dataloader inverse transform to work on jax arrays --- src/probabilistic_posrec/dataloader/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/probabilistic_posrec/dataloader/base.py b/src/probabilistic_posrec/dataloader/base.py index 1a65e10..36b116f 100644 --- a/src/probabilistic_posrec/dataloader/base.py +++ b/src/probabilistic_posrec/dataloader/base.py @@ -197,7 +197,10 @@ def data_inv_transformation( positions = positions_normalized * max_pred # Update positions in-place - data[:, :2] = positions + if type(data) == np.ndarray: + data[:, :2] = positions + else: + data = data.at[:, :2].set(positions) # Dimensions 2+ pass through unchanged return data @@ -439,7 +442,7 @@ def preprocess_common(self) -> None: inv_data = self.data_inv_transformation(self.fit_data) if not jnp.all( jnp.abs(inv_data - self.fit_data_untransformed) < 1e-3 - ): + ): raise ValueError( "Data transformation verification failed: inverse " "transformation does not match original data within "