diff --git a/src/probabilistic_posrec/dataloader/base.py b/src/probabilistic_posrec/dataloader/base.py index 1a65e10..36b116f 100644 --- a/src/probabilistic_posrec/dataloader/base.py +++ b/src/probabilistic_posrec/dataloader/base.py @@ -197,7 +197,10 @@ def data_inv_transformation( positions = positions_normalized * max_pred # Update positions in-place - data[:, :2] = positions + if type(data) == np.ndarray: + data[:, :2] = positions + else: + data = data.at[:, :2].set(positions) # Dimensions 2+ pass through unchanged return data @@ -439,7 +442,7 @@ def preprocess_common(self) -> None: inv_data = self.data_inv_transformation(self.fit_data) if not jnp.all( jnp.abs(inv_data - self.fit_data_untransformed) < 1e-3 - ): + ): raise ValueError( "Data transformation verification failed: inverse " "transformation does not match original data within "