Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,9 @@ def batch_loop_simulator(theta: Tensor) -> Tensor:
return batch_loop_simulator


def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor:
def process_x(
x: Array, x_event_shape: Optional[torch.Size] = None, check_finite: bool = True
) -> Tensor:
"""Return observed data adapted to match sbi's shape and type requirements.

This means that `x` is returned with a `batch_dim`.
Expand All @@ -611,7 +613,8 @@ def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor:
"""

x = atleast_2d(torch.as_tensor(x, dtype=float32))
assert_all_finite(x, "Observed data x_o contains Nans or Infs.")
if check_finite:
assert_all_finite(x, "Observed data x_o contains Nans or Infs.")

if x_event_shape is not None and len(x_event_shape) > len(x.shape):
raise ValueError(
Expand Down
Loading