diff --git a/.github/workflows/basic-install.yml b/.github/workflows/basic-install.yml index 7d8d7a3fe..167268ee6 100644 --- a/.github/workflows/basic-install.yml +++ b/.github/workflows/basic-install.yml @@ -20,9 +20,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - # disable windows build test as bilby_cython is currently broken there - os: [ubuntu-latest, macos-latest] - python-version: ["3.10", "3.11", "3.12", "3.13"] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 with: @@ -39,6 +38,7 @@ jobs: - name: Test imports run: bash test/ci_test_imports.sh - name: Test entry points + if: matrix.os != 'windows-latest' run: | for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do ${script} --help; diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 3c2eaba5d..c0e3becc2 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -32,9 +32,6 @@ jobs: fail-fast: false matrix: python: - - name: Python 3.10 - version: 3.10 - short-version: 310 - name: Python 3.11 version: 3.11 short-version: 311 @@ -84,3 +81,59 @@ jobs: with: name: pytest-${{ matrix.python.short-version }} path: pytest.xml + + array-backend: + + name: ${{ matrix.python.name }} array backend (${{ matrix.backend.name }}) + runs-on: ubuntu-latest + container: ghcr.io/bilby-dev/bilby-python${{ matrix.python.short-version }}:latest + strategy: + fail-fast: false + matrix: + python: + - name: Python 3.13 + version: 3.13 + short-version: 313 + backend: + - name: numpy + install-args: . + - name: jax + install-args: .[jax] + - name: torch + install-args: . + env: + BILBY_ARRAY_API: 1 + SCIPY_ARRAY_API: 1 + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + - name: Install package + run: | + # activate env so that conda list shows the correct environment + apt-get update + apt-get install -y gnupg curl + source $CONDA_PATH/bin/activate python${{ matrix.python.short-version }} + conda install -c conda-forge -y liblal!=7.7.0 python-lal!=7.7.0 + python -m pip install ${{ matrix.backend.install-args }} + python -m pip install orng + conda list --show-channel-urls + shell: bash + - name: Run array-backend unit tests + run: | + pytest --array-backend ${{ matrix.backend.name }} --durations 10 --junitxml=pytest-array.xml + - name: Publish coverage to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.xml + flags: python${{ matrix.backend.name }} + slug: bilby-dev/bilby + - name: Upload array-backend test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: pytest-array-${{ matrix.python.short-version }}-${{ matrix.backend.name }} + path: pytest-array.xml diff --git a/.gitignore b/.gitignore index 437f5654e..6ac816d92 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ MANIFEST **/outdir .idea/* bilby/_version.py +uv.lock diff --git a/bilby/compat/__init__.py b/bilby/compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bilby/compat/jax.py b/bilby/compat/jax.py new file mode 100644 index 000000000..af0699147 --- /dev/null +++ b/bilby/compat/jax.py @@ -0,0 +1,40 @@ +import jax +import jax.numpy as jnp +from ..core.likelihood import Likelihood + + +class JittedLikelihood(Likelihood): + """ + A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood + The likelihood to wrap. + cast_to_float: bool + Whether to return a float instead of a :code:`jax.Array`. + """ + + def __init__(self, likelihood, cast_to_float=True): + self._likelihood = likelihood + self._ll = jax.jit(likelihood.log_likelihood) + self._llr = jax.jit(likelihood.log_likelihood_ratio) + self.cast_to_float = cast_to_float + super().__init__() + + def __getattr__(self, name): + return getattr(self._likelihood, name) + + def log_likelihood(self, parameters): + parameters = {k: jnp.array(v) for k, v in parameters.items()} + ln_l = self._ll(parameters) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l + + def log_likelihood_ratio(self, parameters): + parameters = {k: jnp.array(v) for k, v in parameters.items()} + ln_l = self._llr(parameters) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py new file mode 100644 index 000000000..c6c9f544b --- /dev/null +++ b/bilby/compat/patches.py @@ -0,0 +1,104 @@ +import array_api_compat as aac + +from .utils import xp_wrap, BackendNotImplementedError, BILBY_ARRAY_API + + +def multivariate_logpdf(xp, mean, cov): + """ + Return a function to evaluate the log probability density of a multivariate + Gaussian with given mean vector and covariance matrix for the provided + array backend. + + Parameters + ========== + xp: numpy, torch, jax.numpy + A module that will resolve to :code:`numpy`, :code:`torch`, or + :code:`jax.numpy` in :code:`array_api_compat.is_..._namespace`. + mean: array-like + A one-dimensional array providing the mean of the distribution. + cov: array-like + A two-dimensional array providing the covariance matrix of the + distribution. + + Returns + ======= + logpdf: callable + A callable that provides the log probaility density provided an array + of points to evaluate at. + """ + from scipy.stats import multivariate_normal + + if not BILBY_ARRAY_API or aac.is_numpy_namespace(xp): + logpdf = multivariate_normal(mean=mean, cov=cov).logpdf + elif aac.is_jax_namespace(xp): + from functools import partial + from jax.scipy.stats.multivariate_normal import logpdf + + logpdf = partial(logpdf, mean=mean, cov=cov) + elif aac.is_torch_namespace(xp): + from torch.distributions.multivariate_normal import MultivariateNormal + + mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov)) + logpdf = mvn.log_prob + else: + raise BackendNotImplementedError( + f"Unable to import multivariate_logpdf for {xp}" + ) + return logpdf + + +@xp_wrap +def interp(x, xs, fs, /, left=None, right=None, period=None, *, xp=None): + """ + A simple implementation of numpy-style linear interpolation + + The logic is copied from + https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 + + Parameters + ========== + x: array-like + The values to evaluate the interpolant at. + xs: array-like + The x-values for setting up the interpolant. + ys: array-like + The values of the function for setting up the interpolant. + left: float + The value to use for x < xs[0]. Default is fs[0] + right: float + The value to use for x > xs[-1]. Default is fs[-1]. + period: float + The period of the interpolant. + Parameters left and right are ignored if period is specified. + + Notes + ===== + To avoid overlap with the ``xp`` variable, the second and third variable + names from differ from numpy. + These arguments are enforced to be positional only. + """ + if not BILBY_ARRAY_API or hasattr(xp, "interp"): + return xp.interp(x, xs, fs, left=left, right=right, period=period) + + if period is not None: + x = x % period + if left is None: + left = fs[0] + if right is None: + right = fs[-1] + + x = xp.atleast_1d(x) + + m = (fs[1:] - fs[:-1]) / (xs[1:] - xs[:-1]) + b = fs[:-1] - (m * xs[:-1]) + + indices = xp.sum(xp.ge(x[:, None], xs[None, :]), axis=1) - 1 + indices = xp.clip(indices, 0, len(m) - 1) + + ret = m[indices] * x + b[indices] + + if period is None: + ret = xp.where(x < xs[0], xp.asarray(left), ret) + ret = xp.where(x > xs[-1], xp.asarray(right), ret) + + return ret.squeeze() diff --git a/bilby/compat/types.py b/bilby/compat/types.py new file mode 100644 index 000000000..bb83cb478 --- /dev/null +++ b/bilby/compat/types.py @@ -0,0 +1,4 @@ +import numpy as np + +Real = float | int | np.number +ArrayLike = np.ndarray | list | tuple diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py new file mode 100644 index 000000000..e1a77af00 --- /dev/null +++ b/bilby/compat/utils.py @@ -0,0 +1,231 @@ +import inspect +import os +from collections.abc import Iterable + +import numpy as np +from array_api_compat import array_namespace, is_numpy_namespace + +from ..core.utils.log import logger + +__all__ = ["array_module", "promote_to_array"] + +# environment variable to control whether to use the array API or not implementation taken from +# https://github.com/scipy/scipy/blob/514aeea23e1c90cc4d736ef0ee8b5d762dab461a/scipy/_lib/_array_api_override.py#L27 +BILBY_ARRAY_API = os.getenv("BILBY_ARRAY_API", False) +SCIPY_ARRAY_API = os.getenv("SCIPY_ARRAY_API", False) +if BILBY_ARRAY_API and not SCIPY_ARRAY_API: + logger.warning( + "BILBY_ARRAY_API is set but SCIPY_ARRAY_API is not set. " + "This may lead to unexpected behavior since some functions in " + "scipy will not be array API compatible. It is recommended to set " + "both environment variables to ensure consistent behavior." + ) +BILBY_DEVICE = os.getenv("BILBY_DEVICE", None) +SCIPY_DEVICE = os.getenv("SCIPY_DEVICE", None) +if BILBY_DEVICE and SCIPY_DEVICE is None: + logger.warning( + "BILBY_DEVICE is set but SCIPY_DEVICE is not set. " + "This may lead to unexpected behavior since some functions may not" + "set the default device consistently. It is recommended to set " + "both environment variables to ensure consistent behavior." + ) + + +def array_module(arr): + """ + Infer the array module (namespace) from the input argument. + This is a generalization of the :code:`array_api_compat.array_namespace` + function that can handle a wider variety of inputs, including some nested + structures. + + This function determines which array library backend is being used + by inspecting the input argument. It handles various input types and + fallback mechanisms to ensure a valid array module is always returned. + + The inference logic proceeds as follows: + 1. If a single-element tuple is provided, extract the element first. + 2. Attempt to use the array_api_compat.array_namespace() function + directly, which works for most array-like objects. + 3. If that fails, handle special cases: + - Dictionaries: extract values and infer from non-string values + - Builtin iterables (list, tuple, etc.): infer from elements + - Builtin scalars: default to numpy + - Pandas objects: default to numpy (treated as numpy-compatible) + - Unknown types: log a warning and default to numpy + + This is a best-effort function, but will not cover all possible edge cases. + + If support for the general array API is not activated via :code:`BILBY_ARRAY_API=1`, + this always returns numpy. + + Parameters + ========== + arr: array-like, tuple, dict, or other type + The input argument to infer the array module from. Can be: + - An array object (numpy, cupy, jax.numpy, etc.) + - A tuple of arrays (single-element unwrapped) + - A dictionary with array values + - An iterable containing arrays + - A builtin scalar or type + + Returns + ======= + module + The array namespace module (e.g., numpy, cupy, jax.numpy, etc.). + Defaults to numpy if the module cannot be determined. + + Examples + ======== + >>> import numpy as np + >>> import jax.numpy as jnp + >>> array_module(np.array([1, 2, 3])) + + + >>> array_module(jnp.array([1, 2, 3])) + + + >>> array_module({'data': np.array([1, 2, 3])}) + + + >>> array_module([np.array([1]), np.array([2])]) + + + >>> array_module([1, jnp.array([2])]) + + + >>> array_module(5) + + """ + if not BILBY_ARRAY_API: + return np + + # FIXME: remove direct import of orng to avoid hard dependency + import orng + if isinstance(arr, orng.ArrayRNG): + match arr.backend: + case "jax": + import jax.numpy as jnp + return jnp + case _: + return np + if isinstance(arr, tuple) and len(arr) == 1: + arr = arr[0] + try: + return array_namespace(arr) + except TypeError: + if isinstance(arr, dict): + try: + return array_namespace(*[val for val in arr.values() if not isinstance(val, str)]) + except TypeError: + return np + elif arr.__class__.__module__ == "builtins" and isinstance(arr, Iterable): + try: + return array_namespace(*arr) + except TypeError: + return np + elif arr.__class__.__module__ == "builtins": + return np + elif arr.__module__.startswith("pandas"): + return np + else: + logger.warning( + f"Unknown array module for type: {type(arr)} Defaulting to numpy." + ) + return np + + +def promote_to_array(args, xp, skip=None): + """ + Promote arguments to arrays using the specified array module. + + Parameters + ========== + args: tuple + Tuple of arguments to promote. + xp: module + The array module (namespace) to use for promotion. + skip: int, optional + Number of trailing arguments to skip promotion for. + Defaults to None (promote all arguments). + + Returns + ======= + tuple + Arguments with the first (len(args) - skip) elements promoted to + arrays using the specified module. + + Notes + ===== + This function cannot handle manual specification of devices. Arrays + are promoted to the default device of the specified array module unless + the :code:`BILBY_DEVICE` environment variable is set, e.g., to ignore a + GPU when using :code:`pytorch`, users can specify :code:`BILBY_DEVICE=cpu`. + """ + if skip is None: + skip = len(args) + else: + skip = len(args) - skip + if not is_numpy_namespace(xp): + args = tuple( + xp.asarray(arg, device=BILBY_DEVICE) for arg in args[:skip] + ) + args[skip:] + return args + + +def xp_wrap(func, no_xp=False): + """ + A decorator that will figure out the array module from the input + arguments and pass it to the function as the 'xp' keyword argument. + + Parameters + ========== + func: function + The function to be decorated. + no_xp: bool + If True, the decorator will not attempt to add the 'xp' keyword + argument and so the wrapper is a no-op. + + Returns + ======= + function + The decorated function. + """ + def parse_args_kwargs_for_xp(*args, xp=None, **kwargs): + if not no_xp and xp is None: + try: + # if the user specified the target arrays in kwargs + # we need to be able to support this, if there is + # only one kwargs, pass it through alone, this is + # sometimes a dictionary of arrays so this is needed + # to remove a level of nesting + if len(args) > 0: + xp = array_module(args) + elif len(kwargs) == 1: + xp = array_module(next(iter(kwargs.values()))) + elif len(kwargs) > 1: + xp = array_module(kwargs) + else: + xp = np + kwargs["xp"] = xp + except TypeError as e: + print("type failed", e) + kwargs["xp"] = np + elif not no_xp: + kwargs["xp"] = xp + return args, kwargs + + sig = inspect.signature(func) + if any(name in sig.parameters for name in ("self", "cls")): + def wrapped(self, *args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(self, *args, **kwargs) + else: + def wrapped(*args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(*args, **kwargs) + + return wrapped + + +class BackendNotImplementedError(NotImplementedError): + pass diff --git a/bilby/core/grid.py b/bilby/core/grid.py index c574a7a3a..4bd55bed7 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -1,6 +1,8 @@ import json import os +from copy import copy +import array_api_compat as aac import numpy as np from .prior import Prior, PriorDict @@ -9,6 +11,7 @@ BilbyJsonEncoder, load_json, move_old_file ) from .result import FileMovedError +from ..compat.utils import array_module def grid_file_name(outdir, label, gzip=False): @@ -35,8 +38,11 @@ def grid_file_name(outdir, label, gzip=False): class Grid(object): - def __init__(self, likelihood=None, priors=None, grid_size=101, - save=False, label='no_label', outdir='.', gzip=False): + def __init__( + self, likelihood=None, priors=None, grid_size=101, + save=False, label='no_label', outdir='.', gzip=False, + xp=None, + ): """ Parameters @@ -57,8 +63,16 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, The output directory to which the grid will be saved gzip: bool Set whether to gzip the output grid file + xp: array module | None + The array module to use for calculations (e.g., :code:`numpy`, + :code:`cupy`). If :code:`None`, defaults to :code:`numpy`. + """ + if xp is None: + xp = np + logger.debug("No array module given for grid, defaulting to numpy.") + if priors is None: priors = dict() self.likelihood = likelihood @@ -67,13 +81,15 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, self.parameter_names = list(self.priors.keys()) self.sample_points = dict() - self._get_sample_points(grid_size) + self._get_sample_points(grid_size, xp=xp) # evaluate the prior on the grid points if self.n_dims > 0: self._ln_prior = self.priors.ln_prob( {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)}, axis=0).reshape( self.mesh_grid[0].shape) + else: + self._ln_prior = xp.asarray(0.0) self._ln_likelihood = None # evaluate the likelihood on the grid points @@ -96,12 +112,14 @@ def ln_prior(self): @property def prior(self): - return np.exp(self.ln_prior) + lnp = self.ln_prior + xp = array_module(lnp) + return xp.exp(lnp) @property def ln_likelihood(self): if self._ln_likelihood is None: - self._evaluate() + self._evaluate(xp=array_module(self._ln_prior)) return self._ln_likelihood @property @@ -115,7 +133,8 @@ def marginalize(self, log_array, parameters=None, not_parameters=None): Parameters ========== log_array: array_like - A :class:`numpy.ndarray` of log likelihood/posterior values. + A :code:`Python` array-api compatible array of log + likelihood/posterior values. parameters: list, str A list, or single string, of parameters to marginalize over. If None then all parameters will be marginalized over. @@ -150,7 +169,7 @@ def marginalize(self, log_array, parameters=None, not_parameters=None): else: raise TypeError("Parameters names must be a list or string") - out_array = log_array.copy() + out_array = copy(log_array) names = list(self.parameter_names) for name in params: @@ -165,7 +184,8 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): Parameters ========== log_array: array_like - A :class:`numpy.ndarray` of log likelihood/posterior values. + A :code:`Python` array-api compatible array of log + likelihood/posterior values. name: str The name of the parameter to marginalize over. non_marg_names: list @@ -188,17 +208,26 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): non_marg_names.remove(name) places = self.sample_points[name] + xp = aac.get_namespace(log_array) if len(places) > 1: - dx = np.diff(places) - out = np.apply_along_axis( - logtrapzexp, axis, log_array, dx - ) + dx = xp.diff(places) + if log_array.ndim == 1: + out = logtrapzexp(log_array, dx=dx, xp=xp) + elif aac.is_torch_namespace(xp): + # https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440 + out = xp.stack([ + logtrapzexp(x_i, dx=dx, xp=xp) for x_i in xp.unbind(log_array, dim=axis) + ], dim=min(axis, log_array.ndim - 2)) + else: + out = xp.apply_along_axis( + logtrapzexp, axis, log_array, dx + ) else: # no marginalisation required, just remove the singleton dimension z = log_array.shape - q = np.arange(0, len(z)).astype(int) != axis - out = np.reshape(log_array, tuple((np.array(list(z)))[q])) + q = xp.arange(0, len(z)).astype(int) != axis + out = xp.reshape(log_array, tuple((xp.asarray(list(z)))[q])) return out @@ -276,8 +305,9 @@ def marginalize_likelihood(self, parameters=None, not_parameters=None): """ ln_like = self.marginalize(self.ln_likelihood, parameters=parameters, not_parameters=not_parameters) + xp = aac.get_namespace(ln_like) # NOTE: the output will not be properly normalised - return np.exp(ln_like - np.max(ln_like)) + return xp.exp(ln_like - xp.max(ln_like)) def marginalize_posterior(self, parameters=None, not_parameters=None): """ @@ -300,18 +330,31 @@ def marginalize_posterior(self, parameters=None, not_parameters=None): ln_post = self.marginalize(self.ln_posterior, parameters=parameters, not_parameters=not_parameters) # NOTE: the output will not be properly normalised - return np.exp(ln_post - np.max(ln_post)) + xp = aac.get_namespace(ln_post) + return xp.exp(ln_post - xp.max(ln_post)) def _evaluate(self): - self._ln_likelihood = np.empty(self.mesh_grid[0].shape) - self._evaluate_recursion(0, parameters=dict()) + xp = aac.get_namespace(self.mesh_grid[0]) + if aac.is_torch_namespace(xp) or aac.is_jax_namespace(xp): + if aac.is_torch_namespace(xp): + from torch import vmap + else: + from jax import vmap + self._ln_likelihood = vmap(self.likelihood.log_likelihood)( + {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)} + ).reshape(self.mesh_grid[0].shape) + + else: + self._ln_likelihood = xp.empty(self.mesh_grid[0].shape) + self._evaluate_recursion(0, parameters=dict()) self.ln_noise_evidence = self.likelihood.noise_log_likelihood() def _evaluate_recursion(self, dimension, parameters): if dimension == self.n_dims: - current_point = tuple([[int(np.where( + xp = aac.get_namespace(self.mesh_grid[0]) + current_point = tuple([[xp.where( parameters[name] == - self.sample_points[name])[0])] for name in self.parameter_names]) + self.sample_points[name])[0].item()] for name in self.parameter_names]) self._ln_likelihood[current_point] = self.likelihood.log_likelihood(parameters) else: name = self.parameter_names[dimension] @@ -319,29 +362,29 @@ def _evaluate_recursion(self, dimension, parameters): parameters[name] = self.sample_points[name][ii] self._evaluate_recursion(dimension + 1, parameters) - def _get_sample_points(self, grid_size): + def _get_sample_points(self, grid_size, *, xp=np): for ii, key in enumerate(self.parameter_names): if isinstance(self.priors[key], Prior): if isinstance(grid_size, int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size)) + xp.linspace(0, 1, grid_size)) elif isinstance(grid_size, list): if isinstance(grid_size[ii], int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[ii])) + xp.linspace(0, 1, grid_size[ii])) else: - self.sample_points[key] = grid_size[ii] + self.sample_points[key] = xp.asarray(grid_size[ii]) elif isinstance(grid_size, dict): if isinstance(grid_size[key], int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[key])) + xp.linspace(0, 1, grid_size[key])) else: - self.sample_points[key] = grid_size[key] + self.sample_points[key] = xp.asarray(grid_size[key]) else: raise TypeError("Unrecognized 'grid_size' type") # set the mesh of points - self.mesh_grid = np.meshgrid( + self.mesh_grid = xp.meshgrid( *(self.sample_points[key] for key in self.parameter_names), indexing='ij') @@ -417,7 +460,7 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, "following message:\n {} \n\n".format(e)) @classmethod - def read(cls, filename=None, outdir=None, label=None, gzip=False): + def read(cls, filename=None, outdir=None, label=None, gzip=False, xp=None): """ Read in a saved .json grid file Parameters @@ -430,6 +473,9 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): If given, whether the file is gzipped or not (only required if the file is gzipped, but does not have the standard '.gz' file extension) + xp: array module | None + The array module to use for calculations (e.g., :code:`numpy`, + :code:`jax.numpy`). If :code:`None`, defaults to :code:`numpy`. Returns ======= @@ -453,7 +499,7 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): try: grid = cls(likelihood=None, priors=dictionary['priors'], grid_size=dictionary['sample_points'], - label=dictionary['label'], outdir=dictionary['outdir']) + label=dictionary['label'], outdir=dictionary['outdir'], xp=xp) # set the likelihood grid._ln_likelihood = dictionary['ln_likelihood'] diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 210dc501f..0d3fd4537 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -1,10 +1,13 @@ import copy +import array_api_compat as aac import numpy as np +from array_api_compat import is_array_api_obj from scipy.special import gammaln, xlogy -from scipy.stats import multivariate_normal from .utils import infer_parameters_from_function, infer_args_from_function_except_n_args +from ..compat.patches import multivariate_logpdf +from ..compat.utils import BackendNotImplementedError, array_module class Likelihood: @@ -195,9 +198,10 @@ def __init__(self, x, y, func, sigma=None, **kwargs): self.sigma = sigma def log_likelihood(self, parameters): + xp = array_module(self.x) sigma = parameters.get("sigma", self.sigma) - log_l = np.sum(- (self.residual(parameters) / sigma)**2 / 2 - - np.log(2 * np.pi * sigma**2) / 2) + log_l = xp.sum(- (self.residual(parameters) / sigma)**2 / 2 - + xp.log(xp.asarray(2 * np.pi * sigma**2)) / 2) return log_l def __repr__(self): @@ -253,17 +257,18 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters): rate = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if not isinstance(rate, np.ndarray): + if not is_array_api_obj(rate): raise ValueError( "Poisson rate function returns wrong value type! " "Is {} when it should be numpy.ndarray".format(type(rate))) - elif np.any(rate < 0.): + xp = aac.get_namespace(rate) + if xp.any(rate < 0.): raise ValueError(("Poisson rate function returns a negative", " value!")) - elif np.any(rate == 0.): + elif xp.any(rate == 0.): return -np.inf else: - return np.sum(-rate + self.y * np.log(rate) - gammaln(self.y + 1)) + return xp.sum(-rate + self.y * xp.log(rate) - gammaln(self.y + 1)) def __repr__(self): return Analytical1DLikelihood.__repr__(self) @@ -275,10 +280,12 @@ def y(self): @y.setter def y(self, y): - if not isinstance(y, np.ndarray): - y = np.array([y]) + if not is_array_api_obj(y): + y = np.atleast_1d(y) + xp = aac.get_namespace(y) # check array is a non-negative integer array - if y.dtype.kind not in 'ui' or np.any(y < 0): + # torch doesn't support checking dtype kind + if (not aac.is_torch_namespace(xp) and y.dtype.kind not in 'ui') or xp.any(y < 0): raise ValueError("Data must be non-negative integers") self.__y = y @@ -304,9 +311,10 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if np.any(mu < 0.): + xp = array_module(mu) + if xp.any(mu < 0.): return -np.inf - return -np.sum(np.log(mu) + (self.y / mu)) + return -xp.sum(xp.log(mu) + (self.y / mu)) def __repr__(self): return Analytical1DLikelihood.__repr__(self) @@ -318,9 +326,10 @@ def y(self): @y.setter def y(self, y): - if not isinstance(y, np.ndarray): - y = np.array([y]) - if np.any(y < 0): + if not is_array_api_obj(y): + y = np.atleast_1d(y) + xp = aac.get_namespace(y) + if xp.any(y < 0): raise ValueError("Data must be non-negative") self._y = y @@ -366,9 +375,10 @@ def log_likelihood(self, parameters): raise ValueError("Number of degrees of freedom for Student's " "t-likelihood must be positive") + xp = array_module(self.x) log_l =\ - np.sum(- (nu + 1) * np.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + - np.log(self.lam / (nu * np.pi)) / 2 + + xp.sum(- (nu + 1) * xp.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + + xp.log(xp.asarray(self.lam / (nu * np.pi))) / 2 + gammaln((nu + 1) / 2) - gammaln(nu / 2)) return log_l @@ -412,8 +422,10 @@ def __init__(self, data, n_dimensions, base="parameter_"): base: str The base of the parameter labels """ - self.data = np.array(data) - self._total = np.sum(self.data) + if not is_array_api_obj(data): + data = np.array(data) + self.data = data + self._total = self.data.sum() super(Multinomial, self).__init__() self.n = n_dimensions self.base = base @@ -448,7 +460,8 @@ def noise_log_likelihood(self): def _multinomial_ln_pdf(self, probs): """Lifted from scipy.stats.multinomial._logpdf""" - ln_prob = gammaln(self._total + 1) + np.sum( + xp = array_module(self.data) + ln_prob = gammaln(self._total + 1) + xp.sum( xlogy(self.data, probs) - gammaln(self.data + 1), axis=-1) return ln_prob @@ -467,10 +480,17 @@ class AnalyticalMultidimensionalCovariantGaussian(Likelihood): """ def __init__(self, mean, cov): - self.cov = np.atleast_2d(cov) - self.mean = np.atleast_1d(mean) - self.sigma = np.sqrt(np.diag(self.cov)) - self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) + xp = array_module(cov) + self.cov = xp.atleast_2d(cov) + self.mean = xp.atleast_1d(mean) + self.sigma = xp.sqrt(xp.diag(self.cov)) + try: + self.logpdf = multivariate_logpdf(xp, mean=self.mean, cov=self.cov) + except BackendNotImplementedError: + raise NotImplementedError( + f"Multivariate normal likelihood not implemented for {xp.__name__} backend" + ) + super(AnalyticalMultidimensionalCovariantGaussian, self).__init__() @property @@ -478,8 +498,9 @@ def dim(self): return len(self.cov[0]) def log_likelihood(self, parameters): - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.pdf.logpdf(x) + xp = array_module(self.cov) + x = xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) + return self.logpdf(x) class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): @@ -497,12 +518,18 @@ class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): """ def __init__(self, mean_1, mean_2, cov): - self.cov = np.atleast_2d(cov) - self.sigma = np.sqrt(np.diag(self.cov)) - self.mean_1 = np.atleast_1d(mean_1) - self.mean_2 = np.atleast_1d(mean_2) - self.pdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov) - self.pdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov) + xp = array_module(cov) + self.cov = xp.atleast_2d(cov) + self.sigma = xp.sqrt(xp.diag(self.cov)) + self.mean_1 = xp.atleast_1d(mean_1) + self.mean_2 = xp.atleast_1d(mean_2) + try: + self.logpdf_1 = multivariate_logpdf(xp, mean=self.mean_1, cov=self.cov) + self.logpdf_2 = multivariate_logpdf(xp, mean=self.mean_2, cov=self.cov) + except BackendNotImplementedError: + raise NotImplementedError( + f"Multivariate normal likelihood not implemented for {xp.__name__} backend" + ) super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__() @property @@ -510,8 +537,9 @@ def dim(self): return len(self.cov[0]) def log_likelihood(self, parameters): - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return -np.log(2) + np.logaddexp(self.pdf_1.logpdf(x), self.pdf_2.logpdf(x)) + xp = array_module(self.cov) + x = xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) + return -xp.log(2) + xp.logaddexp(self.logpdf_1(x), self.logpdf_2(x)) class JointLikelihood(Likelihood): diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index bc47cf680..c3bdf67e3 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1,21 +1,22 @@ import numpy as np from scipy.special import ( - xlogy, - erf, - erfinv, - log1p, - stdtrit, - gammaln, - stdtr, - betaln, betainc, betaincinv, + betaln, + erf, + erfinv, gammaincinv, gammainc, + gammaln, + stdtr, + stdtrit, + xlogy, + xlog1py, ) from .base import Prior from ..utils import logger +from ...compat.utils import array_module, xp_wrap class DeltaFunction(Prior): @@ -41,7 +42,7 @@ def __init__(self, peak, name=None, latex_label=None, unit=None): self._is_fixed = True self.least_recently_sampled = peak - def rescale(self, val): + def rescale(self, val, *, xp=None): """Rescale everything to the peak with the correct shape. Parameters @@ -54,7 +55,7 @@ def rescale(self, val): """ return self.peak * val ** 0 - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -67,10 +68,11 @@ def prob(self, val): """ at_peak = (val == self.peak) - return np.nan_to_num(np.multiply(at_peak, np.inf)) + # coerce bool to float for some array backends + return at_peak * 1.0 - def cdf(self, val): - return np.ones_like(val) * (val > self.peak) + def cdf(self, val, *, xp=None): + return 1.0 * (val > self.peak) class PowerLaw(Prior): @@ -101,7 +103,8 @@ def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, boundary=boundary) self.alpha = alpha - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -117,12 +120,13 @@ def rescale(self, val): Union[float, array_like]: Rescaled probability """ if self.alpha == -1: - return self.minimum * np.exp(val * np.log(self.maximum / self.minimum)) + return self.minimum * xp.exp(val * xp.log(xp.asarray(self.maximum / self.minimum))) else: return (self.minimum ** (1 + self.alpha) + val * (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -134,13 +138,16 @@ def prob(self, val): float: Prior probability of val """ if self.alpha == -1: - return np.nan_to_num(1 / val / np.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) + return xp.nan_to_num( + 1 / val / xp.log(xp.asarray(self.maximum / self.minimum)) + ) * self.is_in_prior_range(val) else: - return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / + return xp.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) * self.is_in_prior_range(val) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the logarithmic prior probability of val Parameters @@ -153,28 +160,29 @@ def ln_prob(self, val): """ if self.alpha == -1: - normalising = 1. / np.log(self.maximum / self.minimum) + normalising = 1. / xp.log(xp.asarray(self.maximum / self.minimum)) else: - normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - - self.minimum ** (1 + self.alpha)) + normalising = (1 + self.alpha) / xp.asarray( + self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha) + ) with np.errstate(divide='ignore', invalid='ignore'): - ln_in_range = np.log(1. * self.is_in_prior_range(val)) - ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising) + ln_in_range = xp.log(1. * self.is_in_prior_range(val)) + ln_p = self.alpha * xp.nan_to_num(xp.log(val)) + xp.log(normalising) return ln_p + ln_in_range - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): if self.alpha == -1: - _cdf = (np.log(val / self.minimum) / - np.log(self.maximum / self.minimum)) + with np.errstate(invalid="ignore"): + _cdf = xp.log(val / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) else: _cdf = ( (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) ) - _cdf = np.minimum(_cdf, 1) - _cdf = np.maximum(_cdf, 0) + _cdf = xp.clip(_cdf, 0, 1) return _cdf @@ -203,7 +211,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -220,7 +228,7 @@ def rescale(self, val): """ return self.minimum + val * (self.maximum - self.minimum) - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -233,7 +241,8 @@ def prob(self, val): """ return ((val >= self.minimum) & (val <= self.maximum)) / (self.maximum - self.minimum) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val Parameters @@ -244,13 +253,13 @@ def ln_prob(self, val): ======= float: log probability of val """ - return xlogy(1, (val >= self.minimum) & (val <= self.maximum)) - xlogy(1, self.maximum - self.minimum) + with np.errstate(divide="ignore"): + return xp.log(self.prob(val)) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): _cdf = (val - self.minimum) / (self.maximum - self.minimum) - _cdf = np.minimum(_cdf, 1) - _cdf = np.maximum(_cdf, 0) - return _cdf + return xp.clip(_cdf, 0, 1) class LogUniform(PowerLaw): @@ -310,7 +319,8 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -325,21 +335,14 @@ def rescale(self, val): ======= Union[float, array_like]: Rescaled probability """ - if isinstance(val, (float, int)): - if val < 0.5: - return -self.maximum * np.exp(-2 * val * np.log(self.maximum / self.minimum)) - else: - return self.minimum * np.exp(np.log(self.maximum / self.minimum) * (2 * val - 1)) - else: - vals_less_than_5 = val < 0.5 - rescaled = np.empty_like(val) - rescaled[vals_less_than_5] = -self.maximum * np.exp(-2 * val[vals_less_than_5] * - np.log(self.maximum / self.minimum)) - rescaled[~vals_less_than_5] = self.minimum * np.exp(np.log(self.maximum / self.minimum) * - (2 * val[~vals_less_than_5] - 1)) - return rescaled - - def prob(self, val): + return ( + xp.sign(2 * val - 1) + * self.minimum + * xp.exp(xp.abs(2 * val - 1) * xp.log(xp.asarray(self.maximum / self.minimum))) + ) + + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -350,11 +353,12 @@ def prob(self, val): ======= float: Prior probability of val """ - val = np.abs(val) - return (np.nan_to_num(0.5 / val / np.log(self.maximum / self.minimum)) * + val = xp.abs(val) + return (xp.nan_to_num(0.5 / val / xp.log(xp.asarray(self.maximum / self.minimum))) * self.is_in_prior_range(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the logarithmic prior probability of val Parameters @@ -366,19 +370,12 @@ def ln_prob(self, val): float: """ - return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) + return xp.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(xp.asarray(self.maximum / self.minimum)))) - def cdf(self, val): - norm = 0.5 / np.log(self.maximum / self.minimum) - _cdf = ( - -norm * np.log(abs(val) / self.maximum) - * (val <= -self.minimum) * (val >= -self.maximum) - + (0.5 + norm * np.log(abs(val) / self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 0.5 * (val > -self.minimum) * (val < self.minimum) - + 1 * (val > self.maximum) - ) - return _cdf + @xp_wrap + def cdf(self, val, *, xp=None): + asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) + return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) class Cosine(Prior): @@ -405,16 +402,18 @@ def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, super(Cosine, self).__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to a uniform in cosine prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (np.sin(self.maximum) - np.sin(self.minimum)) - return np.arcsin(val / norm + np.sin(self.minimum)) + norm = 1 / (xp.sin(xp.asarray(self.maximum)) - xp.sin(xp.asarray(self.minimum))) + return xp.arcsin(val / norm + xp.sin(xp.asarray(self.minimum))) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Defined over [-pi/2, pi/2]. Parameters @@ -425,15 +424,17 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.cos(val) / 2 * self.is_in_prior_range(val) + return xp.cos(val) / 2 * self.is_in_prior_range(val) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): _cdf = ( - (np.sin(val) - np.sin(self.minimum)) - / (np.sin(self.maximum) - np.sin(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) + (xp.sin(val) - xp.sin(xp.asarray(self.minimum))) / + (xp.sin(xp.asarray(self.maximum)) - xp.sin(xp.asarray(self.minimum))) ) + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -461,16 +462,18 @@ def __init__(self, minimum=0, maximum=np.pi, name=None, super(Sine, self).__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to a uniform in sine prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (np.cos(self.minimum) - np.cos(self.maximum)) - return np.arccos(np.cos(self.minimum) - val / norm) + norm = 1 / (xp.cos(xp.asarray(self.minimum)) - xp.cos(xp.asarray(self.maximum))) + return xp.arccos(xp.cos(xp.asarray(self.minimum)) - val / norm) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Defined over [0, pi]. Parameters @@ -481,15 +484,17 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.sin(val) / 2 * self.is_in_prior_range(val) + return xp.sin(val) / 2 * self.is_in_prior_range(val) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): _cdf = ( - (np.cos(val) - np.cos(self.minimum)) - / (np.cos(self.maximum) - np.cos(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) + (xp.cos(val) - xp.cos(xp.asarray(self.minimum))) + / (xp.cos(xp.asarray(self.maximum)) - xp.cos(xp.asarray(self.minimum))) ) + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -517,7 +522,8 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.sigma = sigma - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Gaussian prior. @@ -529,7 +535,8 @@ def rescale(self, val): """ return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -540,9 +547,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma + return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the Log prior probability of val. Parameters @@ -553,10 +561,9 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ + return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + xp.log(xp.asarray(2 * np.pi * self.sigma ** 2))) - return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + np.log(2 * np.pi * self.sigma ** 2)) - - def cdf(self, val): + def cdf(self, val, *, xp=None): return (1 - erf((self.mu - val) / 2 ** 0.5 / self.sigma)) / 2 @@ -607,7 +614,8 @@ def normalisation(self): return (erf((self.maximum - self.mu) / 2 ** 0.5 / self.sigma) - erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate truncated Gaussian prior. @@ -616,7 +624,8 @@ def rescale(self, val): return erfinv(2 * val * self.normalisation + erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -627,17 +636,15 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ + return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ / self.sigma / self.normalisation * self.is_in_prior_range(val) - def cdf(self, val): - _cdf = ( - ( - erf((val - self.mu) / 2 ** 0.5 / self.sigma) - - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) - ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) - ) + def cdf(self, val, *, xp=None): + _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( + (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -701,15 +708,17 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.sigma = sigma - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate LogNormal prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return np.exp(self.mu + np.sqrt(2 * self.sigma ** 2) * erfinv(2 * val - 1)) + return xp.exp(self.mu + (2 * self.sigma ** 2)**0.5 * erfinv(2 * val - 1)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Returns the prior probability of val. Parameters @@ -720,20 +729,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val <= self.minimum: - _prob = 0. - else: - _prob = np.exp(-(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val / self.sigma - else: - _prob = np.zeros(val.size) - idx = (val > self.minimum) - _prob[idx] = np.exp(-(np.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val[idx] / self.sigma - return _prob + return xp.exp(self.ln_prob(val, xp=xp)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -744,30 +743,18 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val <= self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = -(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2\ - - np.log(np.sqrt(2 * np.pi) * val * self.sigma) - else: - _ln_prob = -np.inf * np.ones(val.size) - idx = (val > self.minimum) - _ln_prob[idx] = -(np.log(val[idx]) - self.mu) ** 2\ - / self.sigma ** 2 / 2 - np.log(np.sqrt(2 * np.pi) * val[idx] * self.sigma) - return _ln_prob - - def cdf(self, val): - if isinstance(val, (float, int)): - if val <= self.minimum: - _cdf = 0. - else: - _cdf = 0.5 + erf((np.log(val) - self.mu) / self.sigma / np.sqrt(2)) / 2 - else: - _cdf = np.zeros(val.size) - _cdf[val > self.minimum] = 0.5 + erf(( - np.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 - return _cdf + with np.errstate(divide="ignore", invalid="ignore"): + return xp.nan_to_num(( + -(xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) ** 2 / self.sigma ** 2 / 2 + - xp.log((2 * np.pi)**0.5 * val * self.sigma) + ) + xp.log(val > self.minimum), nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) + + @xp_wrap + def cdf(self, val, *, xp=None): + with np.errstate(divide="ignore"): + return 0.5 + erf( + (xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2) + ) / 2 class LogGaussian(LogNormal): @@ -795,15 +782,18 @@ def __init__(self, mu, name=None, latex_label=None, unit=None, boundary=None): unit=unit, boundary=boundary) self.mu = mu - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Exponential prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return -self.mu * log1p(-val) + with np.errstate(divide="ignore", over="ignore"): + return -self.mu * xp.log1p(-val) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -814,17 +804,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _prob = 0. - else: - _prob = np.exp(-val / self.mu) / self.mu - else: - _prob = np.zeros(val.size) - _prob[val >= self.minimum] = np.exp(-val[val >= self.minimum] / self.mu) / self.mu - return _prob + return xp.exp(self.ln_prob(val, xp=xp)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -835,26 +818,13 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = -val / self.mu - np.log(self.mu) - else: - _ln_prob = -np.inf * np.ones(val.size) - _ln_prob[val >= self.minimum] = -val[val >= self.minimum] / self.mu - np.log(self.mu) - return _ln_prob - - def cdf(self, val): - if isinstance(val, (float, int)): - if val < self.minimum: - _cdf = 0. - else: - _cdf = 1. - np.exp(-val / self.mu) - else: - _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) - return _cdf + with np.errstate(divide="ignore"): + return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum) + + @xp_wrap + def cdf(self, val, *, xp=None): + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + return xp.maximum(1. - xp.exp(-val / self.mu), xp.asarray(0.0)) class StudentT(Prior): @@ -891,26 +861,26 @@ def __init__(self, df, mu=0., scale=1., name=None, latex_label=None, self.mu = mu self.scale = scale - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Student's t-prior. This maps to the inverse CDF. This has been analytically solved for this case. + + Notes + ===== + This explicitly casts to the requested backend, but the computation will be done by scipy. """ - if isinstance(val, (float, int)): - if val == 0: - rescaled = -np.inf - elif val == 1: - rescaled = np.inf - else: - rescaled = stdtrit(self.df, val) * self.scale + self.mu - else: - rescaled = stdtrit(self.df, val) * self.scale + self.mu - rescaled[val == 0] = -np.inf - rescaled[val == 1] = np.inf - return rescaled + with np.errstate(divide="ignore", invalid="ignore"): + return ( + xp.nan_to_num(stdtrit(self.df, val) * self.scale + self.mu) + + xp.log(val > 0) + - xp.log(val < 1) + ) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -921,9 +891,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -934,11 +905,13 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df)\ - - np.log(np.sqrt(np.pi * self.df) * self.scale) - (self.df + 1) / 2 *\ - np.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + return ( + gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df) + - xp.log(xp.asarray((np.pi * self.df)**0.5 * self.scale)) - (self.df + 1) / 2 + * xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + ) - def cdf(self, val): + def cdf(self, val, *, xp=None): return stdtr(self.df, (val - self.mu) / self.scale) @@ -980,15 +953,25 @@ def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, self.alpha = alpha self.beta = beta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Beta prior. This maps to the inverse CDF. This has been analytically solved for this case. + + Notes + ===== + This explicitly casts to the requested backend, but the computation will be done by scipy. """ - return betaincinv(self.alpha, self.beta, val) * (self.maximum - self.minimum) + self.minimum + return ( + xp.asarray(betaincinv(xp.asarray(self.alpha), xp.asarray(self.beta), val)) + * (self.maximum - self.minimum) + + self.minimum + ) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -999,9 +982,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val, xp=xp)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1012,37 +996,19 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - _ln_prob = xlogy(self.alpha - 1, val - self.minimum) + xlogy(self.beta - 1, self.maximum - val)\ - - betaln(self.alpha, self.beta) - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) - - # deal with the fact that if alpha or beta are < 1 you get infinities at 0 and 1 - if isinstance(val, (float, int)): - if np.isfinite(_ln_prob) and self.minimum <= val <= self.maximum: - return _ln_prob - return -np.inf - else: - _ln_prob_sub = np.full_like(val, -np.inf) - idx = np.isfinite(_ln_prob) & (val >= self.minimum) & (val <= self.maximum) - _ln_prob_sub[idx] = _ln_prob[idx] - return _ln_prob_sub - - def cdf(self, val): - if isinstance(val, (float, int)): - if val > self.maximum: - return 1. - elif val < self.minimum: - return 0. - else: - return betainc( - self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum) - ) - else: - _cdf = np.nan_to_num(betainc(self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum))) - _cdf[val < self.minimum] = 0. - _cdf[val > self.maximum] = 1. - return _cdf + ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) + ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) + return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) + + @xp_wrap + def cdf(self, val, *, xp=None): + return xp.nan_to_num( + betainc( + xp.asarray(self.alpha), + xp.asarray(self.beta), + (val - self.minimum) / (self.maximum - self.minimum) + ) + ) + (val > self.maximum) class Logistic(Prior): @@ -1074,27 +1040,19 @@ def __init__(self, mu, scale, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.scale = scale - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Logistic prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - if isinstance(val, (float, int)): - if val == 0: - rescaled = -np.inf - elif val == 1: - rescaled = np.inf - else: - rescaled = self.mu + self.scale * np.log(val / (1. - val)) - else: - rescaled = np.inf * np.ones(val.size) - rescaled[val == 0] = -np.inf - rescaled[(val > 0) & (val < 1)] = self.mu + self.scale\ - * np.log(val[(val > 0) & (val < 1)] / (1. - val[(val > 0) & (val < 1)])) - return rescaled + with np.errstate(divide="ignore"): + val = xp.asarray(val) + return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0))) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1105,9 +1063,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1118,11 +1077,13 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return -(val - self.mu) / self.scale -\ - 2. * np.log(1. + np.exp(-(val - self.mu) / self.scale)) - np.log(self.scale) + with np.errstate(over="ignore"): + return -(val - self.mu) / self.scale -\ + 2. * xp.log1p(xp.exp(-(val - self.mu) / self.scale)) - xp.log(xp.asarray(self.scale)) - def cdf(self, val): - return 1. / (1. + np.exp(-(val - self.mu) / self.scale)) + @xp_wrap + def cdf(self, val, *, xp=None): + return 1. / (1. + xp.exp(-(val - self.mu) / self.scale)) class Cauchy(Prior): @@ -1154,24 +1115,18 @@ def __init__(self, alpha, beta, name=None, latex_label=None, unit=None, boundary self.alpha = alpha self.beta = beta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Cauchy prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - rescaled = self.alpha + self.beta * np.tan(np.pi * (val - 0.5)) - if isinstance(val, (float, int)): - if val == 1: - rescaled = np.inf - elif val == 0: - rescaled = -np.inf - else: - rescaled[val == 1] = np.inf - rescaled[val == 0] = -np.inf - return rescaled + rescaled = self.alpha + self.beta * xp.tan(np.pi * (val - 0.5)) + with np.errstate(divide="ignore", invalid="ignore"): + return rescaled - xp.log(val < 1) + xp.log(val > 0) - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1184,7 +1139,8 @@ def prob(self, val): """ return 1. / self.beta / np.pi / (1. + ((val - self.alpha) / self.beta) ** 2) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val. Parameters @@ -1195,10 +1151,11 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return - np.log(self.beta * np.pi) - np.log(1. + ((val - self.alpha) / self.beta) ** 2) + return - xp.log(xp.asarray(self.beta * np.pi)) - xp.log(1. + ((val - self.alpha) / self.beta) ** 2) - def cdf(self, val): - return 0.5 + np.arctan((val - self.alpha) / self.beta) / np.pi + @xp_wrap + def cdf(self, val, *, xp=None): + return 0.5 + xp.arctan((val - self.alpha) / self.beta) / np.pi class Lorentzian(Cauchy): @@ -1235,15 +1192,17 @@ def __init__(self, k, theta=1., name=None, latex_label=None, unit=None, boundary self.k = k self.theta = theta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Gamma prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return gammaincinv(self.k, val) * self.theta + return xp.asarray(gammaincinv(self.k, val)) * self.theta - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1254,9 +1213,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1267,28 +1227,16 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = xlogy(self.k - 1, val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) - else: - _ln_prob = -np.inf * np.ones(val.size) - idx = (val >= self.minimum) - _ln_prob[idx] = xlogy(self.k - 1, val[idx]) - val[idx] / self.theta\ + with np.errstate(divide="ignore"): + ln_prob = ( + xlogy(xp.asarray(self.k - 1), val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) - return _ln_prob - - def cdf(self, val): - if isinstance(val, (float, int)): - if val < self.minimum: - _cdf = 0. - else: - _cdf = gammainc(self.k, val / self.theta) - else: - _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = gammainc(self.k, val[val >= self.minimum] / self.theta) - return _cdf + ) + xp.log(val >= self.minimum) + return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=xp.inf) + + @xp_wrap + def cdf(self, val, *, xp=None): + return gammainc(xp.asarray(self.k), xp.maximum(val, xp.asarray(self.minimum)) / self.theta) class ChiSquared(Gamma): @@ -1375,9 +1323,11 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") - self.expr = np.exp(self.r) + xp = array_module((mu, sigma, r)) + self.expr = xp.exp(self.r) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1395,9 +1345,10 @@ def rescale(self, val): `_, 2017. """ inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr - return -self.sigma * np.log(np.maximum(inv, 0)) + return -self.sigma * xp.log(xp.maximum(inv, xp.asarray(0))) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1409,12 +1360,13 @@ def prob(self, val): float: Prior probability of val """ return ( - (np.exp((val - self.mu) / self.sigma) + 1)**-1 - / (self.sigma * np.log1p(self.expr)) + (xp.exp((val - self.mu) / self.sigma) + 1)**-1 + / (self.sigma * xp.log1p(xp.asarray(self.expr))) * (val >= self.minimum) ) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val. Parameters @@ -1425,9 +1377,10 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return np.log(self.prob(val)) + return xp.log(self.prob(val)) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Evaluate the CDF of the Fermi-Dirac distribution using a slightly modified form of Equation 23 of [1]_. @@ -1449,10 +1402,10 @@ def cdf(self, val): `_, 2017. """ result = ( - (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) - / np.logaddexp(0, self.r) + (xp.logaddexp(xp.asarray(0.0), -xp.asarray(self.r)) - xp.logaddexp(-val / self.sigma, -xp.asarray(self.r))) + / xp.logaddexp(xp.asarray(0.0), xp.asarray(self.r)) ) - return np.clip(result, 0, 1) + return xp.clip(result, 0, 1) class WeightedDiscreteValues(Prior): @@ -1482,20 +1435,21 @@ def __init__( The unit of the parameter. Used for plotting. """ + xp = array_module(values) nvalues = len(values) - values = np.array(values) + values = xp.asarray(values) if values.shape != (nvalues,): raise ValueError( f"Shape of argument 'values' must be 1d array-like but has shape {values.shape}" ) - minimum = np.min(values) + minimum = xp.min(values) # Small delta added to help with MCMC walking - maximum = np.max(values) * (1 + 1e-15) + maximum = xp.max(values) * (1 + 1e-15) super(WeightedDiscreteValues, self).__init__( name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) self.nvalues = nvalues - sorter = np.argsort(values) + sorter = xp.argsort(values) self._values_array = values[sorter] # inititialization of priors from repr only supports @@ -1503,9 +1457,9 @@ def __init__( self.values = self._values_array.tolist() weights = ( - np.array(weights) / np.sum(weights) + xp.asarray(weights) / xp.sum(weights) if weights is not None - else np.ones(self.nvalues) / self.nvalues + else xp.ones(self.nvalues) / self.nvalues ) # check for consistent shape of input if weights.shape != (self.nvalues,): @@ -1516,14 +1470,15 @@ def __init__( ) self._weights_array = weights[sorter] self.weights = self._weights_array.tolist() - self._lnweights_array = np.log(self._weights_array) + self._lnweights_array = xp.log(self._weights_array) # save cdf for rescaling - _cumulative_weights_array = np.cumsum(self._weights_array) + _cumulative_weights_array = xp.cumsum(self._weights_array) # insert 0 for values smaller than minimum - self._cumulative_weights_array = np.insert(_cumulative_weights_array, 0, 0) + self._cumulative_weights_array = xp.insert(_cumulative_weights_array, 0, 0) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the discrete-value prior. @@ -1538,10 +1493,11 @@ def rescale(self, val): ======= Union[float, array_like]: Rescaled probability """ - index = np.searchsorted(self._cumulative_weights_array[1:], val) - return self._values_array[index] + index = xp.searchsorted(xp.asarray(self._cumulative_weights_array[1:]), val) + return xp.asarray(self._values_array)[index] - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """Return the cumulative prior probability of val. Parameters @@ -1552,10 +1508,11 @@ def cdf(self, val): ======= float: cumulative prior probability of val """ - index = np.searchsorted(self._values_array, val, side="right") - return self._cumulative_weights_array[index] + index = xp.searchsorted(xp.asarray(self._values_array), val, side="right") + return xp.asarray(self._cumulative_weights_array)[index] - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1566,13 +1523,18 @@ def prob(self, val): ======= float: Prior probability of val """ - index = np.searchsorted(self._values_array, val) - index = np.clip(index, 0, self.nvalues - 1) - p = np.where(self._values_array[index] == val, self._weights_array[index], 0) + index = xp.searchsorted(xp.asarray(self._values_array), val) + index = xp.clip(index, 0, self.nvalues - 1) + p = xp.where( + xp.asarray(self._values_array[index]) == val, + xp.asarray(self._weights_array[index]), + xp.asarray(0.0), + ) # turn 0d numpy array to scalar return p[()] - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, xp=None): """Return the logarithmic prior probability of val Parameters @@ -1584,12 +1546,14 @@ def ln_prob(self, val): float: """ - index = np.searchsorted(self._values_array, val) - index = np.clip(index, 0, self.nvalues - 1) - lnp = np.where( - self._values_array[index] == val, self._lnweights_array[index], -np.inf + index = xp.searchsorted(xp.asarray(self._values_array), val) + index = xp.clip(index, 0, self.nvalues - 1) + lnp = xp.where( + xp.asarray(self._values_array[index]) == val, + xp.asarray(self._lnweights_array[index]), + -np.inf, ) - # turn 0d numpy array to scalar + # turn 0d array to scalar return lnp[()] @@ -1713,7 +1677,7 @@ def __init__(self, mode, minimum, maximum, name=None, latex_label=None, unit=Non self.rescaled_minimum = self.minimum - (self.minimum == self.mode) * self.scale self.rescaled_maximum = self.maximum + (self.maximum == self.mode) * self.scale - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from standard uniform to a triangular distribution. @@ -1735,7 +1699,7 @@ def rescale(self, val): self.maximum - above_mode ) * (val >= self.fractional_mode) - def prob(self, val): + def prob(self, val, *, xp=None): """ Return the prior probability of val @@ -1762,7 +1726,7 @@ def prob(self, val): ) return 2.0 * (between_minimum_and_mode + between_mode_and_maximum) / self.scale - def cdf(self, val): + def cdf(self, val, *, xp=None): """ Return the prior cumulative probability at val diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 5ca28de28..dc0c9f14e 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -2,7 +2,9 @@ import json import os import re +import warnings +import array_api_compat as aac import numpy as np import scipy.stats @@ -12,8 +14,9 @@ decode_bilby_json, logger, get_dict_with_properties, - WrappedInterp1d as interp1d, ) +from ...compat.patches import interp +from ...compat.utils import xp_wrap class Prior(object): @@ -57,6 +60,27 @@ def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf, self.boundary = boundary self._is_fixed = False + def __init_subclass__(cls): + for method_name in ["prob", "ln_prob", "rescale", "cdf"]: + method = getattr(cls, method_name, None) + if method is not None: + from inspect import signature + + sig = signature(method) + if "xp" not in sig.parameters: + warnings.warn( + f"The method {method_name} of the prior class " + f"{cls.__name__} does not accept an 'xp' keyword " + "argument. This may cause some behaviour to fail. " + "Please see the bilby documentation for more " + "information: https://bilby-dev.github.io/bilby/" + "array_api.html" + f" {sig}", + DeprecationWarning, + stacklevel=2, + ) + setattr(cls, method_name, xp_wrap(method, no_xp=True)) + def __call__(self): """Overrides the __call__ special method. Calls the sample method. @@ -106,7 +130,7 @@ def __eq__(self, other): for key in this_dict: if key == "least_recently_sampled": continue - if isinstance(this_dict[key], np.ndarray): + if aac.is_array_api_obj(this_dict[key]): if not np.array_equal(this_dict[key], other_dict[key]): return False elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))): @@ -116,7 +140,7 @@ def __eq__(self, other): return False return True - def sample(self, size=None): + def sample(self, size=None, *, random_state=None): """Draw a sample from the prior Parameters @@ -130,11 +154,16 @@ def sample(self, size=None): """ from ..utils import random + rng = random.resolve_random_state(random_state) + + if isinstance(size, (int, np.integer)): + size = (size,) - self.least_recently_sampled = self.rescale(random.rng.uniform(0, 1, size)) + unit = rng.uniform(low=0, high=1, size=size) + self.least_recently_sampled = self.rescale(unit) return self.least_recently_sampled - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. @@ -152,7 +181,7 @@ def rescale(self, val): """ return None - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val, this should be overwritten Parameters @@ -166,24 +195,22 @@ def prob(self, val): """ return np.nan - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Generic method to calculate CDF, can be overwritten in subclass """ from scipy.integrate import cumulative_trapezoid if np.any(np.isinf([self.minimum, self.maximum])): raise ValueError( "Unable to use the generic CDF calculation for priors with" "infinite support") - x = np.linspace(self.minimum, self.maximum, 1000) - pdf = self.prob(x) + x = xp.linspace(self.minimum, self.maximum, 1000) + pdf = self.prob(x, xp=xp) cdf = cumulative_trapezoid(pdf, x, initial=0) - interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, - fill_value=(0, 1)) - output = interp(val) - if isinstance(val, (int, float)): - output = float(output) - return output - - def ln_prob(self, val): + output = interp(val, x, cdf / cdf[-1], left=0, right=1) + return output[()] + + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the prior ln probability of val, this should be overwritten Parameters @@ -196,7 +223,7 @@ def ln_prob(self, val): """ with np.errstate(divide='ignore'): - return np.log(self.prob(val)) + return xp.log(self.prob(val, xp=xp)) def is_in_prior_range(self, val): """Returns True if val is in the prior boundaries, zero otherwise @@ -473,7 +500,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, latex_label=latex_label, unit=unit) self._is_fixed = True - def prob(self, val): + def prob(self, val, *, xp=None): return (val > self.minimum) & (val < self.maximum) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index ad142c2a9..b49aecc37 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -1,9 +1,12 @@ +import numpy as np + from .base import Prior, PriorException from .interpolated import Interped from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac from ..utils import infer_args_from_method, infer_parameters_from_function +from ...compat.utils import xp_wrap def conditional_prior_factory(prior_class): @@ -59,7 +62,7 @@ def condition_func(reference_params, y): self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) - def sample(self, size=None, **required_variables): + def sample(self, size=None, *, random_state=None, **required_variables): """Draw a sample from the prior Parameters @@ -75,11 +78,18 @@ def sample(self, size=None, **required_variables): """ from ..utils import random + rng = random.resolve_random_state(random_state) + + if isinstance(size, int | np.integer): + size = (size,) - self.least_recently_sampled = self.rescale(random.rng.uniform(0, 1, size), **required_variables) + self.least_recently_sampled = self.rescale( + rng.uniform(0, 1, size), **required_variables + ) return self.least_recently_sampled - def rescale(self, val, **required_variables): + @xp_wrap + def rescale(self, val, *, xp=None, **required_variables): """ 'Rescale' a sample from the unit line element to the prior. @@ -93,9 +103,10 @@ def rescale(self, val, **required_variables): """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).rescale(val) + return super(ConditionalPrior, self).rescale(val, xp=xp) - def prob(self, val, **required_variables): + @xp_wrap + def prob(self, val, *, xp=None, **required_variables): """Return the prior probability of val. Parameters @@ -111,9 +122,10 @@ def prob(self, val, **required_variables): float: Prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).prob(val) + return super(ConditionalPrior, self).prob(val, xp=xp) - def ln_prob(self, val, **required_variables): + @xp_wrap + def ln_prob(self, val, *, xp=None, **required_variables): """Return the natural log prior probability of val. Parameters @@ -129,9 +141,10 @@ def ln_prob(self, val, **required_variables): float: Natural log prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).ln_prob(val) + return super(ConditionalPrior, self).ln_prob(val, xp=xp) - def cdf(self, val, **required_variables): + @xp_wrap + def cdf(self, val, *, xp=None, **required_variables): """Return the cdf of val. Parameters @@ -147,7 +160,7 @@ def cdf(self, val, **required_variables): float: CDF of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).cdf(val) + return super(ConditionalPrior, self).cdf(val, xp=xp) def update_conditions(self, **required_variables): """ @@ -164,6 +177,7 @@ class depending on the required variables it depends on. self.reference_params will be used. """ + required_variables.pop("xp", None) if sorted(list(required_variables)) == sorted(self.required_variables): parameters = self.condition_func(self.reference_params.copy(), **required_variables) for key, value in parameters.items(): diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 3ac54622e..65688a620 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -5,6 +5,7 @@ from io import open as ioopen from warnings import warn +import array_api_compat as aac import numpy as np from .analytical import DeltaFunction @@ -16,6 +17,8 @@ BilbyJsonEncoder, decode_bilby_json, ) +from ..utils.random import resolve_random_state, random_array_module +from ...compat.utils import array_module, xp_wrap class PriorDict(dict): @@ -54,12 +57,16 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): else: self.conversion_function = self.default_conversion_function - def evaluate_constraints(self, sample): + def __hash__(self): + return hash(str(self)) + + @xp_wrap + def evaluate_constraints(self, sample, *, xp=None): out_sample = self.conversion_function(sample) try: - prob = np.ones_like(next(iter(out_sample.values()))) + prob = xp.ones_like(next(iter(out_sample.values())), dtype=bool) except TypeError: - prob = np.ones_like(out_sample) + prob = xp.ones_like(out_sample, dtype=bool) for key in self: if isinstance(self[key], Constraint) and key in out_sample: prob *= self[key].prob(out_sample[key]) @@ -349,7 +356,7 @@ def fill_priors(self, likelihood=None, default_priors_file=None): for key in self: self.test_redundancy(key) - def sample(self, size=None): + def sample(self, size=None, *, random_state=None): """Draw samples from the prior set Parameters @@ -361,9 +368,13 @@ def sample(self, size=None): ======= dict: Dictionary of the samples """ - return self.sample_subset_constrained(keys=list(self.keys()), size=size) + return self.sample_subset_constrained( + keys=list(self.keys()), size=size, random_state=random_state + ) - def sample_subset_constrained_as_array(self, keys=iter([]), size=None): + def sample_subset_constrained_as_array( + self, keys=iter([]), size=None, *, random_state=None + ): """Return an array of samples Parameters @@ -378,12 +389,15 @@ def sample_subset_constrained_as_array(self, keys=iter([]), size=None): array: array_like An array of shape (len(key), size) of the samples (ordered by keys) """ - samples_dict = self.sample_subset_constrained(keys=keys, size=size) - samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()} + samples_dict = self.sample_subset_constrained( + keys=keys, size=size, random_state=random_state + ) + xp = random_array_module(random_state) + samples_dict = {key: xp.atleast_1d(val) for key, val in samples_dict.items()} samples_list = [samples_dict[key] for key in keys] - return np.array(samples_list) + return xp.stack(samples_list) - def sample_subset(self, keys=iter([]), size=None): + def sample_subset(self, keys=iter([]), size=None, *, random_state=None): """Draw samples from the prior set for parameters which are not a DeltaFunction Parameters @@ -403,7 +417,7 @@ def sample_subset(self, keys=iter([]), size=None): if isinstance(self[key], Constraint): continue elif isinstance(self[key], Prior): - samples[key] = self[key].sample(size=size) + samples[key] = self[key].sample(size=size, random_state=random_state) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -426,7 +440,7 @@ def fixed_keys(self): def constraint_keys(self): return [k for k, p in self.items() if isinstance(p, Constraint)] - def sample_subset_constrained(self, keys=iter([]), size=None): + def sample_subset_constrained(self, keys=iter([]), size=None, *, random_state=None): """ Sample a subset of priors while ensuring constraints are satisfied. @@ -441,8 +455,10 @@ def sample_subset_constrained(self, keys=iter([]), size=None): ======= dict: Dictionary of valid samples. """ + rng = resolve_random_state(random_state) + if not any(isinstance(self[key], Constraint) for key in self): - return self.sample_subset(keys=keys, size=size) + return self.sample_subset(keys=keys, size=size, random_state=rng) efficiency_warning_was_issued = False @@ -458,10 +474,10 @@ def check_efficiency(n_tested, n_valid): n_tested_samples, n_valid_samples = 0, 0 if size is None or size == 1: while True: - sample = self.sample_subset(keys=keys, size=size) + sample = self.sample_subset(keys=keys, size=size, random_state=rng) is_valid = self.evaluate_constraints(sample) n_tested_samples += 1 - n_valid_samples += int(is_valid) + n_valid_samples += int(is_valid.item()) check_efficiency(n_tested_samples, n_valid_samples) if is_valid: return sample @@ -470,20 +486,23 @@ def check_efficiency(n_tested, n_valid): for key in keys.copy(): if isinstance(self[key], Constraint): del keys[keys.index(key)] - all_samples = {key: np.array([]) for key in keys} + xp = random_array_module(random_state) + all_samples = {key: xp.asarray([]) for key in keys} _first_key = list(all_samples.keys())[0] while len(all_samples[_first_key]) < needed: - samples = self.sample_subset(keys=keys, size=needed) - keep = np.array(self.evaluate_constraints(samples), dtype=bool) + samples = self.sample_subset(keys=keys, size=needed, random_state=rng) + keep = self.evaluate_constraints(samples) for key in keys: - all_samples[key] = np.hstack( + all_samples[key] = xp.hstack( [all_samples[key], samples[key][keep].flatten()] ) n_tested_samples += needed - n_valid_samples += np.sum(keep) + n_valid_samples += int(xp.sum(keep)) check_efficiency(n_tested_samples, n_valid_samples) + if not isinstance(size, tuple): + size = (size,) all_samples = { - key: np.reshape(all_samples[key][:needed], size) for key in keys + key: xp.reshape(all_samples[key][:needed], size) for key in keys } return all_samples @@ -523,7 +542,7 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk): factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + def prob(self, sample, *, normalized=True, xp=None, **kwargs): """ Parameters @@ -538,29 +557,31 @@ def prob(self, sample, **kwargs): float: Joint probability of all individual sample probabilities """ - prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) + if xp is None: + xp = array_module(sample.values()) + prob = xp.prod(xp.stack([self[key].prob(sample[key], xp=xp) for key in sample]), **kwargs) - return self.check_prob(sample, prob) + return self.check_prob(sample, prob, normalized=normalized, xp=xp) - def check_prob(self, sample, prob): - ratio = self.normalize_constraint_factor(tuple(sample.keys())) - if np.all(prob == 0.0): + def check_prob(self, sample, prob, *, normalized=True, xp=None): + if normalized: + ratio = self.normalize_constraint_factor(tuple(sample.keys())) + else: + ratio = 1 + if not aac.is_jax_namespace(xp) and xp.all(prob == 0.0): return prob * ratio else: if isinstance(prob, float): - if self.evaluate_constraints(sample): + if self.evaluate_constraints(sample, xp=xp): return prob * ratio else: return 0.0 else: - constrained_prob = np.zeros_like(prob) - in_bounds = np.isfinite(prob) - subsample = {key: sample[key][in_bounds] for key in sample} - keep = np.array(self.evaluate_constraints(subsample), dtype=bool) - constrained_prob[in_bounds] = prob[in_bounds] * keep * ratio + keep = self.evaluate_constraints(sample, xp=xp) + constrained_prob = xp.where(keep, prob * ratio, 0.0) return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + def ln_prob(self, sample, axis=None, *, normalized=True, xp=None): """ Parameters @@ -579,32 +600,34 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) - - def check_ln_prob(self, sample, ln_prob, normalized=True): + if xp is None and isinstance(sample, dict): + xp = array_module(sample.values()) + elif xp is None: + # assume input is a dataframe + xp = array_module(sample.values) + ln_prob = xp.sum(xp.stack([self[key].ln_prob(sample[key], xp=xp) for key in sample]), axis=axis) + return self.check_ln_prob(sample, ln_prob, normalized=normalized, xp=xp) + + def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): if normalized: ratio = self.normalize_constraint_factor(tuple(sample.keys())) else: ratio = 1 - if np.all(np.isinf(ln_prob)): + if not aac.is_jax_namespace(xp) and xp.all(xp.isfinite(ln_prob)): return ln_prob else: if isinstance(ln_prob, float): - if np.all(self.evaluate_constraints(sample)): - return ln_prob + np.log(ratio) + if xp.all(self.evaluate_constraints(sample, xp=xp)): + return ln_prob + xp.log(ratio) else: return -np.inf else: - constrained_ln_prob = -np.inf * np.ones_like(ln_prob) - in_bounds = np.isfinite(ln_prob) - subsample = {key: sample[key][in_bounds] for key in sample} - keep = np.log(np.array(self.evaluate_constraints(subsample), dtype=bool)) - constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) + keep = self.evaluate_constraints(sample, xp=xp) + constrained_ln_prob = xp.where(keep, ln_prob + xp.log(ratio), -xp.inf) return constrained_ln_prob - def cdf(self, sample): + @xp_wrap + def cdf(self, sample, *, xp=None): """Evaluate the cumulative distribution function at the provided points Parameters @@ -618,10 +641,10 @@ def cdf(self, sample): """ return sample.__class__( - {key: self[key].cdf(sample) for key, sample in sample.items()} + {key: self[key].cdf(sample, xp=xp) for key, sample in sample.items()} ) - def rescale(self, keys, theta): + def rescale(self, keys, theta, *, xp=None): """Rescale samples from unit cube to prior Parameters @@ -635,9 +658,12 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - return list( - [self[key].rescale(sample) for key, sample in zip(keys, theta)] - ) + if isinstance(theta, {}.values().__class__): + theta = list(theta) + if xp is None: + xp = array_module(theta) + + return xp.asarray([self[key].rescale(sample, xp=xp) for key, sample in zip(keys, theta)]) def test_redundancy(self, key, disable_logging=False): """Empty redundancy test, should be overwritten in subclasses""" @@ -737,7 +763,7 @@ def _check_conditions_resolved(self, key, sampled_keys): conditions_resolved = False return conditions_resolved - def sample_subset(self, keys=iter([]), size=None): + def sample_subset(self, keys=iter([]), size=None, *, random_state=None): self.convert_floats_to_delta_functions() add_delta_keys = [ key @@ -757,18 +783,24 @@ def sample_subset(self, keys=iter([]), size=None): if isinstance(self[key], Prior): try: samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + size=size, + random_state=random_state, + **subset_dict.get_required_variables(key), ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. required_variables = subset_dict.get_required_variables(key) - samples[key] = np.zeros(size) - for i in range(size): + if size is None: + shape = () + else: + shape = (size,) + samples[key] = np.zeros(shape) + for i in range(size if size is not None else 1): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = subset_dict[key].sample(**rvars, random_state=random_state) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -790,7 +822,7 @@ def get_required_variables(self, key): for k in getattr(self[key], "required_variables", []) } - def prob(self, sample, **kwargs): + def prob(self, sample, *, normalized=True, xp=None, **kwargs): """ Parameters @@ -806,14 +838,16 @@ def prob(self, sample, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ - self[key].prob(sample[key], **self.get_required_variables(key)) + if xp is None: + xp = array_module(sample.values()) + res = xp.asarray([ + self[key].prob(sample[key], **self.get_required_variables(key), xp=xp) for key in sample - ] - prob = np.prod(res, **kwargs) - return self.check_prob(sample, prob) + ]) + prob = xp.prod(res, **kwargs) + return self.check_prob(sample, prob, normalized=normalized, xp=xp) - def ln_prob(self, sample, axis=None, normalized=True): + def ln_prob(self, sample, *, axis=None, normalized=True, xp=None): """ Parameters @@ -832,18 +866,20 @@ def ln_prob(self, sample, axis=None, normalized=True): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ - self[key].ln_prob(sample[key], **self.get_required_variables(key)) + if xp is None: + xp = array_module(sample.values()) + res = xp.asarray([ + self[key].ln_prob(sample[key], **self.get_required_variables(key), xp=xp) for key in sample - ] - ln_prob = np.sum(res, axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) + ]) + ln_prob = xp.sum(res, axis=axis) + return self.check_ln_prob(sample, ln_prob, normalized=normalized, xp=xp) - def cdf(self, sample): + @xp_wrap + def cdf(self, sample, *, xp=None): self._prepare_evaluation(*zip(*sample.items())) res = { - key: self[key].cdf(sample[key], **self.get_required_variables(key)) + key: self[key].cdf(sample[key], **self.get_required_variables(key), xp=xp) for key in sample } return sample.__class__(res) @@ -862,8 +898,11 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ + if isinstance(theta, {}.values().__class__): + theta = list(theta) + xp = array_module(theta) + keys = list(keys) - theta = list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() @@ -880,30 +919,14 @@ def rescale(self, keys, theta): elif isinstance(self[key], JointPrior): joint[self[key].dist.distname].append(key) for names in joint.values(): - # this is needed to unpack how joint prior rescaling works - # as an example of a joint prior over {a, b, c, d} we might - # get the following based on the order within the joint prior - # {a: [], b: [], c: [1, 2, 3, 4], d: []} - # -> [1, 2, 3, 4] - # -> {a: 1, b: 2, c: 3, d: 4} - values = list() for key in names: - values = np.concatenate([values, result[key]]) - for key, value in zip(names, values): - result[key] = value - - def safe_flatten(value): - """ - this is gross but can be removed whenever we switch to returning - arrays, flatten converts 0-d arrays to 1-d so has to be special - cased - """ - if isinstance(value, (float, int, np.int64)): - return value - else: - return result[key].flatten() + if result[key] is None: + continue + for subkey, val in zip(self[key].dist.names, result[key]): + self[subkey].least_recently_sampled = val + result[subkey] = val - return [safe_flatten(result[key]) for key in keys] + return xp.asarray([result[key] for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 5fbf8f9c1..1983877d7 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -2,7 +2,9 @@ from scipy.integrate import trapezoid from .base import Prior -from ..utils import logger, WrappedInterp1d as interp1d +from ..utils import logger +from ..utils.calculus import interp1d +from ...compat.utils import xp_wrap class Interped(Prior): @@ -64,7 +66,8 @@ def __eq__(self, other): return True return False - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -75,18 +78,20 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return self.probability_density(val) + return self.probability_density(val)[()] - def cdf(self, val): - return self.cumulative_distribution(val) + @xp_wrap + def cdf(self, val, *, xp=None): + return self.cumulative_distribution(val)[()] - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. This maps to the inverse CDF. This is done using interpolation. """ - return self.inverse_cumulative_distribution(val) + return self.inverse_cumulative_distribution(val)[()] @property def minimum(self): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e3..34fffc740 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -1,5 +1,7 @@ import re +import array_api_compat as aac +import array_api_extra as xpx import numpy as np import scipy.stats from scipy.special import erfinv @@ -7,6 +9,7 @@ from .base import Prior, PriorException from ..utils import logger, infer_args_from_method, get_dict_with_properties from ..utils import random +from ...compat.utils import xp_wrap class BaseJointPriorDist(object): @@ -172,13 +175,14 @@ def _split_repr(cls, string): kwargs[key.strip()] = arg return kwargs - def prob(self, samp): + @xp_wrap + def prob(self, samp, *, xp=None): """ Get the probability of a sample. For bounded priors the probability will not be properly normalised. """ - return np.exp(self.ln_prob(samp)) + return xp.exp(self.ln_prob(samp, xp=xp)) def _check_samp(self, value): """ @@ -216,7 +220,8 @@ def _check_samp(self, value): break return samp, outbounds - def ln_prob(self, value): + @xp_wrap + def ln_prob(self, value, *, xp=None): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. @@ -230,14 +235,12 @@ def ln_prob(self, value): """ samp, outbounds = self._check_samp(value) - lnprob = -np.inf * np.ones(samp.shape[0]) - lnprob = self._ln_prob(samp, lnprob, outbounds) - if samp.shape[0] == 1: - return lnprob[0] - else: - return lnprob + lnprob = -np.inf * xp.ones(samp.shape[0]) + lnprob = self._ln_prob(samp, lnprob, outbounds, xp=xp) + return lnprob[()] - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. **this method needs overwritten by child class** @@ -261,7 +264,7 @@ def _ln_prob(self, samp, lnprob, outbounds): """ return lnprob - def sample(self, size=1, **kwargs): + def sample(self, size=1, *, random_state=None, **kwargs): """ Draw, and set, a sample from the Dist, accompanying method _sample needs to overwritten @@ -273,14 +276,11 @@ def sample(self, size=1, **kwargs): if size is None: size = 1 - samps = self._sample(size=size, **kwargs) + samps = self._sample(size=size, random_state=random_state, **kwargs) for i, name in enumerate(self.names): - if size == 1: - self.current_sample[name] = samps[:, i].flatten()[0] - else: - self.current_sample[name] = samps[:, i].flatten() + self.current_sample[name] = samps[:, i].flatten()[()] - def _sample(self, size, **kwargs): + def _sample(self, size, *, random_state=None, **kwargs): """ Draw, and set, a sample from the joint dist (**needs to be ovewritten by child class**) @@ -289,13 +289,10 @@ def _sample(self, size, **kwargs): size: int number of samples to generate, defaults to 1 """ - samps = np.zeros((size, len(self))) - """ - Here is where the subclass where overwrite sampling method - """ - return samps + raise NotImplementedError - def rescale(self, value, **kwargs): + @xp_wrap + def rescale(self, value, *, xp=None, **kwargs): """ Rescale from a unit hypercube to JointPriorDist. Note that no bounds are applied in the rescale function. (child classes need to @@ -317,7 +314,7 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + samp = xp.asarray(value) if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,7 +324,9 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) - return np.squeeze(samp) + if samp.shape[0] == 1: + samp = xp.squeeze(samp, axis=0) + return samp def _rescale(self, samp, **kwargs): """ @@ -611,7 +610,8 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1]) ) - def _rescale(self, samp, **kwargs): + @xp_wrap + def _rescale(self, samp, *, xp=None, **kwargs): try: mode = kwargs["mode"] except KeyError: @@ -626,28 +626,30 @@ def _rescale(self, samp, **kwargs): samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] + samp = xp.asarray(self.mus[mode]) + xp.asarray(self.sigmas[mode]) * xp.einsum( + "ij,kj->ik", samp * self.sqeigvalues[mode], xp.asarray(self.eigvectors[mode]) ) return samp - def _sample(self, size, **kwargs): + def _sample(self, size, *, random_state=None, **kwargs): try: mode = kwargs["mode"] except KeyError: mode = None + rng = random.resolve_random_state(random_state) + if mode is None: if self.nmodes == 1: mode = 0 else: if size == 1: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] + mode = np.argwhere(self.cumweights - rng.uniform(0, 1) > 0)[0][0] else: # pick modes mode = [ np.argwhere(self.cumweights - r > 0)[0][0] - for r in random.rng.uniform(0, 1, size) + for r in rng.uniform(0, 1, size) ] samps = np.zeros((size, len(self))) @@ -655,7 +657,7 @@ def _sample(self, size, **kwargs): inbound = False while not inbound: # sample the multivariate Gaussian keys - vals = random.rng.uniform(0, 1, len(self)) + vals = rng.uniform(0, 1, len(self)) if isinstance(mode, list): samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) @@ -673,18 +675,22 @@ def _sample(self, size, **kwargs): if not outbound: inbound = True - return samps + xp = aac.array_namespace(vals) + return xp.asarray(samps) - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): for j in range(samp.shape[0]): # loop over the modes and sum the probabilities for i in range(self.nmodes): # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + lnprob = xpx.at(lnprob, j).set( + xp.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - xp.asarray(self.logprodsigmas[i])) + ) # set out-of-bounds values to -inf - lnprob[outbounds] = -np.inf + lnprob = xp.where(xp.asarray(outbounds), -np.inf, lnprob) return lnprob def __eq__(self, other): @@ -778,7 +784,8 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) - def rescale(self, val, **kwargs): + @xp_wrap + def rescale(self, val, *, xp=None, **kwargs): """ Scale a unit hypercube sample to the prior. @@ -793,18 +800,21 @@ def rescale(self, val, **kwargs): float: A sample from the prior parameter. """ - self.dist.rescale_parameters[self.name] = val if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T + # print(self.dist.rescale_parameters) + values = xp.stack([ + xp.asarray(val) for val in self.dist.rescale_parameters.values() + ]).T + # values = xp.asarray(list(self.dist.rescale_parameters.values())).T samples = self.dist.rescale(values, **kwargs) self.dist.reset_rescale() return samples else: return [] # return empty list - def sample(self, size=1, **kwargs): + def sample(self, size=1, *, random_state=None, **kwargs): """ Draw a sample from the prior. @@ -829,7 +839,7 @@ def sample(self, size=1, **kwargs): if len(self.dist.current_sample) == 0: # generate a sample - self.dist.sample(size=size, **kwargs) + self.dist.sample(size=size, random_state=random_state, **kwargs) sample = self.dist.current_sample[self.name] @@ -840,9 +850,10 @@ def sample(self, size=1, **kwargs): # reset samples self.dist.reset_sampled() self.least_recently_sampled = sample - return sample + return sample.squeeze() - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """ Return the natural logarithm of the prior probability. Note that this will not be correctly normalised if there are bounds on the @@ -864,25 +875,16 @@ def ln_prob(self, val): values = list(self.dist.requested_parameters.values()) # check for the same number of values for each parameter - for i in range(len(self.dist) - 1): - if isinstance(values[i], (list, np.ndarray)) or isinstance( - values[i + 1], (list, np.ndarray) - ): - if isinstance(values[i], (list, np.ndarray)) and isinstance( - values[i + 1], (list, np.ndarray) - ): - if len(values[i]) != len(values[i + 1]): - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) - else: - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) + shapes = set() + for v in values: + shapes.add(xp.asarray(v).shape) + if len(shapes) > 1: + raise ValueError( + "Each parameter must have the same " + "number of requested values." + ) - lnp = self.dist.ln_prob(np.asarray(values).T) + lnp = self.dist.ln_prob(xp.stack(values).T) # reset the requested parameters self.dist.reset_request() @@ -901,9 +903,10 @@ def ln_prob(self, val): if len(val) == 1: return 0.0 else: - return np.zeros_like(val) + return xp.zeros_like(val) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -917,7 +920,7 @@ def prob(self, val): the p value for the prior at given sample """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val, xp=xp)) class MultivariateGaussian(JointPrior): diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 6910be608..2ac310f55 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -1,8 +1,8 @@ -from numbers import Number import numpy as np from .base import Prior from ..utils import logger +from ...compat.utils import xp_wrap class SlabSpikePrior(Prior): @@ -72,7 +72,8 @@ def slab_fraction(self): def _find_inverse_cdf_fraction_before_spike(self): return float(self.slab.cdf(self.spike_location)) * self.slab_fraction - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. @@ -85,28 +86,22 @@ def rescale(self, val): ======= array_like: Associated prior value with input value. """ - original_is_number = isinstance(val, Number) - val = np.atleast_1d(val) - lower_indices = val < self.inverse_cdf_below_spike - intermediate_indices = np.logical_and( - self.inverse_cdf_below_spike <= val, - val <= (self.inverse_cdf_below_spike + self.spike_height)) - higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) - - res = np.zeros(len(val)) - res[lower_indices] = self._contracted_rescale(val[lower_indices]) - res[intermediate_indices] = self.spike_location - res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height) - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + higher_indices = val >= (self.inverse_cdf_below_spike + self.spike_height) + + slab_scaled = self._contracted_rescale( + val - self.spike_height * higher_indices, xp=xp + ) + + res = xp.where( + lower_indices | higher_indices, + slab_scaled, + xp.asarray(self.spike_location), + ) return res - def _contracted_rescale(self, val): + @xp_wrap + def _contracted_rescale(self, val, *, xp=None): """ Contracted version of the rescale function that implements the `rescale` function on the pure slab part of the prior. @@ -120,9 +115,10 @@ def _contracted_rescale(self, val): ======= array_like: Associated prior value with input value. """ - return self.slab.rescale(val / self.slab_fraction) + return self.slab.rescale(val / self.slab_fraction, xp=xp) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Returns np.inf for the spike location @@ -134,19 +130,13 @@ def prob(self, val): ======= array_like: Prior probability of val """ - original_is_number = isinstance(val, Number) res = self.slab.prob(val) * self.slab_fraction - res = np.atleast_1d(res) - res[val == self.spike_location] = np.inf - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + with np.errstate(invalid="ignore"): + res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the Log prior probability of val. Returns np.inf for the spike location @@ -158,19 +148,13 @@ def ln_prob(self, val): ======= array_like: Prior probability of val """ - original_is_number = isinstance(val, Number) res = self.slab.ln_prob(val) + np.log(self.slab_fraction) - res = np.atleast_1d(res) - res[val == self.spike_location] = np.inf - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + with np.errstate(divide="ignore"): + res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Return the CDF of the prior. This calls to the slab CDF and adds a discrete step at the spike location. @@ -184,6 +168,6 @@ def cdf(self, val): array_like: CDF value of val """ - res = self.slab.cdf(val) * self.slab_fraction - res += self.spike_height * (val > self.spike_location) + res = self.slab.cdf(val, xp=xp) * self.slab_fraction + res += (val > self.spike_location) * self.spike_height return res diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index e10ce6111..490195930 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,10 +1,12 @@ import math +import array_api_compat as aac import numpy as np -from scipy.interpolate import interp1d, RectBivariateSpline +from scipy.interpolate import RectBivariateSpline, interp1d as _interp1d from scipy.special import logsumexp from .log import logger +from ...compat.utils import array_module, xp_wrap, BILBY_ARRAY_API def derivatives( @@ -152,7 +154,8 @@ def derivatives( return grads -def logtrapzexp(lnf, dx): +@xp_wrap +def logtrapzexp(lnf, dx, *, xp=np): """ Perform trapezium rule integration for the logarithm of a function on a grid. @@ -171,22 +174,48 @@ def logtrapzexp(lnf, dx): lnfdx1 = lnf[:-1] lnfdx2 = lnf[1:] - if isinstance(dx, (int, float)): - C = np.log(dx / 2.0) - elif isinstance(dx, (list, np.ndarray)): - if len(dx) != len(lnf) - 1: - raise ValueError( - "Step size array must have length one less than the function length" - ) - lndx = np.log(dx) - lnfdx1 = lnfdx1.copy() + lndx - lnfdx2 = lnfdx2.copy() + lndx - C = -np.log(2.0) - else: - raise TypeError("Step size must be a single value or array-like") + try: + dx = xp.asarray(dx) + except TypeError: + raise TypeError(f"Step size dx={dx} could not be converted to an array") + + if dx.ndim > 0 and len(dx) != len(lnf) - 1: + raise ValueError( + "Step size array must have length one less than the function length" + ) + lnfdx1 = lnfdx1 + xp.log(dx) + lnfdx2 = lnfdx2 + xp.log(dx) + + return logsumexp(xp.asarray([logsumexp(lnfdx1), logsumexp(lnfdx2)])) - np.log(2) + - return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)]) +class interp1d(_interp1d): + + def __call__(self, x): + if not BILBY_ARRAY_API: + return super().__call__(x) + + import array_api_compat as aac + + xp = array_module(x) + if aac.is_numpy_namespace(xp): + return super().__call__(x) + else: + from ...compat.patches import interp + + if isinstance(self.fill_value, tuple): + left, right = self.fill_value + else: + left = right = self.fill_value + + return interp( + x, + xp.asarray(self.x), + xp.asarray(self.y), + left=left, + right=right, + ) class BoundedRectBivariateSpline(RectBivariateSpline): @@ -202,9 +231,23 @@ def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): if self.y_max is None: self.y_max = max(y) self.fill_value = fill_value + self.x = x + self.y = y + self.z = z super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s) def __call__(self, x, y, dx=0, dy=0, grid=False): + xp = array_module([x, y]) + if aac.is_numpy_namespace(xp): + return self._call_scipy(x, y, dx=dx, dy=dy, grid=grid) + elif aac.is_jax_namespace(xp): + return self._call_jax(x, y) + else: + raise NotImplementedError( + f"BoundedRectBivariateSpline not implemented for {xp.__name__} backend" + ) + + def _call_scipy(self, x, y, dx=0, dy=0, grid=False): result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) out_of_bounds_y = (y < self.y_min) | (y > self.y_max) @@ -218,6 +261,20 @@ def __call__(self, x, y, dx=0, dy=0, grid=False): else: return result + def _call_jax(self, x, y): + import jax.numpy as jnp + from interpax import interp2d + + return interp2d( + jnp.asarray(x), + jnp.asarray(y), + jnp.asarray(self.x), + jnp.asarray(self.y), + jnp.asarray(self.z), + extrap=self.fill_value if self.fill_value is not None else False, + method="cubic2", + ) + class WrappedInterp1d(interp1d): """ diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 8299d6816..f4c9bc4e8 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -8,6 +8,7 @@ from pathlib import Path from datetime import timedelta +import array_api_compat as aac import numpy as np import pandas as pd @@ -59,8 +60,12 @@ def default(self, obj): return encode_astropy_unit(obj) except ImportError: logger.debug("Cannot import astropy, cannot write cosmological priors") - if isinstance(obj, np.ndarray): - return {"__array__": True, "content": obj.tolist()} + if aac.is_array_api_obj(obj): + return { + "__array__": True, + "__array_namespace__": aac.get_namespace(obj).__name__, + "content": obj.tolist(), + } if isinstance(obj, complex): return {"__complex__": True, "real": obj.real, "imag": obj.imag} if isinstance(obj, pd.DataFrame): @@ -320,7 +325,9 @@ def decode_bilby_json(dct): if dct.get("__astropy_unit__", False): return decode_astropy_unit(dct) if dct.get("__array__", False): - return np.asarray(dct["content"]) + namespace = dct.get("__array_namespace__", "numpy") + xp = import_module(namespace) + return xp.asarray(dct["content"]) if dct.get("__complex__", False): return complex(dct["real"], dct["imag"]) if dct.get("__dataframe__", False): @@ -438,6 +445,10 @@ def encode_for_hdf5(key, item): if item.dtype.kind == 'U': logger.debug(f'converting dtype {item.dtype} for hdf5') item = np.array(item, dtype='S') + elif aac.is_array_api_obj(item): + # temporarily dump all arrays as numpy arrays, we should figure ou + # how to properly deserialize them + item = np.asarray(item) if isinstance(item, (np.ndarray, int, float, complex, str, bytes)): output = item elif isinstance(item, np.random.Generator): diff --git a/bilby/core/utils/random.py b/bilby/core/utils/random.py index ccb7654c6..74efd9198 100644 --- a/bilby/core/utils/random.py +++ b/bilby/core/utils/random.py @@ -27,8 +27,12 @@ import sys import warnings +import array_api_compat as aac +import numpy as np from numpy.random import default_rng, SeedSequence +from ...compat.utils import BILBY_ARRAY_API + def __getattr__(name): if name == "rng": @@ -104,3 +108,82 @@ def generate_seeds(nseeds): A SeedSequence object containing the generated seeds. """ return SeedSequence(Generator.rng.integers(0, 2**63 - 1, size=4)).spawn(nseeds) + + +def resolve_random_state(random_state): + """ + Resolve the provided random state into a random number generator. + + Parameters + ========== + random_state: None, int, np.random.Generator, or jax.random.KeyArray + The random state to resolve. + If None, the default random generator will be used. + If an int, a new :code:`numpy.random.default_rng` object will be + created with that seed. + If a :code:`numpy.random.Generator`, it will be returned as is. + If a :code:`jax.random.KeyArray`, a corresponding + :code:`orng.ArrayRNG` generator will be created and returned. + + Returns + ======= + np.random.Generator or orng.ArrayRNG + The resolved random number generator. + """ + + def _resolve_numpy_generator(random_state): + if isinstance(random_state, np.random.Generator): + return random_state + elif random_state is None: + return Generator.rng + elif isinstance(random_state, int): + return np.random.default_rng(random_state) + else: + raise ValueError( + "Invalid random state. Must be None, int, or np.random.Generator." + ) + + if not BILBY_ARRAY_API: + return _resolve_numpy_generator(random_state) + + import orng + if isinstance(random_state, (np.random.Generator, orng.ArrayRNG)): + return random_state + elif aac.is_jax_array(random_state): + rng = orng.ArrayRNG(generator=random_state, backend="jax") + return rng + elif aac.is_torch_array(random_state): + rng = orng.ArrayRNG(seed=int(random_state), backend="torch") + return rng + else: + return _resolve_numpy_generator(random_state) + + +def random_array_module(random_state): + """ + Return the array module corresponding to the provided random state. + The the random state is a JAX random key, this will return :code:`jax.numpy`. + Otherwise, it will return :code:`numpy`. + + Parameters + ========== + random_state: None, int, np.random.Generator, or jax.random.KeyArray + The random state to resolve. + + Returns + ------- + array module + The array module corresponding to the provided random state. + """ + if random_state is None or not BILBY_ARRAY_API: + return np + elif isinstance(random_state, np.random.Generator): + return np + elif aac.is_jax_array(random_state) or getattr(random_state, "backend") == "jax": + import jax.numpy as jnp + return jnp + elif aac.is_torch_array(random_state) or getattr(random_state, "backend") == "torch": + import torch + return torch + else: + return np diff --git a/bilby/core/utils/samples.py b/bilby/core/utils/samples.py index a075d6dcd..93fdac0ac 100644 --- a/bilby/core/utils/samples.py +++ b/bilby/core/utils/samples.py @@ -1,3 +1,4 @@ +import array_api_extra as xpx import numpy as np from scipy.special import logsumexp @@ -135,7 +136,7 @@ def reflect(u): u: array-like The input array, modified in place. """ - idxs_even = np.mod(u, 2) < 1 - u[idxs_even] = np.mod(u[idxs_even], 1) - u[~idxs_even] = 1 - np.mod(u[~idxs_even], 1) + idxs_even = (u % 2) < 1 + u = xpx.at(u)[idxs_even].set(u[idxs_even] % 1) + u = xpx.at(u)[~idxs_even].set(1 - (u[~idxs_even] % 1)) return u diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index 63daebd6e..44f80e053 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -1,5 +1,10 @@ +import array_api_compat as aac +import array_api_extra as xpx import numpy as np +from ...compat.utils import array_module +from . import random + _TOL = 14 @@ -97,11 +102,14 @@ def create_time_series(sampling_frequency, duration, starting_time=0.): float: An equidistant time series given the parameters """ + xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) number_of_samples = int(duration * sampling_frequency) - return np.linspace(start=starting_time, - stop=duration + starting_time - 1 / sampling_frequency, - num=number_of_samples) + return xp.linspace( + starting_time, + duration + starting_time - 1 / sampling_frequency, + num=number_of_samples, + ) def create_frequency_series(sampling_frequency, duration): @@ -117,13 +125,13 @@ def create_frequency_series(sampling_frequency, duration): array_like: frequency series """ - _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) - number_of_samples = int(np.round(duration * sampling_frequency)) - number_of_frequencies = int(np.round(number_of_samples / 2) + 1) + xp = array_module(sampling_frequency) + if not aac.is_jax_namespace(xp): + _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) + number_of_samples = xp.round(duration * sampling_frequency) + number_of_frequencies = int(xp.round(number_of_samples / 2) + 1) - return np.linspace(start=0, - stop=sampling_frequency / 2, - num=number_of_frequencies) + return xp.linspace(0, sampling_frequency / 2, num=number_of_frequencies) def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): @@ -139,7 +147,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): """ num = sampling_frequency * duration - if np.abs(num - np.round(num)) > 10**(-_TOL): + if abs(num % 1) > 10**(-_TOL): raise IllegalDurationAndSamplingFrequencyException( '\nYour sampling frequency and duration must multiply to a number' 'up to (tol = {}) decimals close to an integer number. ' @@ -150,44 +158,65 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): ) -def create_white_noise(sampling_frequency, duration): - """ Create white_noise which is then coloured by a given PSD +def safe_white_noise(number_of_samples: int, duration: float, *, random_state=None): + """ + A JIT-compilable function to generate white noise in the frequency domain. Parameters ========== - sampling_frequency: float + number_of_samples: int + The number of samples in the time domain. duration: float - duration of the data + The duration of the time series. + random_state: None, int, np.random.Generator, or jax.random.KeyArray + The random state to use for noise generation. Returns ======= - array_like: white noise - array_like: frequency array + white_noise: array_like + The generated complex white noise in the frequency domain. + frequencies: array_like + The corresponding frequency array for the white noise. """ - from . import random - - number_of_samples = duration * sampling_frequency - number_of_samples = int(np.round(number_of_samples)) + frequencies = np.arange(number_of_samples) / duration - frequencies = create_frequency_series(sampling_frequency, duration) + rng = random.resolve_random_state(random_state) norm1 = 0.5 * duration**0.5 - re1, im1 = random.rng.normal(0, norm1, (2, len(frequencies))) + re1, im1 = rng.normal(0, norm1, (2, len(frequencies))) + white_noise = re1 + 1j * im1 # set DC and Nyquist = 0 - white_noise[0] = 0 - # no Nyquist frequency when N=odd - if np.mod(number_of_samples, 2) == 0: - white_noise[-1] = 0 + white_noise = xpx.at(white_noise, 0).set(0) + white_noise = xpx.at(white_noise, -1).set(0) # python: transpose for use with infft - white_noise = np.transpose(white_noise) - frequencies = np.transpose(frequencies) + white_noise = white_noise.T + frequencies = frequencies.T return white_noise, frequencies +def create_white_noise(sampling_frequency, duration, random_state=None): + """ Create white_noise which is then coloured by a given PSD + + Parameters + ========== + sampling_frequency: float + duration: float + duration of the data + + Returns + ======= + array_like: white noise + array_like: frequency array + """ + number_of_samples = duration * sampling_frequency + number_of_samples = int(np.round(number_of_samples)) // 2 + 1 + return safe_white_noise(number_of_samples, duration, random_state=random_state) + + def nfft(time_domain_strain, sampling_frequency): """ Perform an FFT while keeping track of the frequency bins. Assumes input time series is real (positive frequencies only). @@ -206,10 +235,11 @@ def nfft(time_domain_strain, sampling_frequency): strain / Hz, and the associated frequency_array. """ - frequency_domain_strain = np.fft.rfft(time_domain_strain) + xp = array_module(time_domain_strain) + frequency_domain_strain = xp.fft.rfft(time_domain_strain) frequency_domain_strain /= sampling_frequency - frequency_array = np.linspace( + frequency_array = xp.linspace( 0, sampling_frequency / 2, len(frequency_domain_strain)) return frequency_domain_strain, frequency_array @@ -231,7 +261,8 @@ def infft(frequency_domain_strain, sampling_frequency): time_domain_strain: array_like An array of the time domain strain """ - time_domain_strain_norm = np.fft.irfft(frequency_domain_strain) + xp = array_module(frequency_domain_strain) + time_domain_strain_norm = xp.fft.irfft(frequency_domain_strain) time_domain_strain = time_domain_strain_norm * sampling_frequency return time_domain_strain diff --git a/bilby/gw/__init__.py b/bilby/gw/__init__.py index b5115766b..cd09bc6f6 100644 --- a/bilby/gw/__init__.py +++ b/bilby/gw/__init__.py @@ -3,4 +3,5 @@ from .waveform_generator import WaveformGenerator, LALCBCWaveformGenerator from .likelihood import GravitationalWaveTransient from .detector import calibration +from . import compat diff --git a/bilby/gw/compat/__init__.py b/bilby/gw/compat/__init__.py new file mode 100644 index 000000000..245124a71 --- /dev/null +++ b/bilby/gw/compat/__init__.py @@ -0,0 +1,17 @@ +from ...compat.utils import BILBY_ARRAY_API + +try: + from .cython import gps_time_to_utc +except ModuleNotFoundError: + pass + +if BILBY_ARRAY_API: + try: + from .jax import n_leap_seconds + except ModuleNotFoundError: + pass + + try: + from .torch import n_leap_seconds + except ModuleNotFoundError: + pass \ No newline at end of file diff --git a/bilby/gw/compat/cython.py b/bilby/gw/compat/cython.py new file mode 100644 index 000000000..9d0a69af0 --- /dev/null +++ b/bilby/gw/compat/cython.py @@ -0,0 +1,66 @@ +import numpy as np +from bilby_cython import time as _time, geometry as _geometry +from plum import dispatch + +from ...compat.types import Real, ArrayLike + + +@dispatch(precedence=1) +def gps_time_to_utc(gps_time: Real): + return _time.gps_time_to_utc(gps_time) + + +@dispatch(precedence=1) +def greenwich_mean_sidereal_time(gps_time: Real | ArrayLike): + return _time.greenwich_mean_sidereal_time(gps_time) + + +@dispatch(precedence=1) +def greenwich_sidereal_time(gps_time: Real, equation_of_equinoxes: Real): + return _time.greenwich_sidereal_time(gps_time, equation_of_equinoxes) + + +@dispatch(precedence=1) +def n_leap_seconds(gps_time: Real): + return _time.n_leap_seconds(gps_time) + + +@dispatch(precedence=1) +def utc_to_julian_day(utc_time: Real): + return _time.utc_to_julian_day(utc_time) + + +@dispatch(precedence=1) +def calculate_arm(arm_tilt: Real, arm_azimuth: Real, longitude: Real, latitude: Real): + return _geometry.calculate_arm(arm_tilt, arm_azimuth, longitude, latitude) + + +@dispatch(precedence=1) +def detector_tensor(x: ArrayLike, y: ArrayLike): + return _geometry.detector_tensor(x, y) + + +@dispatch(precedence=1) +def get_polarization_tensor(ra: Real, dec: Real, time: Real, psi: Real, mode: str): + return _geometry.get_polarization_tensor(ra, dec, time, psi, mode) + + +@dispatch(precedence=1) +def rotation_matrix_from_delta(delta: ArrayLike): + return _geometry.rotation_matrix_from_delta(delta) + + +@dispatch(precedence=1) +def time_delay_geocentric(detector1: ArrayLike, detector2: ArrayLike, ra, dec, time): + return _geometry.time_delay_geocentric(detector1, detector2, ra, dec, time) + + +@dispatch(precedence=1) +def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: Real | ArrayLike): + return _geometry.time_delay_from_geocenter(detector1, ra, dec, time) + + +@dispatch(precedence=1) +def zenith_azimuth_to_theta_phi(zenith: Real, azimuth: Real, delta_x: np.ndarray): + theta, phi = _geometry.zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + return theta, phi % (2 * np.pi) diff --git a/bilby/gw/compat/jax.py b/bilby/gw/compat/jax.py new file mode 100644 index 000000000..99277e30a --- /dev/null +++ b/bilby/gw/compat/jax.py @@ -0,0 +1,20 @@ +import jax.numpy as jnp +from jax import Array +from plum import dispatch + +from ..time import ( + LEAP_SECONDS as _LEAP_SECONDS, + n_leap_seconds as _n_leap_seconds, +) + +__all__ = ["n_leap_seconds"] + +LEAP_SECONDS = jnp.array(_LEAP_SECONDS) + + +@dispatch +def n_leap_seconds(date: Array): + """ + Find the number of leap seconds required for the specified date. + """ + return _n_leap_seconds(date, LEAP_SECONDS) diff --git a/bilby/gw/compat/torch.py b/bilby/gw/compat/torch.py new file mode 100644 index 000000000..b3958f347 --- /dev/null +++ b/bilby/gw/compat/torch.py @@ -0,0 +1,19 @@ +import torch +from plum import dispatch + +from ..time import ( + LEAP_SECONDS as _LEAP_SECONDS, + n_leap_seconds as _n_leap_seconds, +) + +__all__ = ["n_leap_seconds"] + +LEAP_SECONDS = torch.tensor(_LEAP_SECONDS) + + +@dispatch +def n_leap_seconds(date: torch.Tensor) -> torch.Tensor: + """ + Find the number of leap seconds required for the specified date. + """ + return _n_leap_seconds(date, LEAP_SECONDS) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 9bd9cab06..96cd02dd1 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -26,6 +26,7 @@ lalsim_SimNeutronStarRadius, lalsim_SimNeutronStarLoveNumberK2) +from ..compat.utils import array_module from ..core.likelihood import MarginalizedLikelihoodReconstructionError from ..core.utils import logger, solar_mass, gravitational_constant, speed_of_light, command_line_args, safe_file_dump from ..core.prior import DeltaFunction @@ -204,9 +205,9 @@ def convert_to_lal_binary_black_hole_parameters(parameters): added_keys: list keys which are added to parameters during function call """ - converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) + xp = array_module(parameters) if 'luminosity_distance' not in original_keys: if 'redshift' in converted_parameters.keys(): converted_parameters['luminosity_distance'] = \ @@ -244,7 +245,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters['a_{}'.format(idx)] = abs( converted_parameters[key]) converted_parameters['cos_tilt_{}'.format(idx)] = \ - np.sign(converted_parameters[key]) + xp.sign(xp.asarray(converted_parameters[key])) else: with np.errstate(invalid="raise"): try: @@ -267,13 +268,13 @@ def convert_to_lal_binary_black_hole_parameters(parameters): cos_angle = str('cos_' + angle) if cos_angle in converted_parameters.keys(): with np.errstate(invalid="ignore"): - converted_parameters[angle] = np.arccos(converted_parameters[cos_angle]) + converted_parameters[angle] = xp.arccos(converted_parameters[cos_angle]) - if "delta_phase" in original_keys: + if "delta_phase" in converted_parameters: with np.errstate(invalid="ignore"): - converted_parameters["phase"] = np.mod( + converted_parameters["phase"] = xp.mod( converted_parameters["delta_phase"] - - np.sign(np.cos(converted_parameters["theta_jn"])) + - xp.sign(xp.cos(converted_parameters["theta_jn"])) * converted_parameters["psi"], 2 * np.pi) added_keys = [key for key in converted_parameters.keys() @@ -378,19 +379,19 @@ def convert_to_lal_binary_neutron_star_parameters(parameters): g3pca = converted_parameters['eos_spectral_pca_gamma_3'] m1s = converted_parameters['mass_1_source'] m2s = converted_parameters['mass_2_source'] - all_lambda_1 = np.empty(0) - all_lambda_2 = np.empty(0) - all_eos_check = np.empty(0, dtype=bool) + all_lambda_1 = list() + all_lambda_2 = list() + all_eos_check = list() for (g_0pca, g_1pca, g_2pca, g_3pca, m1_s, m2_s) in zip(g0pca, g1pca, g2pca, g3pca, m1s, m2s): g_0, g_1, g_2, g_3 = spectral_pca_to_spectral(g_0pca, g_1pca, g_2pca, g_3pca) lambda_1, lambda_2, eos_check = \ spectral_params_to_lambda_1_lambda_2(g_0, g_1, g_2, g_3, m1_s, m2_s) - all_lambda_1 = np.append(all_lambda_1, lambda_1) - all_lambda_2 = np.append(all_lambda_2, lambda_2) - all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + all_lambda_1.append(lambda_1) + all_lambda_2.append(lambda_2) + all_eos_check.append(eos_check) + converted_parameters['lambda_1'] = np.array(all_lambda_1) + converted_parameters['lambda_2'] = np.array(all_lambda_2) + converted_parameters['eos_check'] = np.array(all_eos_check) for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] elif 'eos_polytrope_gamma_0' and 'eos_polytrope_log10_pressure_1' in converted_parameters.keys(): @@ -630,8 +631,9 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) array of gamma_0, gamma_1, gamma_2, gamma_3 in model space ''' - sampled_pca_gammas = np.array([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) - transformation_matrix = np.array( + xp = array_module(gamma_pca_0) + sampled_pca_gammas = xp.asarray([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) + transformation_matrix = xp.asarray( [ [0.43801, -0.76705, 0.45143, 0.12646], [-0.53573, 0.17169, 0.67968, 0.47070], @@ -640,10 +642,10 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) ] ) - model_space_mean = np.array([0.89421, 0.33878, -0.07894, 0.00393]) - model_space_standard_deviation = np.array([0.35700, 0.25769, 0.05452, 0.00312]) + model_space_mean = xp.asarray([0.89421, 0.33878, -0.07894, 0.00393]) + model_space_standard_deviation = xp.asarray([0.35700, 0.25769, 0.05452, 0.00312]) converted_gamma_parameters = \ - model_space_mean + model_space_standard_deviation * np.dot(transformation_matrix, sampled_pca_gammas) + model_space_mean + model_space_standard_deviation * (transformation_matrix @ sampled_pca_gammas) return converted_gamma_parameters @@ -958,9 +960,9 @@ def chirp_mass_and_primary_mass_to_mass_ratio(chirp_mass, mass_1): Mass ratio (mass_2/mass_1) of the binary """ a = (chirp_mass / mass_1) ** 5 - t0 = np.cbrt(9 * a + np.sqrt(3) * np.sqrt(27 * a ** 2 - 4 * a ** 3)) - t1 = np.cbrt(2) * 3 ** (2 / 3) - t2 = np.cbrt(2 / 3) * a + t0 = (9 * a + 3**0.5 * (27 * a ** 2 - 4 * a ** 3)**0.5)**(1 / 3) + t1 = (2)**(1 / 3) * 3 ** (2 / 3) + t2 = (2 / 3)**(1 / 3) * a return t2 / t0 + t0 / t1 @@ -1043,8 +1045,8 @@ def component_masses_to_symmetric_mass_ratio(mass_1, mass_2): symmetric_mass_ratio: float Symmetric mass ratio of the binary """ - - return np.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, 1 / 4) + xp = array_module(mass_1) + return xp.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, xp.asarray(0.25)) def component_masses_to_mass_ratio(mass_1, mass_2): @@ -1403,17 +1405,17 @@ def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_s lambda_antisymmetric: float Antisymmetric tidal parameter. """ - lambda_symmetric_m1o5 = np.power(lambda_symmetric, -1. / 5.) + lambda_symmetric_m1o5 = lambda_symmetric ** (-1 / 5) lambda_symmetric_m2o5 = lambda_symmetric_m1o5 * lambda_symmetric_m1o5 lambda_symmetric_m3o5 = lambda_symmetric_m2o5 * lambda_symmetric_m1o5 q = mass_ratio - q2 = np.square(mass_ratio) + q2 = mass_ratio ** 2 # Eqn.2 from CHZ, incorporating the dependence on mass ratio n_polytropic = 0.743 # average polytropic index for the EoSs included in the fit - q_for_Fnofq = np.power(q, 10. / (3. - n_polytropic)) + q_for_Fnofq = q ** (10. / (3. - n_polytropic)) Fnofq = (1. - q_for_Fnofq) / (1. + q_for_Fnofq) # b_ij and c_ij coefficients are given in Table I of CHZ @@ -1483,10 +1485,10 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin lambda_antisymmetric_fitOnly = binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_symmetric, mass_ratio) - lambda_symmetric_sqrt = np.sqrt(lambda_symmetric) + lambda_symmetric_sqrt = lambda_symmetric ** 0.5 q = mass_ratio - q2 = np.square(mass_ratio) + q2 = mass_ratio ** 2 # mu_i and sigma_i coefficients are given in Table II of CHZ @@ -1546,9 +1548,10 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin # Eqn 5 from CHZ, averaging the corrections from the # standard deviations of the residual fits - lambda_antisymmetric_stdCorr = \ - np.sqrt(np.square(lambda_antisymmetric_lambda_symmetric_stdCorr) + - np.square(lambda_antisymmetric_mass_ratio_stdCorr)) + lambda_antisymmetric_stdCorr = ( + lambda_antisymmetric_lambda_symmetric_stdCorr ** 2 + + lambda_antisymmetric_mass_ratio_stdCorr ** 2 + ) ** 0.5 # Draw a correction on the fit from a # Gaussian distribution with width lambda_antisymmetric_stdCorr @@ -2066,28 +2069,29 @@ def generate_spin_parameters(sample): output_sample = sample.copy() output_sample = generate_component_spins(output_sample) + xp = array_module(sample) output_sample['chi_eff'] = (output_sample['spin_1z'] + output_sample['spin_2z'] * output_sample['mass_ratio']) /\ (1 + output_sample['mass_ratio']) - output_sample['chi_1_in_plane'] = np.sqrt( + output_sample['chi_1_in_plane'] = ( output_sample['spin_1x'] ** 2 + output_sample['spin_1y'] ** 2 - ) - output_sample['chi_2_in_plane'] = np.sqrt( + ) ** 0.5 + output_sample['chi_2_in_plane'] = ( output_sample['spin_2x'] ** 2 + output_sample['spin_2y'] ** 2 - ) + ) ** 0.5 - output_sample['chi_p'] = np.maximum( + output_sample['chi_p'] = xp.maximum( output_sample['chi_1_in_plane'], (4 * output_sample['mass_ratio'] + 3) / (3 * output_sample['mass_ratio'] + 4) * output_sample['mass_ratio'] * output_sample['chi_2_in_plane']) try: - output_sample['cos_tilt_1'] = np.cos(output_sample['tilt_1']) - output_sample['cos_tilt_2'] = np.cos(output_sample['tilt_2']) + output_sample['cos_tilt_1'] = xp.cos(output_sample['tilt_1']) + output_sample['cos_tilt_2'] = xp.cos(output_sample['tilt_2']) except KeyError: pass @@ -2116,12 +2120,13 @@ def generate_component_spins(sample): ['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2', 'mass_1', 'mass_2', 'reference_frequency', 'phase'] if all(key in output_sample.keys() for key in spin_conversion_parameters): + xp = array_module(output_sample["theta_jn"]) ( output_sample['iota'], output_sample['spin_1x'], output_sample['spin_1y'], output_sample['spin_1z'], output_sample['spin_2x'], output_sample['spin_2y'], output_sample['spin_2z'] - ) = np.vectorize(bilby_to_lalsimulation_spins)( + ) = xp.vectorize(bilby_to_lalsimulation_spins)( output_sample['theta_jn'], output_sample['phi_jl'], output_sample['tilt_1'], output_sample['tilt_2'], output_sample['phi_12'], output_sample['a_1'], output_sample['a_2'], @@ -2131,10 +2136,10 @@ def generate_component_spins(sample): ) output_sample['phi_1'] =\ - np.fmod(2 * np.pi + np.arctan2( + xp.fmod(2 * np.pi + xp.arctan2( output_sample['spin_1y'], output_sample['spin_1x']), 2 * np.pi) output_sample['phi_2'] =\ - np.fmod(2 * np.pi + np.arctan2( + xp.fmod(2 * np.pi + xp.arctan2( output_sample['spin_2y'], output_sample['spin_2x']), 2 * np.pi) elif 'chi_1' in output_sample and 'chi_2' in output_sample: diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 729b9e332..c904c707d 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -42,10 +42,13 @@ import copy import os +import array_api_compat as aac import numpy as np import pandas as pd +from array_api_compat import is_jax_namespace from scipy.interpolate import interp1d +from ...compat.utils import array_module, xp_wrap from ...core.utils.log import logger from ...core.prior.dict import PriorDict from ..prior import CalibrationPriorDict @@ -240,7 +243,8 @@ def get_calibration_factor(self, frequency_array, **params): calibration_factor : array-like The factor to multiply the strain by. """ - return np.ones_like(frequency_array) + xp = aac.array_namespace(frequency_array) + return xp.ones_like(frequency_array) def set_calibration_parameters(self, **params): self.params.update({key[len(self.prefix):]: params[key] for key in params @@ -330,9 +334,11 @@ def __repr__(self): def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): """Evaluate Eq. (1) in https://dcc.ligo.org/LIGO-T2300140""" - parameters = np.array([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) + xp = array_module(self.params[f"{kind}_0"]) + parameters = xp.asarray([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) next_nodes = previous_nodes + 1 - spline_coefficients = self.nodes_to_spline_coefficients.dot(parameters) + nodes = xp.asarray(self.nodes_to_spline_coefficients) + spline_coefficients = nodes.dot(parameters) return ( a * parameters[previous_nodes] + b * parameters[next_nodes] @@ -340,7 +346,8 @@ def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): + d * spline_coefficients[next_nodes] ) - def get_calibration_factor(self, frequency_array, **params): + @xp_wrap + def get_calibration_factor(self, frequency_array, *, xp=np, **params): """Apply calibration model Parameters @@ -358,10 +365,11 @@ def get_calibration_factor(self, frequency_array, **params): calibration_factor : array-like The factor to multiply the strain by. """ - log10f_per_deltalog10f = ( - np.log10(frequency_array) - self.log_spline_points[0] + log10f_per_deltalog10f = xp.nan_to_num( + xp.log10(frequency_array) - xp.asarray(self.log_spline_points[0]), + neginf=0.0, ) / self.delta_log_spline_points - previous_nodes = np.clip(np.floor(log10f_per_deltalog10f).astype(int), a_min=0, a_max=self.n_points - 2) + previous_nodes = xp.clip(xp.astype(log10f_per_deltalog10f, int), min=0, max=self.n_points - 2) b = log10f_per_deltalog10f - previous_nodes a = 1 - b c = (a**3 - a) / 6 @@ -373,7 +381,7 @@ def get_calibration_factor(self, frequency_array, **params): delta_phase = self._evaluate_spline("phase", a, b, c, d, previous_nodes) calibration_factor = (1 + delta_amplitude) * (2 + 1j * delta_phase) / (2 - 1j * delta_phase) - return calibration_factor + return xp.nan_to_num(calibration_factor) class Precomputed(Recalibrate): @@ -405,8 +413,21 @@ def get_calibration_factor(self, frequency_array, **params): idx = int(params.get(self.prefix, None)) if idx is None: raise KeyError(f"Calibration index for {self.label} not found.") - if not np.array_equal(frequency_array, self.frequency_array): - raise ValueError("Frequency grid passed to calibrator doesn't match.") + + xp = aac.get_namespace(frequency_array) + if not xp.array_equal(frequency_array, self.frequency_array): + intersection, mask, _ = xp.intersect1d( + frequency_array, self.frequency_array, return_indices=True + ) + if len(intersection) != len(self.frequency_array): + raise ValueError("Frequency grid passed to calibrator doesn't match.") + output = xp.ones_like(frequency_array, dtype=complex) + curve = xp.asarray(self.curves[idx]) + if is_jax_namespace(xp): + output = output.at[mask].set(curve) + else: + output[mask] = curve + return output return self.curves[idx] @classmethod diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index d7e1433de..a6c2df168 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -1,5 +1,5 @@ import numpy as np -from bilby_cython.geometry import calculate_arm, detector_tensor +from ..geometry import calculate_arm, detector_tensor from .. import utils as gwutils @@ -264,7 +264,7 @@ def detector_tensor(self): if not self._x_updated or not self._y_updated: _, _ = self.x, self.y # noqa if not self._detector_tensor_updated: - self._detector_tensor = detector_tensor(x=self.x, y=self.y) + self._detector_tensor = detector_tensor(self.x, self.y) self._detector_tensor_updated = True return self._detector_tensor @@ -290,17 +290,27 @@ def unit_vector_along_arm(self, arm): """ if arm == 'x': return calculate_arm( - arm_tilt=self._xarm_tilt, - arm_azimuth=self._xarm_azimuth, - longitude=self._longitude, - latitude=self._latitude + self._xarm_tilt, + self._xarm_azimuth, + self._longitude, + self._latitude ) elif arm == 'y': return calculate_arm( - arm_tilt=self._yarm_tilt, - arm_azimuth=self._yarm_azimuth, - longitude=self._longitude, - latitude=self._latitude + self._yarm_tilt, + self._yarm_azimuth, + self._longitude, + self._latitude ) else: raise ValueError("Arm must either be 'x' or 'y'.") + + def set_array_backend(self, xp): + self.length = xp.asarray(self.length) + self.latitude = xp.asarray(self.latitude) + self.longitude = xp.asarray(self.longitude) + self.elevation = xp.asarray(self.elevation) + self.xarm_azimuth = xp.asarray(self.xarm_azimuth) + self.yarm_azimuth = xp.asarray(self.yarm_azimuth) + self.xarm_tilt = xp.asarray(self.xarm_tilt) + self.yarm_tilt = xp.asarray(self.yarm_tilt) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 9e9c23bdf..e9e920095 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -1,16 +1,17 @@ import os import numpy as np -from bilby_cython.geometry import ( - get_polarization_tensor, - three_by_three_matrix_contraction, - time_delay_from_geocenter, -) from ...core import utils -from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump +from ...core.utils import PropertyAccessor, docstring, logger, safe_file_dump from ...core.utils.env import string_to_boolean +from ...compat.utils import array_module from .. import utils as gwutils +from ..geometry import ( + get_polarization_tensor, + three_by_three_matrix_contraction, + time_delay_from_geocenter, +) from .calibration import Recalibrate from .geometry import InterferometerGeometry from .strain_data import InterferometerStrainData @@ -114,16 +115,19 @@ def __repr__(self): float(self.geometry.yarm_azimuth), float(self.geometry.xarm_tilt), float(self.geometry.yarm_tilt)) - def set_strain_data_from_gwpy_timeseries(self, time_series): + def set_strain_data_from_gwpy_timeseries(self, time_series, *, xp=None): """ Set the `Interferometer.strain_data` from a gwpy TimeSeries Parameters ========== time_series: gwpy.timeseries.timeseries.TimeSeries The data to set. + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ - self.strain_data.set_from_gwpy_timeseries(time_series=time_series) + self.strain_data.set_from_gwpy_timeseries(time_series=time_series, xp=xp) def set_strain_data_from_frequency_domain_strain( self, frequency_domain_strain, sampling_frequency=None, @@ -151,7 +155,7 @@ def set_strain_data_from_frequency_domain_strain( start_time=start_time, frequency_array=frequency_array) def set_strain_data_from_power_spectral_density( - self, sampling_frequency, duration, start_time=0): + self, sampling_frequency, duration, start_time=0, *, random_state=None): """ Set the `Interferometer.strain_data` from a power spectal density This uses the `interferometer.power_spectral_density` object to set @@ -170,11 +174,11 @@ def set_strain_data_from_power_spectral_density( """ self.strain_data.set_from_power_spectral_density( self.power_spectral_density, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time) + duration=duration, start_time=start_time, random_state=random_state) def set_strain_data_from_frame_file( self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): + channel=None, buffer_time=1, *, xp=None): """ Set the `Interferometer.strain_data` from a frame file Parameters @@ -192,15 +196,18 @@ def set_strain_data_from_frame_file( buffer_time: float Read in data with `start_time-buffer_time` and `start_time+duration+buffer_time` + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ self.strain_data.set_from_frame_file( frame_file=frame_file, sampling_frequency=sampling_frequency, duration=duration, start_time=start_time, - channel=channel, buffer_time=buffer_time) + channel=channel, buffer_time=buffer_time, xp=xp) def set_strain_data_from_channel_name( - self, channel, sampling_frequency, duration, start_time=0): + self, channel, sampling_frequency, duration, start_time=0, *, xp=None): """ Set the `Interferometer.strain_data` by fetching from given channel using strain_data.set_from_channel_name() @@ -215,22 +222,28 @@ def set_strain_data_from_channel_name( The data duration (in s) start_time: float The GPS start-time of the data + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ self.strain_data.set_from_channel_name( channel=channel, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time) + duration=duration, start_time=start_time, xp=xp) - def set_strain_data_from_csv(self, filename): + def set_strain_data_from_csv(self, filename, *, xp=None): """ Set the `Interferometer.strain_data` from a csv file Parameters ========== filename: str The path to the file to read in + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ - self.strain_data.set_from_csv(filename) + self.strain_data.set_from_csv(filename, xp=xp) def set_strain_data_from_zero_noise( self, sampling_frequency, duration, start_time=0): @@ -312,11 +325,13 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= used to set the time at which the antenna response is evaluated, otherwise the provided :code:`Parameters["geocent_time"]` is used. """ + xp = array_module(waveform_polarizations) if frequencies is None: - frequencies = self.frequency_array[self.frequency_mask] + frequencies = self.frequency_array mask = self.frequency_mask else: - mask = np.ones(len(frequencies), dtype=bool) + mask = xp.ones(len(frequencies), dtype=bool) + frequencies = xp.asarray(frequencies) if self.reference_time is None: antenna_time = parameters["geocent_time"] @@ -331,8 +346,8 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= antenna_time, parameters['psi'], mode) - signal[mode] = waveform_polarizations[mode] * det_response - signal_ifo = sum(signal.values()) * mask + signal[mode] = waveform_polarizations[mode] * mask * det_response + signal_ifo = sum(signal.values()) time_shift = self.time_delay_from_geocenter( parameters['ra'], parameters['dec'], parameters['geocent_time']) @@ -342,10 +357,12 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= dt_geocent = parameters['geocent_time'] - self.strain_data.start_time dt = dt_geocent + time_shift - signal_ifo[mask] = signal_ifo[mask] * np.exp(-1j * 2 * np.pi * dt * frequencies) + xp = array_module(signal_ifo) + + signal_ifo = signal_ifo * xp.exp(-1j * 2 * np.pi * dt * frequencies) - signal_ifo[mask] *= self.calibration_model.get_calibration_factor( - frequencies, prefix='recalib_{}_'.format(self.name), **parameters + signal_ifo *= self.calibration_model.get_calibration_factor( + frequencies, prefix=f'recalib_{self.name}_', xp=xp, **parameters ) return signal_ifo @@ -494,7 +511,7 @@ def inject_signal_from_waveform_polarizations(self, parameters, injection_polari self.strain_data.frequency_domain_strain += signal_ifo self.meta_data['optimal_SNR'] = ( - np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real) + self.optimal_snr_squared(signal=signal_ifo)).real ** 0.5 self.meta_data['matched_filter_SNR'] = ( self.matched_filter_snr(signal=signal_ifo)) self.meta_data['parameters'] = parameters @@ -680,7 +697,7 @@ def whiten_frequency_series(self, frequency_series : np.array) -> np.array: frequency_series : np.array The frequency series, whitened by the ASD """ - return frequency_series / (self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)) + return frequency_series / (self.amplitude_spectral_density_array * (self.duration / 4)**0.5) def get_whitened_time_series_from_whitened_frequency_series( self, @@ -711,14 +728,13 @@ def get_whitened_time_series_from_whitened_frequency_series( w = \\sqrt{N W} = \\sqrt{\\sum_{k=0}^N \\Theta(f_{max} - f_k)\\Theta(f_k - f_{min})} """ - frequency_window_factor = ( - np.sum(self.frequency_mask) - / len(self.frequency_mask) - ) + xp = array_module(whitened_frequency_series) + + frequency_window_factor = self.frequency_mask.mean() whitened_time_series = ( - np.fft.irfft(whitened_frequency_series) - * np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor + xp.fft.irfft(whitened_frequency_series) + * self.frequency_mask.sum()**0.5 / frequency_window_factor ) return whitened_time_series @@ -936,3 +952,11 @@ def from_pickle(cls, filename=None): if res.__class__ != cls: raise TypeError('The loaded object is not an Interferometer') return res + + def set_array_backend(self, xp): + self.geometry.set_array_backend(xp=xp) + self.power_spectral_density.set_array_backend(xp=xp) + + @property + def array_backend(self): + return array_module(self.geometry.length) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 25b3e7e71..50b99e175 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -5,6 +5,7 @@ from ...core import utils from ...core.utils import logger, safe_file_dump +from ..geometry import zenith_azimuth_to_theta_phi from .interferometer import Interferometer from .psd import PowerSpectralDensity @@ -84,7 +85,7 @@ def _check_interferometers(self): logger.warning(e) def set_strain_data_from_power_spectral_densities( - self, sampling_frequency, duration, start_time=0 + self, sampling_frequency, duration, start_time=0, *, random_state=None ): """Set the `Interferometer.strain_data` from the power spectral densities of the detectors @@ -107,6 +108,7 @@ def set_strain_data_from_power_spectral_densities( sampling_frequency=sampling_frequency, duration=duration, start_time=start_time, + random_state=random_state, ) def set_strain_data_from_zero_noise( @@ -341,6 +343,14 @@ def from_pickle(cls, filename=None): ) from_pickle.__doc__ = _load_docstring.format(format="pickle") + def set_array_backend(self, xp): + for ifo in self: + ifo.set_array_backend(xp) + + @property + def array_backend(self): + return self[0].array_backend + class TriangularInterferometer(InterferometerList): def __init__( @@ -472,3 +482,9 @@ def load_interferometer(filename): "{} could not be loaded. Invalid parameter 'shape'.".format(filename) ) return ifo + + +@zenith_azimuth_to_theta_phi.dispatch +def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos: InterferometerList | list): + delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex + return zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) diff --git a/bilby/gw/detector/psd.py b/bilby/gw/detector/psd.py index a3948f966..78e5a472b 100644 --- a/bilby/gw/detector/psd.py +++ b/bilby/gw/detector/psd.py @@ -1,8 +1,10 @@ import os +import array_api_compat as aac +import array_api_extra as xpx import numpy as np -from scipy.interpolate import interp1d +from ...core.utils.calculus import interp1d from ...core import utils from ...core.utils import logger from .strain_data import InterferometerStrainData @@ -205,55 +207,59 @@ def from_aligo(): @property def psd_array(self): - return self.__psd_array + return self._psd_array @psd_array.setter def psd_array(self, psd_array): - self.__check_frequency_array_matches_density_array(psd_array) - self.__psd_array = np.array(psd_array) - self.__asd_array = psd_array ** 0.5 - self.__interpolate_power_spectral_density() + self._check_frequency_array_matches_density_array(psd_array) + self._psd_array = np.array(psd_array) + self._asd_array = psd_array ** 0.5 + self._interpolate_power_spectral_density() @property def asd_array(self): - return self.__asd_array + return self._asd_array @asd_array.setter def asd_array(self, asd_array): - self.__check_frequency_array_matches_density_array(asd_array) - self.__asd_array = np.array(asd_array) - self.__psd_array = asd_array ** 2 - self.__interpolate_power_spectral_density() + self._check_frequency_array_matches_density_array(asd_array) + self._asd_array = np.array(asd_array) + self._psd_array = asd_array ** 2 + self._interpolate_power_spectral_density() - def __check_frequency_array_matches_density_array(self, density_array): + def _check_frequency_array_matches_density_array(self, density_array): if len(self.frequency_array) != len(density_array): raise ValueError('Provided spectral density does not match frequency array. Not updating.\n' 'Length spectral density {}\n Length frequency array {}\n' .format(density_array, self.frequency_array)) - def __interpolate_power_spectral_density(self): + def _interpolate_power_spectral_density(self): """Interpolate the loaded power spectral density so it can be resampled for arbitrary frequency arrays. """ - self.__power_spectral_density_interpolated = interp1d(self.frequency_array, - self.psd_array, - bounds_error=False, - fill_value=np.inf) + self._power_spectral_density_interpolated = interp1d(self.frequency_array, + self.psd_array, + bounds_error=False, + fill_value=np.inf) self._update_cache(self.frequency_array) def get_power_spectral_density_array(self, frequency_array): + if aac.is_jax_array(frequency_array): + return self.power_spectral_density_interpolated(frequency_array) if not np.array_equal(frequency_array, self._cache['frequency_array']): self._update_cache(frequency_array=frequency_array) return self._cache['psd_array'] def get_amplitude_spectral_density_array(self, frequency_array): + if aac.is_jax_array(frequency_array): + return self.power_spectral_density_interpolated(frequency_array)**0.5 if not np.array_equal(frequency_array, self._cache['frequency_array']): self._update_cache(frequency_array=frequency_array) return self._cache['asd_array'] @property def power_spectral_density_interpolated(self): - return self.__power_spectral_density_interpolated + return self._power_spectral_density_interpolated @property def asd_file(self): @@ -261,13 +267,13 @@ def asd_file(self): @asd_file.setter def asd_file(self, asd_file): - asd_file = self.__validate_file_name(file=asd_file) + asd_file = self._validate_file_name(file=asd_file) self._asd_file = asd_file if asd_file is not None: - self.__import_amplitude_spectral_density() - self.__check_file_was_asd_file() + self._import_amplitude_spectral_density() + self._check_file_was_asd_file() - def __check_file_was_asd_file(self): + def _check_file_was_asd_file(self): if min(self.asd_array) < 1e-30: logger.warning("You specified an amplitude spectral density file.") logger.warning("{} WARNING {}".format("*" * 30, "*" * 30)) @@ -280,13 +286,13 @@ def psd_file(self): @psd_file.setter def psd_file(self, psd_file): - psd_file = self.__validate_file_name(file=psd_file) + psd_file = self._validate_file_name(file=psd_file) self._psd_file = psd_file if psd_file is not None: - self.__import_power_spectral_density() - self.__check_file_was_psd_file() + self._import_power_spectral_density() + self._check_file_was_psd_file() - def __check_file_was_psd_file(self): + def _check_file_was_psd_file(self): if min(self.psd_array) > 1e-30: logger.warning("You specified a power spectral density file.") logger.warning("{} WARNING {}".format("*" * 30, "*" * 30)) @@ -294,7 +300,7 @@ def __check_file_was_psd_file(self): logger.warning("You may have intended to provide this as an amplitude spectral density.") @staticmethod - def __validate_file_name(file): + def _validate_file_name(file): """ Test if the file exists or is available in the default directory. @@ -333,15 +339,15 @@ def __validate_file_name(file): .format(file)) return file - def __import_amplitude_spectral_density(self): + def _import_amplitude_spectral_density(self): """ Automagically load an amplitude spectral density curve """ self.frequency_array, self.asd_array = np.genfromtxt(self.asd_file).T - def __import_power_spectral_density(self): + def _import_power_spectral_density(self): """ Automagically load a power spectral density curve """ self.frequency_array, self.psd_array = np.genfromtxt(self.psd_file).T - def get_noise_realisation(self, sampling_frequency, duration): + def get_noise_realisation(self, number_of_samples, duration, *, random_state=None): """ Generate frequency Gaussian noise scaled to the power spectral density. @@ -358,9 +364,30 @@ def get_noise_realisation(self, sampling_frequency, duration): array_like: frequencies related to the frequency domain strain """ - white_noise, frequencies = utils.create_white_noise(sampling_frequency, duration) + white_noise, frequencies = utils.safe_white_noise(number_of_samples, duration, random_state=random_state) with np.errstate(invalid="ignore"): - frequency_domain_strain = self.__power_spectral_density_interpolated(frequencies) ** 0.5 * white_noise - out_of_bounds = (frequencies < min(self.frequency_array)) | (frequencies > max(self.frequency_array)) - frequency_domain_strain[out_of_bounds] = 0 * (1 + 1j) - return frequency_domain_strain, frequencies + frequency_domain_strain = self._power_spectral_density_interpolated(frequencies) ** 0.5 * white_noise + xp = aac.array_namespace(frequency_domain_strain) + out_of_bounds = ( + (frequencies < self.frequency_array.min()) + | (frequencies > self.frequency_array.max()) + ) + frequency_domain_strain = xpx.at(frequency_domain_strain, out_of_bounds).set(0j) + return xp.nan_to_num(frequency_domain_strain), xp.asarray(frequencies) + + def set_array_backend(self, xp): + """ Set the array backend for the cached arrays + + Parameters + ========== + xp: module + The array backend to use for the cached arrays + + """ + self.frequency_array = xp.asarray(self.frequency_array) + self._asd_array = xp.asarray(self._asd_array) + self._psd_array = xp.asarray(self._psd_array) + self._cache['frequency_array'] = xp.asarray(self._cache['frequency_array']) + self._cache['psd_array'] = xp.asarray(self._cache['psd_array']) + self._cache['asd_array'] = xp.asarray(self._cache['asd_array']) + self._interpolate_power_spectral_density() diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index bca7acced..1383d9d8c 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -1,5 +1,7 @@ +import array_api_compat as aac import numpy as np +from ...compat.utils import array_module from ...core import utils from ...core.series import CoupledTimeAndFrequencySeries from ...core.utils import logger, PropertyAccessor @@ -101,8 +103,10 @@ def minimum_frequency(self, minimum_frequency): def maximum_frequency(self): """ Force the maximum frequency be less than the Nyquist frequency """ if self.sampling_frequency is not None: - if 2 * self._maximum_frequency > self.sampling_frequency: - self._maximum_frequency = self.sampling_frequency / 2. + xp = array_module(self._maximum_frequency) + self._maximum_frequency = xp.minimum( + self._maximum_frequency, self.sampling_frequency / 2 + ) return self._maximum_frequency @maximum_frequency.setter @@ -498,7 +502,7 @@ def set_from_time_domain_strain( else: raise ValueError("Data times do not match time array") - def set_from_gwpy_timeseries(self, time_series): + def set_from_gwpy_timeseries(self, time_series, *, xp=np): """ Set the strain data from a gwpy TimeSeries This sets the time_domain_strain attribute, the frequency_domain_strain @@ -509,17 +513,23 @@ def set_from_gwpy_timeseries(self, time_series): ========== time_series: gwpy.timeseries.timeseries.TimeSeries The data to use + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ from gwpy.timeseries import TimeSeries logger.debug('Setting data using provided gwpy TimeSeries object') if not isinstance(time_series, TimeSeries): raise ValueError("Input time_series is not a gwpy TimeSeries") + duration = xp.asarray(time_series.duration.value) + sampling_frequency = xp.asarray(time_series.sample_rate.value) + start_time = xp.asarray(time_series.epoch.value) self._times_and_frequencies = \ - CoupledTimeAndFrequencySeries(duration=time_series.duration.value, - sampling_frequency=time_series.sample_rate.value, - start_time=time_series.epoch.value) - self._time_domain_strain = time_series.value + CoupledTimeAndFrequencySeries(duration=duration, + sampling_frequency=sampling_frequency, + start_time=start_time) + self._time_domain_strain = xp.asarray(time_series.value) self._frequency_domain_strain = None self._channel = time_series.channel @@ -529,7 +539,7 @@ def channel(self): def set_from_open_data( self, name, start_time, duration=4, outdir='outdir', cache=True, - **kwargs): + *, xp=None, **kwargs): """ Set the strain data from open LOSC data This sets the time_domain_strain attribute, the frequency_domain_strain @@ -548,30 +558,38 @@ def set_from_open_data( Directory where the psd files are saved cache: bool, optional Whether or not to store/use the acquired data. + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. **kwargs: All keyword arguments are passed to `gwpy.timeseries.TimeSeries.fetch_open_data()`. """ - timeseries = gwutils.get_open_strain_data( - name, start_time, start_time + duration, outdir=outdir, cache=cache, + name, float(start_time), float(start_time + duration), outdir=outdir, cache=cache, **kwargs) - self.set_from_gwpy_timeseries(timeseries) + if xp is None: + xp = array_module((duration, start_time)) + + self.set_from_gwpy_timeseries(timeseries, xp=xp) - def set_from_csv(self, filename): + def set_from_csv(self, filename, xp=None): """ Set the strain data from a csv file Parameters ========== filename: str The path to the file to read in + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ from gwpy.timeseries import TimeSeries timeseries = TimeSeries.read(filename, format='csv') - self.set_from_gwpy_timeseries(timeseries) + self.set_from_gwpy_timeseries(timeseries, xp=xp) def set_from_frequency_domain_strain( self, frequency_domain_strain, sampling_frequency=None, @@ -608,7 +626,7 @@ def set_from_frequency_domain_strain( def set_from_power_spectral_density( self, power_spectral_density, sampling_frequency, duration, - start_time=0): + start_time=0, *, random_state=None): """ Set the `frequency_domain_strain` by generating a noise realisation Parameters @@ -632,9 +650,13 @@ def set_from_power_spectral_density( 'power_spectal_density') frequency_domain_strain, frequency_array = \ power_spectral_density.get_noise_realisation( - self.sampling_frequency, self.duration) + self.frequency_array.shape[0], self.duration, random_state=random_state) + + xp = aac.array_namespace(frequency_domain_strain) + self._frequency_array = xp.asarray(self.frequency_array) + self._time_array = xp.asarray(self.time_array) - if np.array_equal(frequency_array, self.frequency_array): + if self.duration == duration and frequency_array.shape == self.frequency_array.shape: self._frequency_domain_strain = frequency_domain_strain else: raise ValueError("Data frequencies do not match frequency_array") @@ -661,12 +683,13 @@ def set_from_zero_noise(self, sampling_frequency, duration, start_time=0): sampling_frequency=sampling_frequency, start_time=start_time) logger.debug('Setting zero noise data') - self._frequency_domain_strain = np.zeros_like(self.frequency_array, + xp = aac.get_namespace(self.frequency_array) + self._frequency_domain_strain = xp.zeros_like(self.frequency_array, dtype=complex) def set_from_frame_file( self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): + channel=None, buffer_time=1, *, xp=None): """ Set the `frequency_domain_strain` from a frame fiile Parameters @@ -684,6 +707,10 @@ def set_from_frame_file( buffer_time: float Read in data with `start_time-buffer_time` and `start_time+duration+buffer_time` + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified, it will be inferred from the provided duration/ + sampling frequency. """ @@ -697,9 +724,12 @@ def set_from_frame_file( buffer_time=buffer_time, channel=channel, resample=sampling_frequency) - self.set_from_gwpy_timeseries(strain) + if xp is None: + xp = aac.get_namespace(self.frequency_array) - def set_from_channel_name(self, channel, duration, start_time, sampling_frequency): + self.set_from_gwpy_timeseries(strain, xp=xp) + + def set_from_channel_name(self, channel, duration, start_time, sampling_frequency, *, xp=None): """ Set the `frequency_domain_strain` by fetching from given channel using gwpy.TimesSeries.get(), which dynamically accesses either frames on disk, or a remote NDS2 server to find and return data. This function @@ -715,6 +745,10 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc The GPS start-time of the data sampling_frequency: float The sampling frequency (in Hz) + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified, it will be inferred from the provided duration/ + sampling frequency. """ from gwpy.timeseries import TimeSeries @@ -730,7 +764,10 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc strain = TimeSeries.get(channel, start_time, start_time + duration) strain = strain.resample(sampling_frequency) - self.set_from_gwpy_timeseries(strain) + if xp is None: + xp = aac.get_namespace(self.frequency_array) + + self.set_from_gwpy_timeseries(strain, xp=xp) class Notch(object): diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py new file mode 100644 index 000000000..fa4be3d86 --- /dev/null +++ b/bilby/gw/geometry.py @@ -0,0 +1,377 @@ +from plum import dispatch + +from .time import greenwich_mean_sidereal_time +from ..compat.utils import array_module, promote_to_array + + +__all__ = [ + "antenna_response", + "calculate_arm", + "detector_tensor", + "get_polarization_tensor", + "get_polarization_tensor_multiple_modes", + "rotation_matrix_from_delta", + "three_by_three_matrix_contraction", + "time_delay_geocentric", + "time_delay_from_geocenter", + "zenith_azimuth_to_theta_phi", +] + + +@dispatch +def antenna_response(detector_tensor, ra, dec, time, psi, mode): + """ + Calculate the antenna response for a detector. + + Parameters + ========== + detector_tensor: array-like + The detector tensor (3x3 matrix). + ra: float or array-like + Right ascension of the source in radians. + dec: float or array-like + Declination of the source in radians. + time: float or array-like + GPS time of the observation. + psi: float or array-like + Polarization angle in radians. + mode: str + Polarization mode ('plus', 'cross', 'breathing', 'longitudinal', 'x', 'y'). + + Returns + ======= + array-like + The antenna response (scalar or array depending on input). + """ + xp = array_module(detector_tensor) + polarization_tensor = get_polarization_tensor(*promote_to_array((ra, dec, time, psi), xp), mode) + return three_by_three_matrix_contraction(detector_tensor, polarization_tensor) + + +@dispatch +def calculate_arm(arm_tilt, arm_azimuth, longitude, latitude): + """ + Calculate arm unit vector from tilt, azimuth, and location. + + Parameters + ========== + arm_tilt: float or array-like + Tilt angle of the arm from horizontal in radians. + arm_azimuth: float or array-like + Azimuth angle of the arm in radians. + longitude: float or array-like + Longitude of the detector in radians. + latitude: float or array-like + Latitude of the detector in radians. + + Returns + ======= + array-like + 3D unit vector (shape (3,) or (3, ...)) representing the arm direction. + """ + xp = array_module(arm_tilt) + e_long = xp.asarray([-xp.sin(longitude), xp.cos(longitude), longitude * 0]) + e_lat = xp.asarray( + [ + -xp.sin(latitude) * xp.cos(longitude), + -xp.sin(latitude) * xp.sin(longitude), + xp.cos(latitude), + ] + ) + e_h = xp.asarray( + [ + xp.cos(latitude) * xp.cos(longitude), + xp.cos(latitude) * xp.sin(longitude), + xp.sin(latitude), + ] + ) + + return ( + xp.cos(arm_tilt) * xp.cos(arm_azimuth) * e_long + + xp.cos(arm_tilt) * xp.sin(arm_azimuth) * e_lat + + xp.sin(arm_tilt) * e_h + ) + + +@dispatch +def detector_tensor(x, y): + """ + Calculate the detector tensor from x and y arm components. + + Parameters + ========== + x: array-like + 3D unit vector for the x arm. + y: array-like + 3D unit vector for the y arm. + + Returns + ======= + array-like + 3x3 detector tensor with components + :math:`d_{ij} = (x_i x_j - y_i y_j) / 2`. + """ + xp = array_module(x) + return (xp.outer(x, x) - xp.outer(y, y)) / 2 + + +@dispatch +def get_polarization_tensor(ra, dec, time, psi, mode): + """ + Calculate the polarization tensor for a given sky location and mode. + + Parameters + ========== + ra: float or array-like + Right ascension of the source in radians. + dec: float or array-like + Declination of the source in radians. + time: float or array-like + GPS time of the observation. + psi: float or array-like + Polarization angle in radians. + mode: str + Polarization mode: 'plus', 'cross', 'breathing', 'longitudinal', + 'x', or 'y'. + + Returns + ======= + array-like + 3x3 polarization tensor for the specified mode. + """ + from functools import partial + + xp = array_module(ra) + + gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) + phi = ra - gmst + theta = xp.atleast_1d(xp.pi / 2 - dec).squeeze() + u = xp.asarray( + [ + xp.cos(phi) * xp.cos(theta), + xp.cos(theta) * xp.sin(phi), + -xp.sin(theta) * xp.ones_like(phi), + ] + ) + v = xp.asarray([ + -xp.sin(phi), xp.cos(phi), xp.zeros_like(phi) + ]) * xp.ones_like(theta) + omega = xp.asarray([ + xp.sin(xp.pi - theta) * xp.cos(xp.pi + phi), + xp.sin(xp.pi - theta) * xp.sin(xp.pi + phi), + xp.cos(xp.pi - theta) * xp.ones_like(phi), + ]) + m = -u * xp.sin(psi) - v * xp.cos(psi) + n = -u * xp.cos(psi) + v * xp.sin(psi) + if xp.__name__ == "mlx.core": + einsum_shape = "i,j->ij" + else: + einsum_shape = "i...,j...->ij..." + product = partial(xp.einsum, einsum_shape) + + match mode.lower(): + case "plus": + return product(m, m) - product(n, n) + case "cross": + return product(m, n) + product(n, m) + case "breathing": + return product(m, m) + product(n, n) + case "longitudinal": + return product(omega, omega) + case "x": + return product(m, omega) + product(omega, m) + case "y": + return product(n, omega) + product(omega, n) + case _: + raise ValueError(f"{mode} not a polarization mode!") + + +@dispatch +def get_polarization_tensor_multiple_modes(ra, dec, time, psi, modes): + """ + Calculate polarization tensors for multiple modes. + + Parameters + ========== + ra: float or array-like + Right ascension of the source in radians. + dec: float or array-like + Declination of the source in radians. + time: float or array-like + GPS time of the observation. + psi: float or array-like + Polarization angle in radians. + modes: list of str + List of polarization modes to calculate. + + Returns + ======= + list + List of 3x3 polarization tensors, one for each mode. + """ + return [get_polarization_tensor(ra, dec, time, psi, mode) for mode in modes] + + +@dispatch +def rotation_matrix_from_delta(delta_x): + r""" + Calculate rotation matrix from a delta vector. + + Parameters + ========== + delta_x: array-like + 3D vector :math:`\vec{\Delta}x` representing the separation + or orientation. + + Returns + ======= + array-like + 3x3 rotation matrix that rotates the z-axis to align with + :math:`\vec{\Delta}x` direction. + """ + xp = array_module(delta_x) + delta_x = delta_x / (delta_x**2).sum() ** 0.5 + alpha = xp.arctan2(-delta_x[1] * delta_x[2], delta_x[0]) + beta = xp.arccos(delta_x[2]) + gamma = xp.arctan2(delta_x[1], delta_x[0]) + rotation_1 = xp.asarray( + [ + [xp.cos(alpha), -xp.sin(alpha), xp.zeros(alpha.shape)], + [xp.sin(alpha), xp.cos(alpha), xp.zeros(alpha.shape)], + [xp.zeros(alpha.shape), xp.zeros(alpha.shape), xp.ones(alpha.shape)], + ] + ) + rotation_2 = xp.asarray( + [ + [xp.cos(beta), xp.zeros(beta.shape), xp.sin(beta)], + [xp.zeros(beta.shape), xp.ones(beta.shape), xp.zeros(beta.shape)], + [-xp.sin(beta), xp.zeros(beta.shape), xp.cos(beta)], + ] + ) + rotation_3 = xp.asarray( + [ + [xp.cos(gamma), -xp.sin(gamma), xp.zeros(gamma.shape)], + [xp.sin(gamma), xp.cos(gamma), xp.zeros(gamma.shape)], + [xp.zeros(gamma.shape), xp.zeros(gamma.shape), xp.ones(gamma.shape)], + ] + ) + return rotation_3 @ rotation_2 @ rotation_1 + + +@dispatch +def three_by_three_matrix_contraction(a, b): + """ + Perform contraction of two 3x3 matrices. + + Parameters + ========== + a: array-like + First 3x3 matrix. + b: array-like + Second 3x3 matrix. + + Returns + ======= + float or array-like + Scalar result of the einsum contraction :math:`a_{ij} b_{ij}`. + """ + xp = array_module(a) + return xp.einsum("ij,ij->", a, b) + + +@dispatch +def time_delay_geocentric(detector1, detector2, ra, dec, time): + r""" + Calculate time delay between two detectors for a source direction. + + Parameters + ========== + detector1: array-like + 3D position vector of the first detector in meters. + detector2: array-like + 3D position vector of the second detector in meters. + ra: float or array-like + Right ascension of the source in radians. + dec: float or array-like + Declination of the source in radians. + time: float or array-like + GPS time of the observation. + + Returns + ======= + float or array-like + Time delay :math:`\Delta t = \hat{\omega} \cdot (\vec{d}_2 - \vec{d}_1) / c` + in seconds, where :math:`\hat{\omega}` is the unit vector to the + source and :math:`c` is the speed of light. + """ + xp = array_module(detector1) + gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) + speed_of_light = 299792458.0 + phi = ra - gmst + theta = xp.pi / 2 - dec + omega = xp.asarray( + [xp.sin(theta) * xp.cos(phi), xp.sin(theta) * xp.sin(phi), xp.cos(theta)] + ) + delta_d = detector2 - detector1 + return omega @ delta_d / speed_of_light + + +@dispatch +def time_delay_from_geocenter(detector1, ra, dec, time): + r""" + Calculate time delay from geocenter to a detector. + + Parameters + ========== + detector1: array-like + 3D position vector of the detector in meters. + ra: float or array-like + Right ascension of the source in radians. + dec: float or array-like + Declination of the source in radians. + time: float or array-like + GPS time of the observation. + + Returns + ======= + float or array-like + Time delay :math:`\Delta t = \hat{\omega} \cdot \vec{d} / c` in + seconds, where :math:`\vec{d}` is the detector position and + :math:`c` is the speed of light. + """ + xp = array_module(detector1) + return time_delay_geocentric(detector1, xp.zeros(3), ra, dec, time) + + +@dispatch +def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): + """ + Convert zenith/azimuth angles to theta/phi in a rotated frame. + + Parameters + ========== + zenith: float or array-like + Zenith angle in radians. + azimuth: float or array-like + Azimuth angle in radians. + delta_x: array-like + 3D vector defining the rotation frame. + + Returns + ======= + tuple of array-like + (theta, phi) angles in the rotated frame, both in radians. + """ + xp = array_module(delta_x) + omega_prime = xp.stack( + [ + xp.sin(zenith) * xp.cos(azimuth), + xp.sin(zenith) * xp.sin(azimuth), + xp.cos(zenith), + ] + ) + rotation_matrix = rotation_matrix_from_delta(delta_x) + omega = rotation_matrix @ omega_prime + theta = xp.arccos(omega[2]) + phi = xp.arctan2(omega[1], omega[0]) % (2 * xp.pi) + return theta, phi diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 8dfbcdbf5..5d028d49f 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -2,6 +2,7 @@ import os import copy +import array_api_compat as aac import attr import numpy as np from scipy.special import logsumexp @@ -107,9 +108,13 @@ class GravitationalWaveTransient(Likelihood): @attr.s(slots=True, weakref_slot=False) class _CalculatedSNRs: - d_inner_h = attr.ib(default=0j, converter=complex) - optimal_snr_squared = attr.ib(default=0, converter=float) - complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + # the complex converted breaks JAX compilation + # d_inner_h = attr.ib(default=0j, converter=complex) + # optimal_snr_squared = attr.ib(default=0, converter=float) + # complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + d_inner_h = attr.ib(default=0j) + optimal_snr_squared = attr.ib(default=0) + complex_matched_filter_snr = attr.ib(default=0j) d_inner_h_array = attr.ib(default=None) optimal_snr_squared_array = attr.ib(default=None) @@ -153,6 +158,7 @@ def __init__( self.waveform_generator = waveform_generator super(GravitationalWaveTransient, self).__init__() self.interferometers = InterferometerList(interferometers) + self.interferometers.set_array_backend(interferometers.array_backend) self.time_marginalization = time_marginalization self.distance_marginalization = distance_marginalization self.phase_marginalization = phase_marginalization @@ -165,6 +171,7 @@ def __init__( if "geocent" not in time_reference: self.time_reference = time_reference self.reference_ifo = get_empty_interferometer(self.time_reference) + self.reference_ifo.set_array_backend(self.interferometers.array_backend) if self.time_marginalization: logger.info("Cannot marginalise over non-geocenter time.") self.time_marginalization = False @@ -290,49 +297,50 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array optimal_snr_squared_array = None normalization = 4 / self.waveform_generator.duration + xp = aac.array_namespace(signal) if return_array is False: d_inner_h_array = None optimal_snr_squared_array = None elif self.time_marginalization and self.calibration_marginalization: - d_inner_h_integrand = np.tile( - interferometer.frequency_domain_strain.conjugate() * signal / + d_inner_h_integrand = xp.tile( + interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( + d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( d_inner_h_integrand[0:-1], axis=0 ).T optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * xp.abs(signal)**2 / interferometer.power_spectral_density_array ) - optimal_snr_squared_array = np.dot( + optimal_snr_squared_array = xp.dot( optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) elif self.time_marginalization and not self.calibration_marginalization: - d_inner_h_array = normalization * np.fft.fft( + d_inner_h_array = normalization * xp.fft.fft( signal[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] + * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1] ) elif self.calibration_marginalization and ('recalib_index' not in parameters): d_inner_h_integrand = ( normalization * - interferometer.frequency_domain_strain.conjugate() * signal + interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array ) - d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) + d_inner_h_array = xp.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * xp.abs(signal)**2 / interferometer.power_spectral_density_array ) - optimal_snr_squared_array = np.dot( + optimal_snr_squared_array = xp.dot( optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) @@ -391,12 +399,12 @@ def _calculate_noise_log_likelihood(self): log_l = 0 for interferometer in self.interferometers: mask = interferometer.frequency_mask - log_l -= noise_weighted_inner_product( + log_l -= abs(noise_weighted_inner_product( interferometer.frequency_domain_strain[mask], interferometer.frequency_domain_strain[mask], interferometer.power_spectral_density_array[mask], - self.waveform_generator.duration) / 2 - return float(np.real(log_l)) + self.waveform_generator.duration) / 2) + return log_l def noise_log_likelihood(self): # only compute likelihood if called for the 1st time @@ -406,7 +414,7 @@ def noise_log_likelihood(self): def log_likelihood_ratio(self, parameters): parameters = copy.deepcopy(parameters) - + parameters.update(self.get_sky_frame_parameters(parameters)) waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(parameters) if waveform_polarizations is None: @@ -415,8 +423,6 @@ def log_likelihood_ratio(self, parameters): if self.time_marginalization and self.jitter_time: parameters['geocent_time'] += parameters['time_jitter'] - parameters.update(self.get_sky_frame_parameters(parameters)) - total_snrs = self._CalculatedSNRs() for interferometer in self.interferometers: @@ -433,7 +439,7 @@ def log_likelihood_ratio(self, parameters): if self.time_marginalization and self.jitter_time: parameters['geocent_time'] -= parameters['time_jitter'] - return float(log_l.real) + return log_l.real def compute_log_likelihood_from_snrs(self, total_snrs, parameters): @@ -467,14 +473,13 @@ def compute_log_likelihood_from_snrs(self, total_snrs, parameters): return log_l def compute_per_detector_log_likelihood(self, parameters): + parameters.update(self.get_sky_frame_parameters(parameters)) waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(parameters) if self.time_marginalization and self.jitter_time: parameters['geocent_time'] += parameters['time_jitter'] - parameters.update(self.get_sky_frame_parameters(parameters)) - for interferometer in self.interferometers: per_detector_snr = self.calculate_snrs( waveform_polarizations=waveform_polarizations, @@ -767,12 +772,12 @@ def distance_marginalized_likelihood(self, d_inner_h, h_inner_h, *, parameters): d_inner_h_ref, h_inner_h_ref = self._setup_rho( d_inner_h, h_inner_h, parameters=parameters) if self.phase_marginalization: - d_inner_h_ref = np.abs(d_inner_h_ref) + d_inner_h_ref = abs(d_inner_h_ref) else: - d_inner_h_ref = np.real(d_inner_h_ref) + d_inner_h_ref = d_inner_h_ref.real return self._interp_dist_margd_loglikelihood( - d_inner_h_ref, h_inner_h_ref, grid=False) + d_inner_h_ref, h_inner_h_ref) def phase_marginalized_likelihood(self, d_inner_h, h_inner_h): d_inner_h = ln_i0(abs(d_inner_h)) @@ -787,14 +792,15 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, *, paramet if self.jitter_time: times = self._times + parameters['time_jitter'] - _time_prior = self.priors['geocent_time'] - time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) - times = times[time_mask] + if not aac.is_jax_array(d_inner_h_tc_array): + _time_prior = self.priors['geocent_time'] + time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) + times = times[time_mask] + if self.calibration_marginalization: + d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] + else: + d_inner_h_tc_array = d_inner_h_tc_array[time_mask] time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc - if self.calibration_marginalization: - d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] - else: - d_inner_h_tc_array = d_inner_h_tc_array[time_mask] if self.distance_marginalization: log_l_tc_array = self.distance_marginalized_likelihood( @@ -804,9 +810,9 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, *, paramet d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h) elif self.calibration_marginalization: - log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h[:, np.newaxis] / 2 + log_l_tc_array = d_inner_h_tc_array.real - h_inner_h[:, np.newaxis] / 2 else: - log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h / 2 + log_l_tc_array = d_inner_h_tc_array.real - h_inner_h / 2 return logsumexp(log_l_tc_array, b=time_prior_array, axis=-1) def get_calibration_log_likelihoods(self, signal_polarizations=None, *, parameters): @@ -917,8 +923,11 @@ def _setup_distance_marginalization(self, lookup_table=None): else: self._create_lookup_table() self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( - self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, - self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) + self._d_inner_h_ref_array, + self._optimal_snr_squared_ref_array, + self._dist_margd_loglikelihood_array.T, + fill_value=-np.inf, + ) @property def cached_lookup_table_filename(self): @@ -1072,6 +1081,8 @@ def reference_frame(self, frame): self._reference_frame = InterferometerList([frame[:2], frame[2:4]]) else: raise ValueError("Unable to parse reference frame {}".format(frame)) + if isinstance(self._reference_frame, InterferometerList): + self._reference_frame.set_array_backend(self.interferometers.array_backend) def get_sky_frame_parameters(self, parameters): """ diff --git a/bilby/gw/likelihood/basic.py b/bilby/gw/likelihood/basic.py index da67481f0..35b5994a6 100644 --- a/bilby/gw/likelihood/basic.py +++ b/bilby/gw/likelihood/basic.py @@ -43,10 +43,11 @@ def noise_log_likelihood(self): """ log_l = 0 for interferometer in self.interferometers: - log_l -= 2. / self.waveform_generator.duration * np.sum( - abs(interferometer.frequency_domain_strain) ** 2 / - interferometer.power_spectral_density_array) - return log_l.real + log_l -= 2. / self.waveform_generator.duration * ( + abs(interferometer.frequency_domain_strain) ** 2 + / interferometer.power_spectral_density_array + ).sum() + return log_l def log_likelihood(self, parameters): """ Calculates the real part of log-likelihood value @@ -85,8 +86,9 @@ def log_likelihood_interferometer(self, waveform_polarizations, signal_ifo = interferometer.get_detector_response( waveform_polarizations, parameters) - log_l = - 2. / self.waveform_generator.duration * np.vdot( - interferometer.frequency_domain_strain - signal_ifo, - (interferometer.frequency_domain_strain - signal_ifo) / - interferometer.power_spectral_density_array) + residual = interferometer.frequency_domain_strain - signal_ifo + + log_l = - 2. / self.waveform_generator.duration * ( + abs(residual)**2 / interferometer.power_spectral_density_array + ).sum() return log_l.real diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 7b746eefb..0fe5d8f47 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -5,6 +5,7 @@ import numpy as np from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import ( logger, speed_of_light, solar_mass, radius_of_earth, gravitational_constant, round_up_to_power_of_two, @@ -532,8 +533,10 @@ def _setup_linear_coefficients(self): for ifo in self.interferometers: logger.info("Pre-computing linear coefficients for {}".format(ifo.name)) fddata = np.zeros(N // 2 + 1, dtype=complex) - fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += \ + fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += np.asarray( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + ) + for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) @@ -550,7 +553,7 @@ def _setup_quadratic_coefficients_linear_interp(self): linear-interpolation algorithm""" logger.info("Linear-interpolation algorithm is used for (h, h).") self.quadratic_coeffs = dict((ifo.name, np.array([])) for ifo in self.interferometers) - original_duration = self.interferometers.duration + original_duration = float(self.interferometers.duration) for b in range(self.number_of_bands): logger.info(f"Pre-computing quadratic coefficients for the {b}-th band") @@ -574,7 +577,7 @@ def _setup_quadratic_coefficients_linear_interp(self): start_idx_in_band + len(window_sequence) - 1, len(ifo.power_spectral_density_array) - 1 ) - _frequency_mask = ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1] + _frequency_mask = np.asarray(ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1]) window_over_psd = np.zeros(end_idx_in_band + 1 - start_idx_in_band) window_over_psd[_frequency_mask] = \ 1. / ifo.power_spectral_density_array[start_idx_in_band:end_idx_in_band + 1][_frequency_mask] @@ -709,13 +712,10 @@ def setup_multibanding_from_weights(self, weights): setattr(self, key, value) def _setup_time_marginalization_multiband(self): - """This overwrites attributes set by _setup_time_marginalization of the base likelihood class""" N = self.Nbs[-1] // 2 self._delta_tc = self.durations[0] / N - self._times = \ - self.interferometers.start_time + np.arange(N) * self._delta_tc - self.time_prior_array = \ - self.priors['geocent_time'].prob(self._times) * self._delta_tc + self._times = self.interferometers.start_time + np.arange(N) * self._delta_tc + self.time_prior_array = self.priors['geocent_time'].prob(self._times) * self._delta_tc # allocate array which is FFTed at each likelihood evaluation self._full_d_h = np.zeros(N, dtype=complex) # idxs to convert full frequency points to banded frequency points, used for filling _full_d_h. @@ -755,12 +755,14 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array modes, parameters, frequencies=self.banded_frequency_points ) - d_inner_h = np.conj(np.dot(strain, self.linear_coeffs[interferometer.name])) + xp = array_module(strain) + + d_inner_h = xp.conj(xp.dot(strain, xp.asarray(self.linear_coeffs[interferometer.name]))) if self.linear_interpolation: - optimal_snr_squared = np.vdot( - np.real(strain * np.conjugate(strain)), - self.quadratic_coeffs[interferometer.name] + optimal_snr_squared = xp.vdot( + xp.abs(strain)**2, + xp.asarray(self.quadratic_coeffs[interferometer.name]) ) else: optimal_snr_squared = 0. @@ -769,18 +771,21 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array start_idx, end_idx = self.start_end_idxs[b] Mb = self.Mbs[b] if b == 0: - optimal_snr_squared += (4. / self.interferometers.duration) * np.vdot( - np.real(strain[start_idx:end_idx + 1] * np.conjugate(strain[start_idx:end_idx + 1])), - interferometer.frequency_mask[Ks:Ke + 1] * self.windows[start_idx:end_idx + 1] + optimal_snr_squared += (4. / self.interferometers.duration) * xp.vdot( + xp.abs(strain[start_idx:end_idx + 1])**2, + interferometer.frequency_mask[Ks:Ke + 1] * xp.asarray(self.windows[start_idx:end_idx + 1]) / interferometer.power_spectral_density_array[Ks:Ke + 1]) else: self.wths[interferometer.name][b][Ks:Ke + 1] = ( - self.square_root_windows[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1] + xp.asarray(self.square_root_windows[start_idx:end_idx + 1]) + * strain[start_idx:end_idx + 1] + ) + self.hbcs[interferometer.name][b][-Mb:] = xp.fft.irfft( + xp.asarray(self.wths[interferometer.name][b]) ) - self.hbcs[interferometer.name][b][-Mb:] = np.fft.irfft(self.wths[interferometer.name][b]) - thbc = np.fft.rfft(self.hbcs[interferometer.name][b]) - optimal_snr_squared += (4. / self.Tbhats[b]) * np.vdot( - np.real(thbc * np.conjugate(thbc)), self.Ibcs[interferometer.name][b]) + thbc = xp.fft.rfft(xp.asarray(self.hbcs[interferometer.name][b])) + optimal_snr_squared += (4. / self.Tbhats[b]) * xp.vdot( + xp.abs(thbc)**2, xp.asarray(self.Ibcs[interferometer.name][b].real)) complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) @@ -790,7 +795,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array start_idx, end_idx = self.start_end_idxs[b] self._full_d_h[self._full_to_multiband[start_idx:end_idx + 1]] += \ strain[start_idx:end_idx + 1] * self.linear_coeffs[interferometer.name][start_idx:end_idx + 1] - d_inner_h_array = np.fft.fft(self._full_d_h) + d_inner_h_array = xp.fft.fft(self._full_d_h) else: d_inner_h_array = None diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index 1928c013d..72c23c958 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -4,6 +4,7 @@ from scipy.optimize import differential_evolution from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import logger from ...core.prior.base import Constraint from ...core.prior import DeltaFunction @@ -258,7 +259,7 @@ def set_fiducial_waveforms(self, parameters): for interferometer in self.interferometers: logger.debug(f"Maximum Frequency is {interferometer.maximum_frequency}") wf = interferometer.get_detector_response(self.fiducial_polarizations, parameters) - wf[interferometer.frequency_array > self.maximum_frequency] = 0 + wf *= interferometer.frequency_array <= self.maximum_frequency self.per_detector_fiducial_waveforms[interferometer.name] = wf def find_maximum_likelihood_parameters(self, parameter_bounds, @@ -332,7 +333,7 @@ def compute_summary_data(self): masked_bin_inds[-1] += 1 masked_strain = interferometer.frequency_domain_strain[mask] - masked_h0 = self.per_detector_fiducial_waveforms[interferometer.name][mask] + masked_h0 = np.asarray(self.per_detector_fiducial_waveforms[interferometer.name][mask]) masked_psd = interferometer.power_spectral_density_array[mask] duration = interferometer.duration a0, b0, a1, b1 = np.zeros((4, self.number_of_bins), dtype=complex) @@ -401,20 +402,21 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array parameters=parameters, ) a0, a1, b0, b1 = self.summary_data[interferometer.name] - d_inner_h = np.sum(a0 * np.conjugate(r0) + a1 * np.conjugate(r1)) - h_inner_h = np.sum(b0 * np.abs(r0) ** 2 + 2 * b1 * np.real(r0 * np.conjugate(r1))) + d_inner_h = (a0 * r0.conj() + a1 * r1.conj()).sum() + h_inner_h = (b0 * abs(r0) ** 2 + 2 * b1 * (r0 * r1.conj()).real).sum() optimal_snr_squared = h_inner_h complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) if return_array and self.time_marginalization: + xp = array_module(r0) full_waveform = self._compute_full_waveform( signal_polarizations=waveform_polarizations, interferometer=interferometer, parameters=parameters, ) - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( + d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( full_waveform[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] + * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1]) else: diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 46a88d09c..d54091dfd 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -1,7 +1,9 @@ - +import array_api_compat as aac +import array_api_extra as xpx import numpy as np from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import ( logger, create_frequency_series, speed_of_light, radius_of_earth ) @@ -270,15 +272,16 @@ def _set_unique_frequency_nodes_and_inverse(self): """Set unique frequency nodes and indices to recover linear and quadratic frequency nodes for each combination of linear and quadratic bases """ + xp = aac.array_namespace(self.interferometers.frequency_array) self._unique_frequency_nodes_and_inverse = [] for idx_linear in range(self.number_of_bases_linear): tmp = [] - frequency_nodes_linear = self.weights['frequency_nodes_linear'][idx_linear] + frequency_nodes_linear = xp.asarray(self.weights['frequency_nodes_linear'][idx_linear]) size_linear = len(frequency_nodes_linear) for idx_quadratic in range(self.number_of_bases_quadratic): - frequency_nodes_quadratic = self.weights['frequency_nodes_quadratic'][idx_quadratic] - frequency_nodes_unique, original_indices = np.unique( - np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), + frequency_nodes_quadratic = xp.asarray(self.weights['frequency_nodes_quadratic'][idx_quadratic]) + frequency_nodes_unique, original_indices = xp.unique( + xp.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), return_inverse=True ) linear_indices = original_indices[:size_linear] @@ -454,10 +457,8 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array frequency_nodes = self.waveform_generator.waveform_arguments['frequency_nodes'] linear_indices = self.waveform_generator.waveform_arguments['linear_indices'] quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices'] - size_linear = len(linear_indices) - size_quadratic = len(quadratic_indices) - h_linear = np.zeros(size_linear, dtype=complex) - h_quadratic = np.zeros(size_quadratic, dtype=complex) + h_linear = 0j + h_quadratic = 0j for mode in waveform_polarizations['linear']: response = interferometer.antenna_response( parameters['ra'], parameters['dec'], @@ -468,14 +469,15 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array h_linear += waveform_polarizations['linear'][mode] * response h_quadratic += waveform_polarizations['quadratic'][mode] * response + xp = array_module(h_linear) calib_factor = interferometer.calibration_model.get_calibration_factor( - frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **parameters) + xp.asarray(frequency_nodes), prefix=f'recalib_{interferometer.name}_', xp=xp, **parameters) h_linear *= calib_factor[linear_indices] h_quadratic *= calib_factor[quadratic_indices] - optimal_snr_squared = np.vdot( - np.abs(h_quadratic)**2, - self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic] + optimal_snr_squared = xp.vdot( + xp.abs(h_quadratic)**2, + xp.asarray(self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic]) ) dt = interferometer.time_delay_from_geocenter( @@ -484,21 +486,25 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array ifo_time = dt_geocent + dt indices, in_bounds = self._closest_time_indices( - ifo_time, self.weights['time_samples']) - if not in_bounds: - logger.debug("SNR calculation error: requested time at edge of ROQ time samples") - d_inner_h = -np.inf - complex_matched_filter_snr = -np.inf - else: - d_inner_h_tc_array = np.einsum( - 'i,ji->j', np.conjugate(h_linear), - self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) + ifo_time, xp.asarray(self.weights['time_samples'])) + indices = xp.clip(xp.asarray(indices), 0, len(self.weights['time_samples']) - 1) + d_inner_h_tc_array = xp.einsum( + 'i,ji->j', + xp.conj(h_linear), + xp.asarray( + self.weights[interferometer.name + '_linear'][self.basis_number_linear] + )[indices], + ) + + d_inner_h = self._interp_five_samples( + xp.asarray(self.weights['time_samples'])[indices], d_inner_h_tc_array, ifo_time + ) - d_inner_h = self._interp_five_samples( - self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + with np.errstate(invalid="ignore"): + complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) - with np.errstate(invalid="ignore"): - complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) + d_inner_h += xp.log(in_bounds) + complex_matched_filter_snr += xp.log(in_bounds) if return_array and self.time_marginalization: ifo_times = self._times - interferometer.strain_data.start_time @@ -535,9 +541,10 @@ def _closest_time_indices(time, samples): in_bounds: bool Whether the indices are for valid times """ - closest = int((time - samples[0]) / (samples[1] - samples[0])) + xp = array_module(time) + closest = xp.astype(xp.floor((time - samples[0]) / (samples[1] - samples[0])), int) indices = [closest + ii for ii in [-2, -1, 0, 1, 2]] - in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) + in_bounds = (indices[0] >= 0) & (indices[-1] < len(samples)) return indices, in_bounds @staticmethod @@ -562,13 +569,13 @@ def _interp_five_samples(time_samples, values, time): """ r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. r2 = values[2] - 2. * values[3] + values[4] - a = (time_samples[3] - time) / (time_samples[1] - time_samples[0]) + a = (time_samples[3] - time) / max(time_samples[1] - time_samples[0], 1e-12) b = 1. - a c = (a**3. - a) / 6. d = (b**3. - b) / 6. return a * values[2] + b * values[3] + c * r1 + d * r2 - def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): + def _calculate_d_inner_h_array(self, times, h_linear, ifo_name, *, xp=None): """ Calculate d_inner_h at regularly-spaced time samples. Each value is interpolated from the nearest 5 samples with the algorithm explained in @@ -586,21 +593,23 @@ def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): ======= d_inner_h_array: array-like """ + if xp is None: + xp = aac.array_namespace(h_linear) roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space - closest_idxs = np.floor(times_per_roq_time_space).astype(int) + closest_idxs = xp.astype(xp.floor(times_per_roq_time_space), int) # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time # spacing is larger than 5 times the ROQ time spacing. weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear] - h_linear_conj = np.conjugate(h_linear) + h_linear_conj = h_linear.conj() if (times[1] - times[0]) / roq_time_space > 5: - d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) - d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj) - d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj) - d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj) - d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj) + d_inner_h_m2 = weights_linear[closest_idxs - 2] @ h_linear_conj + d_inner_h_m1 = weights_linear[closest_idxs - 1] @ h_linear_conj + d_inner_h_0 = weights_linear[closest_idxs] @ h_linear_conj + d_inner_h_p1 = weights_linear[closest_idxs + 1] @ h_linear_conj + d_inner_h_p2 = weights_linear[closest_idxs + 2] @ h_linear_conj else: - d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj) + d_inner_h_at_roq_time_samples = weights_linear @ h_linear_conj d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2] d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1] d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs] @@ -652,17 +661,17 @@ def perform_roq_params_check(self, ifo=None): except ValueError: roq_minimum_component_mass = None - if ifo.maximum_frequency > roq_maximum_frequency: + if float(ifo.maximum_frequency) > roq_maximum_frequency: raise BilbyROQParamsRangeError( "Requested maximum frequency {} larger than ROQ basis fhigh {}" .format(ifo.maximum_frequency, roq_maximum_frequency) ) - if ifo.minimum_frequency < roq_minimum_frequency: + if float(ifo.minimum_frequency) < roq_minimum_frequency: raise BilbyROQParamsRangeError( "Requested minimum frequency {} lower than ROQ basis flow {}" .format(ifo.minimum_frequency, roq_minimum_frequency) ) - if ifo.strain_data.duration != roq_segment_length: + if float(ifo.strain_data.duration) != roq_segment_length: raise BilbyROQParamsRangeError( "Requested duration differs from ROQ basis seglen") @@ -708,6 +717,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): linear and quadratic basis """ + xp = aac.array_namespace(self.interferometers.frequency_array) time_space = self._get_time_resolution() number_of_time_samples = int(self.interferometers.duration / time_space) earth_light_crossing_time = 2 * radius_of_earth / speed_of_light + 5 * time_space @@ -727,7 +737,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): - self.interferometers.start_time ) / time_space)) ) - self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space + self.weights['time_samples'] = xp.arange(start_idx, end_idx + 1) * float(time_space) logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples']))) # select bases to be used, set prior ranges and frequency nodes if exist @@ -780,10 +790,10 @@ def _set_weights(self, linear_matrix, quadratic_matrix): roq_mask = roq_frequencies >= roq_scaled_minimum_frequency roq_frequencies = roq_frequencies[roq_mask] overlap_frequencies, ifo_idxs_this_ifo, roq_idxs_this_ifo = np.intersect1d( - ifo.frequency_array[ifo.frequency_mask], roq_frequencies, + np.asarray(ifo.frequency_array[ifo.frequency_mask]), roq_frequencies, return_indices=True) else: - overlap_frequencies = ifo.frequency_array[ifo.frequency_mask] + overlap_frequencies = np.asarray(ifo.frequency_array[ifo.frequency_mask]) roq_idxs_this_ifo = np.arange( linear_matrix['basis_linear'][str(idxs_in_prior_range['linear'][0])]['basis'].shape[1], dtype=int) @@ -839,32 +849,44 @@ def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): data_over_psd = {} for ifo in self.interferometers: nonzero_idxs[ifo.name] = ifo_idxs[ifo.name] + int( - ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration) - data_over_psd[ifo.name] = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] / \ - ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] - try: - import pyfftw - ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') - except ImportError: + ifo.minimum_frequency * self.interferometers.duration) + data_over_psd[ifo.name] = ( + ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] + / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] + ) + xp = array_module(data_over_psd) + if aac.is_numpy_namespace(xp): + try: + import pyfftw + ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) + ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) + ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') + except ImportError: + pyfftw = None + logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") + ifft_input = np.zeros(number_of_time_samples, dtype=complex) + ifft = np.fft.ifft + else: pyfftw = None - logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") - ifft_input = np.zeros(number_of_time_samples, dtype=complex) - ifft = np.fft.ifft + ifft_input = xp.zeros(number_of_time_samples, dtype=complex) + ifft = xp.fft.ifft for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'] + linear_matrix_single = xp.asarray(linear_matrix['basis_linear'][str(basis_idx)]['basis']) basis_size = linear_matrix_single.shape[0] for ifo in self.interferometers: - ifft_input[:] *= 0. + if pyfftw: + ifft_input[:] *= 0. + else: + ifft_input *= 0 linear_weights = \ - np.zeros((len(self.weights['time_samples']), basis_size), dtype=complex) + xp.zeros((basis_size, len(self.weights['time_samples'])), dtype=complex) for i in range(basis_size): - basis_element = linear_matrix_single[i][roq_idxs[ifo.name]] - ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element) - linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1] - linear_weights *= 4. * number_of_time_samples / self.interferometers.duration + basis_element = xp.asarray(linear_matrix_single[i][roq_idxs[ifo.name]]).conj() + ifft_input = xpx.at(ifft_input, nonzero_idxs[ifo.name]).set(data_over_psd[ifo.name] * basis_element) + linear_weights = xpx.at(linear_weights, i).set(ifft(ifft_input)[start_idx:end_idx + 1]) + linear_weights = linear_weights.T + linear_weights *= 4. * number_of_time_samples / float(self.interferometers.duration) self.weights[ifo.name + '_linear'].append(linear_weights) if pyfftw is not None: pyfftw.forget_wisdom() @@ -883,6 +905,7 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): """ for ifo in self.interferometers: self.weights[ifo.name + '_linear'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) Tbs = linear_matrix['durations_s_linear'][()] / self.roq_scale_factor start_end_frequency_bins = linear_matrix['start_end_frequency_bins_linear'][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) @@ -890,29 +913,39 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): # prepare time-shifted data, which is multiplied by basis tc_shifted_data = dict() for ifo in self.interferometers: - over_whitened_frequency_data = np.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) - over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = \ + over_whitened_frequency_data = xp.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) + over_whitened_frequency_data = xpx.at( + over_whitened_frequency_data, xp.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask] + ).set( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] - over_whitened_time_data = np.fft.irfft(over_whitened_frequency_data) - tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) + ) + over_whitened_time_data = xp.fft.irfft(over_whitened_frequency_data) + tc_shifted_data[ifo.name] = xp.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] - fs = np.arange(start_frequency_bin, end_frequency_bin + 1) / Tb - Db = np.fft.rfft( + fs = xp.arange(start_frequency_bin, end_frequency_bin + 1) / Tb + Db = xp.fft.rfft( over_whitened_time_data[-int(2. * fhigh_basis * Tb):] )[start_frequency_bin:end_frequency_bin + 1] start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp( - 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - ifo.duration + Tb)) + this_data = xp.zeros(len(self.weights['time_samples']), dtype=complex) + sl = slice(start_idx_of_band, start_idx_of_next_band) + this_data = ( + 4. / Tb * Db[:, None] * xp.exp( + 2. * np.pi * 1j * fs[:, None] + * (xp.asarray(self.weights['time_samples'][None, :]) - ifo.duration + Tb) + ) + ) + tc_shifted_data[ifo.name] = xpx.at(tc_shifted_data[ifo.name], sl).set(this_data) + start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'][()] + linear_matrix_single = xp.asarray(linear_matrix['basis_linear'][str(basis_idx)]['basis'][()]) for ifo in self.interferometers: - self.weights[ifo.name + '_linear'].append( - np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T) + self.weights[ifo.name + '_linear'].append((linear_matrix_single.conj() @ tc_shifted_data[ifo.name]).T) def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idxs): """ @@ -934,14 +967,15 @@ def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idx """ for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = xp.asarray(quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real) for ifo in self.interferometers: + inv_psd = xp.asarray(1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]]) self.weights[ifo.name + '_quadratic'].append( - 4. / ifo.strain_data.duration * np.dot( - quadratic_matrix_single[:, roq_idxs[ifo.name]], - 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]])) + 4. / ifo.strain_data.duration * quadratic_matrix_single[:, roq_idxs[ifo.name]] @ inv_psd + ) del quadratic_matrix_single def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): @@ -958,6 +992,7 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): """ for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) Tbs = quadratic_matrix['durations_s_quadratic'][()] / self.roq_scale_factor start_end_frequency_bins = quadratic_matrix['start_end_frequency_bins_quadratic'][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) @@ -965,27 +1000,31 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): # prepare coefficients multiplied by basis multibanded_inverse_psd = dict() for ifo in self.interferometers: - inverse_psd_frequency = np.zeros(int(fhigh_basis * ifo.duration) + 1) - inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = \ - 1. / ifo.power_spectral_density_array[ifo.frequency_mask] - inverse_psd_time = np.fft.irfft(inverse_psd_frequency) - multibanded_inverse_psd[ifo.name] = np.zeros(basis_dimension) + inverse_psd_frequency = xp.zeros(int(fhigh_basis * ifo.duration) + 1) + sl = np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask] + inverse_psd_frequency = xpx.at(inverse_psd_frequency, sl).set( + 1. / xp.asarray(ifo.power_spectral_density_array[ifo.frequency_mask]) + ) + inverse_psd_time = xp.fft.irfft(inverse_psd_frequency) + multibanded_inverse_psd[ifo.name] = xp.zeros(basis_dimension) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] number_of_samples_half = int(fhigh_basis * Tb) start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * np.fft.rfft( - np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]) + sl = slice(start_idx_of_band, start_idx_of_next_band) + this_data = 4. / Tb * xp.fft.rfft( + xp.concat([inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]]) )[start_frequency_bin:end_frequency_bin + 1].real + multibanded_inverse_psd[ifo.name] = xpx.at(multibanded_inverse_psd[ifo.name], sl).set(this_data) start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = xp.asarray(quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real) for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'].append( - np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name])) + quadratic_matrix_single @ multibanded_inverse_psd[ifo.name]) def save_weights(self, filename, format='hdf5'): """ @@ -1200,8 +1239,8 @@ def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations times = self._times if self.jitter_time: times = times + parameters["time_jitter"] - time_prior_array = self.priors['geocent_time'].prob(times) - time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array + time_prior_array = np.asarray(self.priors['geocent_time'].prob(times)) + time_post = np.exp(np.asarray(time_log_like - max(time_log_like))) * time_prior_array time_post /= np.sum(time_post) return random.rng.choice(times, p=time_post) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index e262eaaf3..9122a0c79 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1,12 +1,14 @@ import os import copy +import array_api_extra as xpx import numpy as np from scipy.integrate import cumulative_trapezoid, trapezoid, quad from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm +from ..compat.utils import xp_wrap from ..core.prior import ( PriorDict, Uniform, Prior, DeltaFunction, Gaussian, Interped, Constraint, conditional_prior_factory, PowerLaw, ConditionalLogUniform, @@ -430,23 +432,25 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', def _integral(q): return -5. * q**(-1. / 5.) * hyp2f1(-2. / 5., -1. / 5., 4. / 5., -q) - def cdf(self, val): + def cdf(self, val, *, xp=np): return (self._integral(val) - self._integral(self.minimum)) / self.norm - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): if self.equal_mass: - val = 2 * np.minimum(val, 1 - val) + val = 2 * xp.minimum(val, 1 - val) return self.icdf(val) - def prob(self, val): + def prob(self, val, *, xp=None): in_prior = (val >= self.minimum) & (val <= self.maximum) with np.errstate(invalid="ignore"): prob = (1. + val)**(2. / 5.) / (val**(6. / 5.)) / self.norm * in_prior return prob - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): with np.errstate(divide="ignore"): - return np.log(self.prob(val)) + return xp.log(self.prob(val, xp=xp)) class AlignedSpin(Interped): @@ -511,7 +515,7 @@ def integrand(aa, chi): after performing the integral over spin orientation using a delta function identity. """ - return a_prior.prob(aa) * z_prior.prob(chi / aa) / aa + return a_prior.prob(aa, xp=None) * z_prior.prob(chi / aa, xp=None) / aa self.num_interp = 10_000 if num_interp is None else num_interp xx = np.linspace(chi_min, chi_max, self.num_interp) @@ -600,21 +604,26 @@ def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary self.__class__.__name__ = "ConditionalChiInPlane" self.__class__.__qualname__ = "ConditionalChiInPlane" - def prob(self, val, **required_variables): - self.update_conditions(**required_variables) + @xp_wrap + def prob(self, val, *, xp=np, **required_variables): + parameters = self.condition_func(self.reference_params.copy(), **required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) + minimum = parameters.get("minimum", self.minimum) + maximum = parameters.get("maximum", self.maximum) return ( - (val >= self.minimum) * (val <= self.maximum) + (val >= minimum) * (val <= maximum) * val / (chi_aligned ** 2 + val ** 2) - / np.log(self._reference_maximum / chi_aligned) + / xp.log(self._reference_maximum / chi_aligned) ) - def ln_prob(self, val, **required_variables): + @xp_wrap + def ln_prob(self, val, *, xp=np, **required_variables): with np.errstate(divide="ignore"): - return np.log(self.prob(val, **required_variables)) + return xp.log(self.prob(val, **required_variables)) - def cdf(self, val, **required_variables): + @xp_wrap + def cdf(self, val, *, xp=np, **required_variables): r""" .. math:: \text{CDF}(\chi_\per) = N ln(1 + (\chi_\perp / \chi) ** 2) @@ -634,14 +643,15 @@ def cdf(self, val, **required_variables): """ self.update_conditions(**required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) - return np.maximum(np.minimum( + return xp.clip( (val >= self.minimum) * (val <= self.maximum) - * np.log(1 + (val / chi_aligned) ** 2) - / 2 / np.log(self._reference_maximum / chi_aligned) - , 1 - ), 0) + * xp.log(1 + (val / chi_aligned) ** 2) + / 2 / xp.log(self._reference_maximum / chi_aligned), + 0, + 1 + ) - def rescale(self, val, **required_variables): + def rescale(self, val, *, xp=np, **required_variables): r""" .. math:: \text{PPF}(\chi_\perp) = ((a_\max / \chi) ** (2x) - 1) ** 0.5 * \chi @@ -664,9 +674,9 @@ def rescale(self, val, **required_variables): def _condition_function(self, reference_params, **kwargs): with np.errstate(invalid="ignore"): - maximum = np.sqrt( + maximum = ( self._reference_maximum ** 2 - kwargs[self._required_variables[0]] ** 2 - ) + )**0.5 return dict(minimum=0, maximum=maximum) def __repr__(self): @@ -690,13 +700,13 @@ def __init__(self, minimum=-np.inf, maximum=np.inf): super().__init__(minimum=minimum, maximum=maximum, name=None, latex_label=None, unit=None) - def prob(self, val): + def prob(self, val, *, xp=np): """ Returns the result of the equation of state check in the conversion function. """ return val - def ln_prob(self, val): + def ln_prob(self, val, *, xp=np): if val: result = 0.0 @@ -1516,7 +1526,8 @@ def _check_imports(): raise ImportError("Must have healpy installed on this machine to use HealPixMapPrior") return healpy - def _rescale(self, samp, **kwargs): + @xp_wrap + def _rescale(self, samp, *, xp=None, **kwargs): """ Overwrites the _rescale method of BaseJoint Prior to rescale a single value from the unitcube onto two values (ra, dec) or 3 (ra, dec, dist) if distance is included @@ -1539,17 +1550,19 @@ def _rescale(self, samp, **kwargs): else: samp = samp[:, 0] pix_rescale = self.inverse_cdf(samp) - sample = np.empty((len(pix_rescale), 2)) - dist_samples = np.empty((len(pix_rescale))) + sample = xp.empty((len(pix_rescale), 2)) + dist_samples = xp.empty((len(pix_rescale))) for i, val in enumerate(pix_rescale): theta, ra = self.hp.pix2ang(self.nside, int(round(val))) dec = 0.5 * np.pi - theta - sample[i, :] = self.draw_from_pixel(ra, dec, int(round(val))) + sample = xpx.at(sample, i).set(xp.asarray(self.draw_from_pixel(ra, dec, int(round(val))))) if self.distance: self.update_distance(int(round(val))) - dist_samples[i] = self.distance_icdf(dist_samp[i]) + dist_samples = xpx.at(dist_samples, i).set( + xp.asarray(self.distance_icdf(dist_samp[i])) + ) if self.distance: - sample = np.vstack([sample[:, 0], sample[:, 1], dist_samples]) + sample = xp.vstack([sample[:, 0], sample[:, 1], dist_samples]) return sample.reshape((-1, self.num_vars)) def update_distance(self, pix_idx): @@ -1595,7 +1608,7 @@ def _check_norm(array): norm = np.finfo(array.dtype).eps return array / norm - def _sample(self, size, **kwargs): + def _sample(self, size, *, random_state=None, **kwargs): """ Overwrites the _sample method of BaseJoint Prior. Picks a pixel value according to their probabilities, then uniformly samples ra, and decs that are contained in chosen pixel. If the PriorDist includes distance it then @@ -1614,21 +1627,25 @@ def _sample(self, size, **kwargs): sample : array_like sample of ra, and dec (and distance if 3D=True) """ - sample_pix = random.rng.choice(self.npix, size=size, p=self.prob, replace=True) - sample = np.empty((size, self.num_vars)) + rng = random.resolve_random_state(random_state) + xp = random.random_array_module(rng) + + sample_pix = rng.choice(self.npix, size=size, p=self.prob, replace=True) + sample = xp.empty((size, self.num_vars)) for samp in range(size): theta, ra = self.hp.pix2ang(self.nside, sample_pix[samp]) dec = 0.5 * np.pi - theta if self.distance: self.update_distance(sample_pix[samp]) - dist = self.draw_distance(sample_pix[samp]) - ra_dec = self.draw_from_pixel(ra, dec, sample_pix[samp]) - sample[samp, :] = [ra_dec[0], ra_dec[1], dist] + dist = self.draw_distance(sample_pix[samp], random_state=rng) + ra, dec = self.draw_from_pixel(ra, dec, sample_pix[samp], random_state=rng) + new = [ra, dec, dist] else: - sample[samp, :] = self.draw_from_pixel(ra, dec, sample_pix[samp]) - return sample.reshape((-1, self.num_vars)) + new = self.draw_from_pixel(ra, dec, sample_pix[samp]) + sample = xpx.at(sample, samp).set(xp.asarray(new)) + return xp.asarray(sample.reshape((-1, self.num_vars))) - def draw_distance(self, pix): + def draw_distance(self, pix, *, random_state=None): """ Method to recursively draw a distance value from the given set distance distribution and check that it is in the bounds @@ -1644,16 +1661,18 @@ def draw_distance(self, pix): dist : float sample drawn from the distance distribution at set pixel index """ + rng = random.resolve_random_state(random_state) + if self.distmu[pix] == np.inf or self.distmu[pix] <= 0: return 0 - dist = self.distance_icdf(random.rng.uniform(0, 1)) + dist = self.distance_icdf(rng.uniform(0, 1)) name = self.names[-1] if (dist > self.bounds[name][1]) | (dist < self.bounds[name][0]): - self.draw_distance(pix) + self.draw_distance(pix, random_state=rng) else: return dist - def draw_from_pixel(self, ra, dec, pix): + def draw_from_pixel(self, ra, dec, pix, *, random_state=None): """ Recursive function to uniformly draw ra, and dec values that are located in the given pixel @@ -1671,12 +1690,14 @@ def draw_from_pixel(self, ra, dec, pix): ra_dec : tuple this returns a tuple of ra, and dec sampled uniformly that are in the pixel given """ + rng = random.resolve_random_state(random_state) + if not self.check_in_pixel(ra, dec, pix): - self.draw_from_pixel(ra, dec, pix) + self.draw_from_pixel(ra, dec, pix, random_state=rng) return np.array( [ - random.rng.uniform(ra - self.pixel_length, ra + self.pixel_length), - random.rng.uniform(dec - self.pixel_length, dec + self.pixel_length), + rng.uniform(ra - self.pixel_length, ra + self.pixel_length), + rng.uniform(dec - self.pixel_length, dec + self.pixel_length), ] ) @@ -1705,7 +1726,8 @@ def check_in_pixel(self, ra, dec, pix): pixel = self.hp.ang2pix(self.nside, theta, phi) return pix == pixel - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): """ Overwrites the _lnprob method of BaseJoint Prior @@ -1731,11 +1753,15 @@ def _ln_prob(self, samp, lnprob, outbounds): phi, dec = samp[0] theta = 0.5 * np.pi - dec pixel = self.hp.ang2pix(self.nside, theta, phi) - lnprob[i] = np.log(self.prob[pixel] / self.pixel_area) + lnprob = xpx.at(lnprob, i).set( + xp.log(xp.asarray(self.prob[pixel] / self.pixel_area)) + ) if self.distance: self.update_distance(pixel) - lnprob[i] += np.log(self.distance_pdf(dist) * dist ** 2) - lnprob[outbounds] = -np.inf + lnprob = xpx.at(lnprob, i).set( + lnprob[i] + xp.log(xp.asarray(self.distance_pdf(dist) * dist ** 2)) + ) + lnprob = xp.where(xp.asarray(outbounds), -np.inf, lnprob) return lnprob def __eq__(self, other): diff --git a/bilby/gw/sampler/proposal.py b/bilby/gw/sampler/proposal.py index 79e1ec92c..2ac84687e 100644 --- a/bilby/gw/sampler/proposal.py +++ b/bilby/gw/sampler/proposal.py @@ -13,7 +13,7 @@ class SkyLocationWanderJump(JumpProposal): def __call__(self, sample, **kwargs): temperature = 1 / kwargs.get('inverse_temperature', 1.0) - sigma = np.sqrt(temperature) / 2 / np.pi + sigma = temperature**0.5 / 2 / np.pi sample['ra'] += random.gauss(0, sigma) sample['dec'] += random.gauss(0, sigma) return super(SkyLocationWanderJump, self).__call__(sample) diff --git a/bilby/gw/source.py b/bilby/gw/source.py index 78da709ba..11411b468 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -1,5 +1,6 @@ import numpy as np +from ..compat.utils import array_module from ..core import utils from ..core.utils import logger from .conversion import bilby_to_lalsimulation_spins @@ -1188,20 +1189,22 @@ def sinegaussian(frequency_array, hrss, Q, frequency, **kwargs): dict: Dictionary containing the plus and cross components of the strain. """ - tau = Q / (np.sqrt(2.0) * np.pi * frequency) - temp = Q / (4.0 * np.sqrt(np.pi) * frequency) + xp = array_module(frequency_array) + tau = Q / (2.0**0.5 * np.pi * frequency) + temp = Q / (4.0 * np.pi**0.5 * frequency) fm = frequency_array - frequency fp = frequency_array + frequency - h_plus = ((hrss / np.sqrt(temp * (1 + np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) + - np.exp(-fp**2 * np.pi**2 * tau**2))) + negative_term = xp.exp(-fm**2 * np.pi**2 * tau**2) + positive_term = xp.exp(-fp**2 * np.pi**2 * tau**2) - h_cross = (-1j * (hrss / np.sqrt(temp * (1 - np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) - - np.exp(-fp**2 * np.pi**2 * tau**2))) + h_plus = hrss * np.pi**0.5 * tau / 2 * ( + negative_term + positive_term + ) / (temp * (1 + xp.exp(-Q**2)))**0.5 + + h_cross = -1j * hrss * np.pi**0.5 * tau / 2 * ( + negative_term - positive_term + ) / (temp * (1 - xp.exp(-Q**2)))**0.5 return {'plus': h_plus, 'cross': h_cross} @@ -1284,12 +1287,13 @@ def supernova_pca_model( dict: The plus and cross polarizations of the signal """ + xp = array_module(frequency_array) principal_components = kwargs["realPCs"] + 1j * kwargs["imagPCs"] coefficients = [pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5] - strain = np.sum( - [coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)], + strain = xp.sum( + xp.asarray([coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)]), axis=0 ) diff --git a/bilby/gw/time.py b/bilby/gw/time.py new file mode 100644 index 000000000..367a95634 --- /dev/null +++ b/bilby/gw/time.py @@ -0,0 +1,211 @@ +import numpy as np +from plum import dispatch + +from ..compat.utils import array_module + + +__all__ = [ + "datetime", + "gps_time_to_utc", + "greenwich_mean_sidereal_time", + "greenwich_sidereal_time", + "n_leap_seconds", + "utc_to_julian_day", + "LEAP_SECONDS", +] + + +class datetime: + """ + A barebones datetime class for use in the GPS to GMST conversion. + """ + + def __init__( + self, + year: int = 0, + month: int = 0, + day: int = 0, + hour: int = 0, + minute: int = 0, + second: float = 0, + ): + self.year = year + self.month = month + self.day = day + self.hour = hour + self.minute = minute + self.second = second + + def __repr__(self): + return f"{self.year}-{self.month}-{self.day} {self.hour}:{self.minute}:{self.second}" + + def __add__(self, other): + """ + Add two datetimes together. + Note that this does not handle overflow and can lead to unphysical + values for the various attributes. + """ + return datetime( + self.year + other.year, + self.month + other.month, + self.day + other.day, + self.hour + other.hour, + self.minute + other.minute, + self.second + other.second, + ) + + @property + def julian_day(self): + return ( + 367 * self.year + - 7 * (self.year + (self.month + 9) // 12) // 4 + + 275 * self.month // 9 + + self.day + + self.second / SECONDS_PER_DAY + + JULIAN_GPS_EPOCH + ) + + +GPS_EPOCH = datetime(1980, 1, 6, 0, 0, 0) +JULIAN_GPS_EPOCH = 1721013.5 +EPOCH_J2000_0_JD = 2451545.0 +DAYS_PER_CENTURY = 36525.0 +SECONDS_PER_DAY = 86400.0 +LEAP_SECONDS = [ + 46828800, + 78364801, + 109900802, + 173059203, + 252028804, + 315187205, + 346723206, + 393984007, + 425520008, + 457056009, + 504489610, + 551750411, + 599184012, + 820108813, + 914803214, + 1025136015, + 1119744016, + 1167264017, +] + + +@dispatch +def gps_time_to_utc(gps_time): + """ + Convert GPS time to UTC. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + datetime + UTC time. + """ + return GPS_EPOCH + datetime(second=gps_time - n_leap_seconds(gps_time)) + + +@dispatch +def greenwich_mean_sidereal_time(gps_time): + """ + Calculate the Greenwich Mean Sidereal Time. + + This is a thin wrapper around :py:func:`greenwich_sidereal_time` with the + equation of the equinoxes set to zero. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + float + Greenwich Mean Sidereal Time in radians. + """ + return greenwich_sidereal_time(gps_time, gps_time * 0) + + +@dispatch +def greenwich_sidereal_time(gps_time, equation_of_equinoxes): + """ + Calculate the Greenwich Sidereal Time. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + equation_of_equinoxes : float + Equation of the equinoxes in seconds. + + Returns + ------- + float + """ + julian_day = utc_to_julian_day(gps_time_to_utc(gps_time // 1)) + t_hi = (julian_day - EPOCH_J2000_0_JD) / DAYS_PER_CENTURY + t_lo = (gps_time % 1) / (DAYS_PER_CENTURY * SECONDS_PER_DAY) + + t = t_hi + t_lo + + sidereal_time = ( + equation_of_equinoxes + (-6.2e-6 * t + 0.093104) * t**2 + 67310.54841 + ) + sidereal_time += 8640184.812866 * t_lo + sidereal_time += 3155760000.0 * t_lo + sidereal_time += 8640184.812866 * t_hi + sidereal_time += 3155760000.0 * t_hi + + return sidereal_time * 2 * np.pi / SECONDS_PER_DAY + + +@dispatch +def n_leap_seconds(gps_time, leap_seconds): + """ + Calculate the number of leap seconds that have occurred up to a given GPS time. + + Parameters + ---------- + gps_time : float | np.ndarray | int + GPS time in seconds. + leap_seconds : array_like + GPS time of leap seconds. + + Returns + ------- + float + Number of leap seconds + """ + xp = array_module(gps_time) + return xp.sum(gps_time > leap_seconds[:, None], axis=0).squeeze() + + +@dispatch +def n_leap_seconds(gps_time: np.ndarray | np.number | float | int): # noqa F811 + xp = array_module(gps_time) + return n_leap_seconds(gps_time, xp.asarray(LEAP_SECONDS)) + + +@dispatch +def utc_to_julian_day(utc_time): + """ + Convert UTC time to Julian day. + + Parameters + ---------- + utc_time : datetime + UTC time. + + Returns + ------- + float + Julian day. + + """ + return utc_time.julian_day diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 420e1fc04..1a55de975 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -5,18 +5,18 @@ import numpy as np from scipy.interpolate import interp1d from scipy.special import i0e -from bilby_cython.geometry import ( - zenith_azimuth_to_theta_phi as _zenith_azimuth_to_theta_phi, -) -from bilby_cython.time import greenwich_mean_sidereal_time +from .geometry import zenith_azimuth_to_theta_phi +from .time import greenwich_mean_sidereal_time +from ..compat.utils import array_module, xp_wrap from ..core.utils import (logger, run_commandline, check_directory_exists_and_if_not_mkdir, SamplesSummary, theta_phi_to_ra_dec) from ..core.utils.constants import solar_mass -def asd_from_freq_series(freq_data, df): +@xp_wrap +def asd_from_freq_series(freq_data, df, *, xp=None): """ Calculate the ASD from the frequency domain output of gaussian_noise() @@ -32,10 +32,11 @@ def asd_from_freq_series(freq_data, df): array_like: array of real-valued normalized frequency domain ASD data """ - return np.absolute(freq_data) * 2 * df**0.5 + return xp.abs(freq_data) * 2 * df**0.5 -def psd_from_freq_series(freq_data, df): +@xp_wrap +def psd_from_freq_series(freq_data, df, *, xp=None): """ Calculate the PSD from the frequency domain output of gaussian_noise() Calls asd_from_freq_series() and squares the output @@ -52,7 +53,7 @@ def psd_from_freq_series(freq_data, df): array_like: Real-valued normalized frequency domain PSD data """ - return np.power(asd_from_freq_series(freq_data, df), 2) + return asd_from_freq_series(freq_data, df, xp=xp) ** 2 def get_vertex_position_geocentric(latitude, longitude, elevation): @@ -76,14 +77,15 @@ def get_vertex_position_geocentric(latitude, longitude, elevation): array_like: A 3D representation of the geocentric vertex position """ + xp = array_module(latitude) semi_major_axis = 6378137 # for ellipsoid model of Earth, in m semi_minor_axis = 6356752.314 # in m - radius = semi_major_axis**2 * (semi_major_axis**2 * np.cos(latitude)**2 + - semi_minor_axis**2 * np.sin(latitude)**2)**(-0.5) - x_comp = (radius + elevation) * np.cos(latitude) * np.cos(longitude) - y_comp = (radius + elevation) * np.cos(latitude) * np.sin(longitude) - z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * np.sin(latitude) - return np.array([x_comp, y_comp, z_comp]) + radius = semi_major_axis**2 * (semi_major_axis**2 * xp.cos(latitude)**2 + + semi_minor_axis**2 * xp.sin(latitude)**2)**(-0.5) + x_comp = (radius + elevation) * xp.cos(latitude) * xp.cos(longitude) + y_comp = (radius + elevation) * xp.cos(latitude) * xp.sin(longitude) + z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * xp.sin(latitude) + return xp.asarray([x_comp, y_comp, z_comp]) def inner_product(aa, bb, frequency, PSD): @@ -106,11 +108,11 @@ def inner_product(aa, bb, frequency, PSD): psd_interp = PSD.power_spectral_density_interpolated(frequency) # calculate the inner product - integrand = np.conj(aa) * bb / psd_interp + integrand = (aa.conj() * bb / psd_interp).real df = frequency[1] - frequency[0] - integral = np.sum(integrand) * df - return 4. * np.real(integral) + integral = integrand.sum() * df + return 4. * integral def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): @@ -132,9 +134,8 @@ def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): ======= Noise-weighted inner product. """ - - integrand = np.conj(aa) * bb / power_spectral_density - return 4 / duration * np.sum(integrand) + integrand = aa.conj() * bb / power_spectral_density + return 4 / duration * integrand.sum() def matched_filter_snr(signal, frequency_domain_strain, power_spectral_density, duration): @@ -222,34 +223,12 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non """ low_index = int(lower_cut_off / delta_frequency) up_index = int(upper_cut_off / delta_frequency) - integrand = np.conj(signal_a) * signal_b + integrand = signal_a.conj() * signal_b integrand = integrand[low_index:up_index] / power_spectral_density[low_index:up_index] integral = (4 * delta_frequency * integrand) / norm_a / norm_b return sum(integral).real -def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos): - """ - Convert from the 'detector frame' to the Earth frame. - - Parameters - ========== - kappa: float - The zenith angle in the detector frame - eta: float - The azimuthal angle in the detector frame - ifos: list - List of Interferometer objects defining the detector frame - - Returns - ======= - theta, phi: float - The zenith and azimuthal angles in the earth frame. - """ - delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex - return _zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) - - def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos): """ Convert from the 'detector frame' to the Earth frame. @@ -945,7 +924,8 @@ def lalsim_SimNeutronStarLoveNumberK2(mass_in_SI, fam): return SimNeutronStarLoveNumberK2(mass_in_SI, fam) -def spline_angle_xform(delta_psi): +@xp_wrap +def spline_angle_xform(delta_psi, *, xp=None): """ Returns the angle in degrees corresponding to the spline calibration parameters delta_psi. @@ -962,7 +942,7 @@ def spline_angle_xform(delta_psi): """ rotation = (2.0 + 1.0j * delta_psi) / (2.0 - 1.0j * delta_psi) - return 180.0 / np.pi * np.arctan2(np.imag(rotation), np.real(rotation)) + return 180.0 / np.pi * xp.arctan2(xp.imag(rotation), xp.real(rotation)) def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=None, xform=None): @@ -1023,7 +1003,8 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= plt.xlim(freq_points.min() - .5, freq_points.max() + 50) -def ln_i0(value): +@xp_wrap +def ln_i0(value, *, xp=None): """ A numerically stable method to evaluate ln(I_0) a modified Bessel function of order 0 used in the phase-marginalized likelihood. @@ -1038,7 +1019,7 @@ def ln_i0(value): array-like: The natural logarithm of the bessel function """ - return np.log(i0e(value)) + np.abs(value) + return xp.log(i0e(value)) + xp.abs(value) def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): @@ -1067,10 +1048,10 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): import lalsimulation return safety * lalsimulation.SimInspiralTaylorF2ReducedSpinChirpTime( - frequency, - mass_1 * solar_mass, - mass_2 * solar_mass, - chi, + float(frequency), + float(mass_1 * solar_mass), + float(mass_2 * solar_mass), + float(chi), -1 ) diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index b42f5d8c8..a78a4d843 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -1,3 +1,4 @@ +import array_api_compat as aac import numpy as np from ..core import utils @@ -24,7 +25,8 @@ class WaveformGenerator(object): def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None, time_domain_source_model=None, parameters=None, parameter_conversion=None, - waveform_arguments=None): + waveform_arguments=None, use_cache=True, + ): """ The base waveform generator class. @@ -58,6 +60,10 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen Note: the arguments of frequency_domain_source_model (except the first, which is the frequencies at which to compute the strain) will be added to the WaveformGenerator object and initialised to `None`. + use_cache: bool + Whether to attempt caching the waveform between subsequent calls. + This is :code:`True` by default but must be disabled for JIT compilation + with :code:`JAX`. """ self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, @@ -76,9 +82,11 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen self.waveform_arguments = dict() if parameters is not None: logger.warning( - "Non null parameters passed to waveform generator. These will be ignored." + "Setting initial parameters via the 'parameters' argument is " + "deprecated and will be removed in a future release." ) self._cache = dict(parameters=None, waveform=None, model=None) + self.use_cache = use_cache logger.info(f"Waveform generator instantiated: {self}") def __repr__(self): @@ -102,7 +110,7 @@ def __repr__(self): .format(self.duration, self.sampling_frequency, self.start_time, fdsm_name, tdsm_name, param_conv_name, self.waveform_arguments) - def frequency_domain_strain(self, parameters=None): + def frequency_domain_strain(self, parameters=None, *, xp=None): """ Wrapper to source_model. Converts parameters with self.parameter_conversion before handing it off to the source model. @@ -129,9 +137,10 @@ def frequency_domain_strain(self, parameters=None): parameters=parameters, transformation_function=utils.nfft, transformed_model=self.time_domain_source_model, - transformed_model_data_points=self.time_array) + transformed_model_data_points=self.time_array, + xp=xp) - def time_domain_strain(self, parameters=None): + def time_domain_strain(self, parameters=None, *, xp=None): """ Wrapper to source_model. Converts parameters with self.parameter_conversion before handing it off to the source model. @@ -140,10 +149,14 @@ def time_domain_strain(self, parameters=None): Parameters ========== - parameters: dict, None + parameters: dict, optional Parameters to evaluate the waveform for. If not passed and the generator has been called previously, the last used parameters will be used. + xp: array module, optional + The array module to use when evaluating the source model, e.g., :code:`numpy`. + This can be used to override the :code:`time_array` stored in the generator. + If :code:`None`, the default will be used. Returns ======= @@ -159,17 +172,21 @@ def time_domain_strain(self, parameters=None): parameters=parameters, transformation_function=utils.infft, transformed_model=self.frequency_domain_source_model, - transformed_model_data_points=self.frequency_array) + transformed_model_data_points=self.frequency_array, + xp=xp) def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, - transformed_model_data_points, parameters): - if parameters is None and self._cache["parameters"] is not None: - parameters = self._cache["parameters"] - elif parameters is None: - raise ValueError("No parameters passed to waveform generator.") - - if parameters == self._cache['parameters'] and self._cache['model'] == model and \ - self._cache['transformed_model'] == transformed_model: + transformed_model_data_points, parameters, *, xp=None): + if parameters is None: + parameters = self._cache.get('parameters', None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") + if ( + self.use_cache + and parameters == self._cache.get('parameters', None) + and self._cache['model'] == model + and self._cache['transformed_model'] == transformed_model + ): return self._cache['waveform'] else: self._cache['parameters'] = parameters.copy() @@ -177,26 +194,39 @@ def _calculate_strain(self, model, model_data_points, transformation_function, t self._cache['transformed_model'] = transformed_model parameters = self._format_parameters(parameters) if model is not None: - model_strain = self._strain_from_model(model_data_points, model, parameters) + model_strain = self._strain_from_model(model_data_points, model, parameters, xp=xp) elif transformed_model is not None: - model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model, - transformation_function, parameters) + model_strain = self._strain_from_transformed_model( + transformed_model_data_points, + transformed_model, + transformation_function, + parameters, + xp=xp, + ) else: raise RuntimeError("No source model given") self._cache['waveform'] = model_strain return model_strain - def _strain_from_model(self, model_data_points, model, parameters): + def _strain_from_model(self, model_data_points, model, parameters, *, xp=None): + if xp is not None: + model_data_points = xp.asarray(model_data_points) return model(model_data_points, **parameters) def _strain_from_transformed_model( - self, transformed_model_data_points, transformed_model, transformation_function, parameters + self, + transformed_model_data_points, + transformed_model, + transformation_function, + parameters, + *, + xp=None, ): transformed_model_strain = self._strain_from_model( - transformed_model_data_points, transformed_model, parameters + transformed_model_data_points, transformed_model, parameters, xp=xp ) - if isinstance(transformed_model_strain, np.ndarray): + if aac.is_array_api_obj(transformed_model_strain): return transformation_function(transformed_model_strain, self.sampling_frequency) model_strain = dict() diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py index 2ca63bbbd..fa113a048 100644 --- a/bilby/hyper/likelihood.py +++ b/bilby/hyper/likelihood.py @@ -1,8 +1,10 @@ import logging +import array_api_compat as aac import numpy as np +from ..compat.utils import array_module from ..core.likelihood import Likelihood from .model import Model from ..core.prior import PriorDict @@ -29,11 +31,13 @@ class HyperparameterLikelihood(Likelihood): the sampling prior and the hyperparameterised model. max_samples: int, optional Maximum number of samples to use from each set. + xp: module + The array backend to use for the data. """ def __init__(self, posteriors, hyper_prior, sampling_prior=None, - log_evidences=None, max_samples=1e100): + log_evidences=None, max_samples=1e100, xp=np): if not isinstance(hyper_prior, Model): hyper_prior = Model([hyper_prior]) if sampling_prior is None: @@ -47,23 +51,27 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, self.evidence_factor = np.sum(log_evidences) else: self.evidence_factor = np.nan - self.posteriors = posteriors + if aac.is_jax_namespace(xp): + self.posteriors = None + else: + self.posteriors = posteriors self.hyper_prior = hyper_prior self.sampling_prior = sampling_prior self.max_samples = max_samples super(HyperparameterLikelihood, self).__init__() - self.data = self.resample_posteriors() - self.n_posteriors = len(self.posteriors) + self.data = self.resample_posteriors(posteriors=posteriors, xp=xp) + self.n_posteriors = len(posteriors) self.samples_per_posterior = self.max_samples self.samples_factor =\ - self.n_posteriors * np.log(self.samples_per_posterior) def log_likelihood_ratio(self, parameters): - log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data, **parameters) / - self.data['prior'], axis=-1))) + probs = self.hyper_prior.prob(self.data, **parameters) + xp = array_module(probs) + log_l = xp.sum(xp.log(xp.sum(probs / self.data['prior'], axis=-1))) log_l += self.samples_factor - return np.nan_to_num(log_l) + return xp.nan_to_num(log_l) def noise_log_likelihood(self): return self.evidence_factor @@ -71,7 +79,7 @@ def noise_log_likelihood(self): def log_likelihood(self, parameters): return self.noise_log_likelihood() + self.log_likelihood_ratio(parameters=parameters) - def resample_posteriors(self, max_samples=None): + def resample_posteriors(self, posteriors=None, max_samples=None, xp=np): """ Convert list of pandas DataFrame object to dict of arrays. @@ -86,18 +94,26 @@ def resample_posteriors(self, max_samples=None): Dictionary containing arrays of size (n_posteriors, max_samples) There is a key for each shared key in self.posteriors. """ + if posteriors is None: + posteriors = self.posteriors + if isinstance(posteriors, int): + raise ValueError( + "Input posteriors is an integer. This may have been intended for the " + "max_samples argument. The API changed in Bilby v3." + ) + if max_samples is not None: self.max_samples = max_samples - for posterior in self.posteriors: + for posterior in posteriors: self.max_samples = min(len(posterior), self.max_samples) - data = {key: [] for key in self.posteriors[0]} + data = {key: [] for key in posteriors[0]} if 'log_prior' in data.keys(): data.pop('log_prior') if 'prior' not in data.keys(): data['prior'] = [] logging.debug('Downsampling to {} samples per posterior.'.format( self.max_samples)) - for posterior in self.posteriors: + for posterior in posteriors: temp = posterior.sample(self.max_samples) if self.sampling_prior is not None: temp['prior'] = self.sampling_prior.prob(temp, axis=0) @@ -106,5 +122,5 @@ def resample_posteriors(self, max_samples=None): for key in data: data[key].append(temp[key]) for key in data: - data[key] = np.array(data[key]) + data[key] = xp.asarray(data[key]) return data diff --git a/docs/array_api.rst b/docs/array_api.rst new file mode 100644 index 000000000..c6e6190c0 --- /dev/null +++ b/docs/array_api.rst @@ -0,0 +1,552 @@ +===================== +Array API Support +===================== + +Bilby now supports the Python `Array API Standard `_, +enabling the use of different array backends (NumPy, JAX, CuPy, etc.) for improved performance +and hardware acceleration. This page describes how to use this functionality and how it works internally. + +For Users and Downstream Developers +==================================== + +Overview +-------- + +The Array API support allows you to use different array libraries with Bilby seamlessly. +This can significantly improve performance, especially when using hardware accelerators like GPUs +or when you need automatic differentiation capabilities. +To activate array API support you need to set the :code:`BILBY_ARRAY_API` environment variable to +:code:`1` before importing Bilby. +You will also need to set the corresponding :code:`scipy` environment variable (:code:`SCIPY_ARRAY_API`) +for most functionality. +This can be most easily done by setting the environment variable in your shell: + +.. code-block:: bash + + export BILBY_ARRAY_API=1 + export SCIPY_ARRAY_API=1 + +**Key principle**: In most cases, you don't need to explicitly specify which array backend to use. +Bilby automatically detects the array type you're working with and uses the appropriate backend. +Simply pass JAX arrays, CuPy arrays, or NumPy arrays to prior methods, and Bilby handles the rest. + +Supported Backends +------------------ + +Bilby is currently tested with the following array backends: + +- **NumPy** (default): Standard CPU-based computations +- **JAX**: GPU/TPU acceleration and automatic differentiation +- **PyTorch**: GPU acceleration and deep learning integration. + :code:`PyTorch` support is not complete, for example, functionality + requiring interpolation is not available. + +While :code:`Bilby` should be compatible with other Array API compliant libraries, +these are not currently tested or officially supported. +If you notice any issues when using other backends, +please report them on the `Bilby GitHub repository `__. + +Using Different Array Backends +------------------------------- + +Basic Prior Usage (Automatic Detection) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The array backend is automatically detected from your input arrays. You typically don't need +to specify the ``xp`` parameter: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + import numpy as np + + prior = bilby.core.prior.Uniform(minimum=0, maximum=10) + + # Using JAX - backend automatically detected + val_jax = jnp.array([0.5, 1.5, 2.5]) + prob_jax = prior.prob(val_jax) # Returns JAX array + + # Using NumPy - backend automatically detected + val_np = np.array([0.5, 1.5, 2.5]) + prob_np = prior.prob(val_np) # Returns NumPy array + +Sampling with Array Backends (Explicit RNG Required) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When sampling from priors, you **must** explicitly specify the random state for +the operation using the ``random_state`` parameter, +as there's no input array to infer the backend from: + +.. code-block:: python + + import bilby + import jax + + prior = bilby.core.prior.Uniform(minimum=0, maximum=10) + samples = prior.sample(size=1000, random_state=jax.random.PRNGKey(42)) # Returns JAX array + + # Or with NumPy (default) + samples_np = prior.sample(size=1000) # Or explicitly: random_state=np.random.default_rng(42) + +Prior Dictionaries +~~~~~~~~~~~~~~~~~~ + +Prior dictionaries work the same way - automatic detection for most methods, explicit ``random_state`` for sampling: + +.. code-block:: python + + import bilby + import jax + import jax.numpy as jnp + + priors = bilby.core.prior.PriorDict({ + 'x': bilby.core.prior.Uniform(0, 100), + 'y': bilby.core.prior.Uniform(0, 1) + }) + + # Sampling requires explicit random_state + samples = priors.sample(size=1000, random_state=jax.random.PRNGKey(42)) + + # Evaluation automatically detects backend from input + theta = jnp.array([50.0, 0.5]) + prob = priors.prob(samples) # Automatically uses JAX + +Core Likelihoods and Sampling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Core :code:`Bilby` likelihoods are compatible with the Array API. +When using :code:`JAX` arrays, you can take advantage of :code:`JAX`'s JIT compilation and automatic differentiation. +For :code:`JAX`-compatible samplers (e.g., :code:`numpyro`), +you can pass any :code:`JAX`-compatible :code:`Bilby` likelihood directly. +For non-:code:`JAX` samplers, you should wrap your likelihood with the +:code:`bilby.compat.jax.JittedLikelihood` class to enable JIT compilation. + +.. code-block:: python + + import bilby + import jax.numpy as jnp + from bilby.compat.jax import JittedLikelihood + + class MyLikelihood(bilby.Likelihood): + def log_likelihood(self, parameters): + # model returns a JAX array if passed a dictionary of JAX arrays + return -0.5 * xp.sum((self.data - model(parameters))**2) + + data = jnp.array([...]) # Your data as a JAX array + + priors = bilby.core.prior.PriorDict({ + 'param1': bilby.core.prior.Uniform(0, 10), + 'param2': bilby.core.prior.Uniform(-5, 5) + }) + + likelihood = MyLikelihood(data) + + # call the likelihood once in case any initial setup is needed + likelihood.log_likelihood(priors.sample()) + + # Wrap with JittedLikelihood for JAX + jitted_likelihood = JittedLikelihood(likelihood) + + # call the jitted likelihood once to trigger JIT compilation + # the JittedLikelihood automatically converts the parameters + # to JAX arrays + jitted_likelihood.log_likelihood(priors.sample()) + + # Use with a JAX-incompatible sampler + sampler = bilby.run_sampler(likelihood=jitted_likelihood, ...) + +Gravitational-Wave Likelihoods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :code:`Bilby` implementation of gravitational-wave likelihood is compatible with the Array API, +however this requires access to waveform models that support the provided array backend. +The desired array backend must be explicitly specified for the data, +using :code:`bilby.gw.detector.networks.InterferometerList.set_array_backend`. +Below is an example using the :code:`ripplegw` package for waveform generation. +Here, an injection is performed using the standard :code:`LALSimulation` waveform generator, +and the analysis is then performed using the JIT-compiled likelihood. + +.. code-block:: python + + import bilby + import jax.numpy as jnp + import ripplegw + + priors = bilby.gw.prior.BBHPriorDict() + priors["geocent_time"] = bilby.core.prior.Uniform(1126259462.4, 1126259462.6) + injection_parameters = priors.sample() + + # Create interferometers and inject signal using standard waveform generator + ifos = bilby.gw.detector.networks.InterferometerList(['H1', 'L1']) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=2048, + duration=4, + start_time=injection_parameters["geocent_time"] - 2 + ) + injection_wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=4, + sampling_frequency=2048, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + waveform_arguments={"approximant": "IMRPhenomXPHM"} + ) + ifos.inject_signal(parameters=injection_parameters, waveform_generator=injection_wfg) + + # set the array backend after the injection + ifos.set_array_backend(jnp) + + ripple_wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=4, + sampling_frequency=2048, + frequency_domain_source_model=ripplegw.get_fd_waveform + ) + + # Create gravitational-wave likelihood + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=ripple_wfg, + priors=priors, + phase_marginalization=True, + ) + # call the likelihood once to do some initial setup + # this is needed for the gravitational-wave transient likelihoods + likelihood.log_likelihood_ratio(priors.sample()) + + # Wrap with JittedLikelihood for JAX and JIT compile + jitted_likelihood = bilby.compat.jax.JittedLikelihood(likelihood) + jitted_likelihood.log_likelihood_ratio(priors.sample()) + +.. note:: + + All of the likelihood marginalizations implemented in :code:`Bilby` are compatible with the Array API. + However, there is currently a performance issue with the distance marginalized likelihood + using the :code:`JAX` backend. + +.. warning:: + + Some array backends (notably :code:`torch`) are more picky than others about data types. + For maximal consistency, try to consistently pass zero-dimensional arrays rather than :code:`Python` + scalars, e.g., :code:`torch.array(1.0)` instead of :code:`1.0`. + +Performance Considerations +-------------------------- + +**When to use JAX:** + +- GPU/TPU acceleration is available +- You need automatic differentiation +- Working with large datasets or many parameters +- Repeated evaluations benefit from JIT compilation + +**When to use NumPy:** + +- Simple CPU-based computations +- Small datasets +- Maximum compatibility +- Debugging (easier to inspect values) + +**Best Practices:** + +1. Let Bilby detect the array backend automatically - only specify ``xp`` when sampling +2. Use array backend consistently throughout your analysis +3. Avoid mixing array types in the same computation +4. For JAX, consider using ``jax.jit`` for repeated computations +5. Profile your code to ensure the chosen backend provides benefits +6. If you find :code:`xp_wrap` is a bottleneck in your code, you can explicitly pass + :code:`xp` to the function/method to skip the automatic backend detection step. + +Bilby and JIT compilation +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Currently, Bilby functions are not JIT-compiled by default. +Additionally, many Bilby types are not defined as :code:`JAX` :code:`PyTrees`, +and so cannot be passed as arguments to JIT-compiled functions. +We plan to support JIT-compilation for at least some Bilby types in future releases. + +Custom Priors with Array API +----------------------------- + +When creating custom priors, ensure they support the Array API: + +Example Implementation +~~~~~~~~~~~~~~~~~~~~~~ + +Always include the ``xp`` parameter with a default value: + +.. code-block:: python + + from bilby.core.prior import Prior + + class MyCustomPrior(Prior): + def __init__(self, parameter, **kwargs): + super().__init__(**kwargs) + self.parameter = parameter + + def rescale(self, val, *, xp=None): + """Rescale method with xp parameter.""" + return self.minimum + val * (self.maximum - self.minimum) * self.parameter + + def prob(self, val, *, xp=None): + """Probability method with xp parameter.""" + in_range = (val >= self.minimum) & (val <= self.maximum) + return in_range / (self.maximum - self.minimum) * self.parameter + +The ``xp`` parameter should: + +- Be a keyword-only argument (after ``*``) +- Have a default value (``None`` if method is decorated with ``@xp_wrap``, ``np`` otherwise) +- Be passed through to any array operations if used directly + +**Note**: Users of your custom prior won't need to pass ``xp`` explicitly for evaluation methods - +it will be automatically inferred from their input arrays. They only need to specify ``xp`` when sampling. + +Using the :code:`xp_wrap` Decorator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For methods that perform array operations, use the ``@xp_wrap`` decorator: + +.. code-block:: python + + from bilby.core.prior import Prior + from bilby.compat.utils import xp_wrap + import numpy as np + + class MyCustomPrior(Prior): + @xp_wrap + def prob(self, val, *, xp=None): + """The decorator handles xp=None automatically.""" + return xp.exp(-val) / self.normalization * self.is_in_prior_range(val) + + @xp_wrap + def ln_prob(self, val, *, xp=None): + """Works with logarithmic operations.""" + return -val - xp.log(self.normalization) + xp.log(self.is_in_prior_range(val)) + +The ``@xp_wrap`` decorator: + +- Automatically provides the appropriate array module when ``xp=None`` +- Infers the array backend from input arrays when they are :code:`JAX`/:code:`CuPy`/:code:`PyTorch` arrays +- Falls back to NumPy when the input is a standard Python type or NumPy array +- Handles the conversion seamlessly so users don't need to specify ``xp`` + +Missing functionality +--------------------- + +**JAX pytrees**: Currently, Bilby types are not defined as JAX pytrees, which means they cannot be +passed as arguments to JIT-compiled functions. +This is a known limitation and we plan to add support for JAX pytrees in future releases. + +**Device management**: Bilby does not currently manage device placement for arrays. +When using JAX or PyTorch, you may need to manually ensure that your arrays are on the +correct device (CPU/GPU). We may revisit this in the future. + +For Bilby Developers +===================== + +Architecture Overview +--------------------- + +The Array API support in Bilby is built around several key components: + +1. **The xp parameter**: A keyword-only parameter added to prior methods +2. **The @xp_wrap decorator**: Handles array module selection and injection +3. **Compatibility utilities**: Helper functions for array module detection + +Core Changes to Prior Base Class +--------------------------------- + +The ``Prior`` base class in ``bilby/core/prior/base.py`` includes these key changes: + +Method Signature Pattern +~~~~~~~~~~~~~~~~~~~~~~~~ + +All array-processing methods in prior classes follow this pattern: + +**For methods with @xp_wrap decorator**: + +.. code-block:: python + + @xp_wrap + def prob(self, val, *, xp=None): + """Method that uses xp for array operations.""" + return xp.some_operation(val) * self.is_in_prior_range(val) + +Key rules: + +- ``xp`` is always keyword-only (after ``*``) +- Methods with ``@xp_wrap`` use ``xp=None`` as default +- Methods without ``@xp_wrap`` that use ``xp`` use ``xp=np`` as default +- Methods that don't use ``xp`` have ``xp=None`` as default + +The :code:`@xp_wrap` Decorator +------------------------------ + +Located in ``bilby/compat/utils.py``, this decorator: + +1. **Inspects input arguments** to determine the array module in use +2. **Provides the appropriate xp** when ``xp=None`` +3. **Maintains backward compatibility** with code that doesn't pass ``xp`` + +Example implementation pattern: + +.. code-block:: python + + from bilby.compat.utils import xp_wrap + + @xp_wrap + def my_function(val, *, xp=None): + # When called: + # - If xp=None, decorator infers from val + # - If xp is provided, uses that + # - Returns results in the same array type as input + return xp.exp(val) / xp.mean(val) + +Testing Array API Support +------------------------- + +Test Structure +~~~~~~~~~~~~~~ + +When appropriate, tests should verify functionality across different +backends using the ``array_backend`` marker: + +.. code-block:: python + + @pytest.mark.array_backend + @pytest.mark.usefixtures("xp_class") + class TestMyPrior: + def test_prob(self): + prior = MyPrior() + val = self.xp.asarray([0.5, 1.5, 2.5]) + # No need to pass xp - automatically detected + prob = prior.prob(val) + assert self.xp.all(prob >= 0) + assert aac.get_namespace(prob) == self.xp + + def test_sample(self): + prior = MyPrior() + # Sampling requires explicit xp + samples = prior.sample(size=100, random_state=self.rng) + assert aac.get_namespace(samples) == self.xp + +The array_backend Marker +~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``@pytest.mark.array_backend`` marker is used to indicate that a test or test class should be run +with multiple array backends. When you run pytest with the ``--array-backend`` flag, only tests marked +with ``array_backend`` will be executed with that specific backend. + +Without the marker, tests run with the default NumPy backend only. With the marker: + +- Tests are parametrized to run with different backends +- The ``xp_class`` fixture is available, providing access to the array module via ``self.xp`` + and the random state via ``self.rng`` +- Tests verify that code works correctly regardless of the array backend + +Running Tests with Different Backends +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use the ``--array-backend`` flag to test with specific backends:: + + # Test with NumPy (default) + pytest test/core/prior/analytical_test.py + + # Test with JAX backend + pytest --array-backend jax test/core/prior/analytical_test.py + + # Test with CuPy backend + pytest --array-backend cupy test/core/prior/analytical_test.py + +You need to set both ``BILBY_ARRAY_API=1`` and ``SCIPY_ARRAY_API=1`` environment variables +to enable array API support in testing +The ``--array-backend`` flag controls which backend the ``xp_class`` fixture provides to your tests. + +Migration Guide from Previous Versions +-------------------------------------- + +Key Differences +~~~~~~~~~~~~~~~ + +1. **Method signatures changed**: All prior methods now include ``xp`` parameter +2. **Decorator added**: Many methods now use ``@xp_wrap`` +3. **Default values differ**: Methods with ``@xp_wrap`` use ``xp=None``, others use ``xp=np`` +4. **Validation added**: Custom priors are checked for ``xp`` support +5. **Explicit random state**: Sampling methods accept a ``random_state`` argument + +Best Practices for Contributors +-------------------------------- + +When adding or modifying prior methods: + +1. **Always include xp parameter** in prob, ln_prob, rescale, cdf, sample methods +2. **Use @xp_wrap decorator** for methods doing array operations +3. **Set correct default**: ``xp=None`` with decorator, ``xp=np`` without (for methods that use xp directly) +4. **Pass xp through**: When calling other methods, pass ``xp=xp`` +5. **Test with multiple backends**: Use ``@pytest.mark.array_backend`` and test with ``--array-backend jax`` +6. **Document xp parameter**: Note it in docstrings, but emphasize it's usually auto-detected +7. **Use array module functions**: Use ``xp.function()`` not ``np.function()`` in wrapped methods + +Handling Array Updates with :code:`array_api_extra.at` +------------------------------------------------------ + +One key difference between array backends is how they handle array updates. +NumPy allows in-place modification of array slices, +while JAX requires functional updates since arrays are immutable. +The ``array_api_extra.at`` function provides a unified interface for array updates across backends. + +Usage Examples +~~~~~~~~~~~~~~ + +**Conditional update**: + +.. code-block:: python + + @xp_wrap + def conditional_update(vals, *, xp=None): + """Update array elements where mask is True.""" + arr = vals**2 + mask = arr > 0.5 + # Instead of: arr[mask] = value + arr = xpx.at(arr)[mask].set(value) + return arr + +**Increment operation**: + +.. code-block:: python + + @xp_wrap + def increment_slice(arr, *, xp=None): + """Add values to a slice of an array.""" + # Instead of: arr[2:5] += values + arr = xpx.at(arr)[2:5].add(values) + return arr + +Available Operations +~~~~~~~~~~~~~~~~~~~~ + +The ``at`` function supports several operations: + +- ``set(values)``: Replace values at specified indices +- ``add(values)``: Add values to specified indices +- ``multiply(values)``: Multiply specified indices by values +- ``min(values)``: Take element-wise minimum +- ``max(values)``: Take element-wise maximum + +Important Notes +~~~~~~~~~~~~~~~ + +1. **Return value**: Always use the returned array. The operation may create a new array (JAX) or modify in-place (NumPy). + +2. **Import**: Import ``array_api_extra`` at the module level: + +.. code-block:: python + + import array_api_extra as xpx + +Further Resources +----------------- + +- `Array API Standard `_ +- `JAX Documentation `_ +- `array-api-compat Package `_ +- `array-api-extra Package `_ diff --git a/docs/index.txt b/docs/index.txt index ff6e12c85..d8fabb550 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -16,6 +16,7 @@ Welcome to bilby's documentation! prior likelihood samplers + array_api dynesty-guide bilby-mcmc-guide rng diff --git a/docs/installation.txt b/docs/installation.txt index ae2820a0c..8d4fcee5c 100644 --- a/docs/installation.txt +++ b/docs/installation.txt @@ -47,7 +47,7 @@ wave inference, please additionally run the following commands. Install bilby from source ------------------------- -:code:`bilby` is developed and tested with Python 3.10-3.12. In the +:code:`bilby` is developed and tested with Python 3.11+. In the following, we assume you have a working python installation, `python pip `_, and `git `_. See :ref:`installing-python` for our diff --git a/docs/rng.rst b/docs/rng.rst index 6024d5b2b..9feefc8a2 100644 --- a/docs/rng.rst +++ b/docs/rng.rst @@ -27,6 +27,10 @@ The random number generation can be seeded using the >>> from bilby.core.utils import random >>> random.seed(1234) +For more fine-grained control, every function/method that relies on random number +generation supports a :code:`random_state` argument that can be used to specify +the random number generator to use for that function/method. + ---------------- Seeding samplers ---------------- @@ -57,4 +61,29 @@ For example: .. note:: Some sampler interfaces do not support seeding. +-------------------------------------------------------- +Random number generation and non-:code:`NumpPy` backends +-------------------------------------------------------- + +To support random number generation with non-:code:`NumPy` array backends, +any :code:`bilby` function or method that supports random number generation and accepts a +:code:`random_state` argument. +This argument should be one of the following types: + +- :code:`None` (the default): the function will use the :code:`bilby` global + :code:`numpy` random number generator (set using :code:`bilby.core.random.seed`). +- :code:`numpy.random.Generator`: the function will use the provided generator. +- :code:`orng.ArrayRNG`: the function will use the provided :code:`orng` random number generator. +- :code:`int`: the function will create a new :code:`numpy` random number generator seeded with + the provided integer and use it for random number generation. +- :code:`jax.random.PRNGKey`: the function will create a new :code:`orng` random number generator + with the "jax" backend seeded with the provided key and use it for random number generation. + +For example, + +.. code:: python + >>> import orng + >>> rng = orng.ArrayRNG("jax", seed=1234) + >>> x = rng.uniform() + >>> priors.sample(xp=jnp, rng=rng) diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py new file mode 100644 index 000000000..22a64c935 --- /dev/null +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python +""" +Tutorial to demonstrate running parameter estimation on a reduced parameter +space for an injected signal. + +This example estimates the masses using a uniform prior in both component masses +and distance using a uniform in comoving volume prior on luminosity distance +between luminosity distances of 100Mpc and 5Gpc, the cosmology is Planck15. + +We optionally use ripple waveforms and a JIT-compiled likelihood. +""" +import os + +# Set OMP_NUM_THREADS to stop lalsimulation taking over my computer +os.environ["OMP_NUM_THREADS"] = "1" + +import bilby +import jax +import jax.numpy as jnp +import numpy as np +from bilby.compat.jax import JittedLikelihood +from ripplegw.waveforms import IMRPhenomPv2 + +jax.config.update("jax_enable_x64", True) + + +def bilby_to_ripple_spins( + theta_jn, + phi_jl, + tilt_1, + tilt_2, + phi_12, + a_1, + a_2, +): + """ + A simplified spherical to cartesian spin conversion function. + This is not equivalent to the method used in `bilby.gw.conversion` + which comes from `lalsimulation` and is not `JAX` compatible. + """ + iota = theta_jn + spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl) + spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl) + spin_1z = a_1 * jnp.cos(tilt_1) + spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12) + spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12) + spin_2z = a_2 * jnp.cos(tilt_2) + return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z + + +def ripple_bbh( + frequency, + mass_1, + mass_2, + luminosity_distance, + theta_jn, + phase, + a_1, + a_2, + tilt_1, + tilt_2, + phi_12, + phi_jl, + **kwargs, +): + """ + Source function wrapper to ripple's IMRPhenomPv2 waveform generator. + This function cannot be jitted directly as the Bilby waveform generator + relies on inspecting the function signature. + + Parameters + ---------- + frequency: jnp.ndarray + Frequencies at which to compute the waveform. + mass_1: float | jnp.ndarray + Mass of the primary component in solar masses. + mass_2: float | jnp.ndarray + Mass of the secondary component in solar masses. + luminosity_distance: float | jnp.ndarray + Luminosity distance to the source in Mpc. + theta_jn: float | jnp.ndarray + Angle between total angular momentum and line of sight in radians. + phase: float | jnp.ndarray + Phase at coalescence in radians. + a_1: float | jnp.ndarray + Dimensionless spin magnitude of the primary component. + a_2: float | jnp.ndarray + Dimensionless spin magnitude of the secondary component. + tilt_1: float | jnp.ndarray + Tilt angle of the primary component spin in radians. + tilt_2: float | jnp.ndarray + Tilt angle of the secondary component spin in radians. + phi_12: float | jnp.ndarray + Azimuthal angle between the two spin vectors in radians. + phi_jl: float | jnp.ndarray + Azimuthal angle of the total angular momentum vector in radians. + **kwargs + Additional keyword arguments. Must include 'minimum_frequency'. + + Returns + ------- + dict + Dictionary containing the plus and cross polarizations of the waveform. + """ + iota, *cartesian_spins = bilby_to_ripple_spins( + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 + ) + frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) + theta = jnp.array( + [ + mass_1, + mass_2, + *cartesian_spins, + luminosity_distance, + jnp.array(0.0), + phase, + iota, + ] + ) + wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2) + hp, hc = wf_func(frequencies, theta, jnp.array(20.0)) + return dict(plus=hp, cross=hc) + + +def main(): + # Set the duration and sampling frequency of the data segment that we're + # going to inject the signal into + duration = 64.0 + sampling_frequency = 2048.0 + minimum_frequency = 20.0 + duration = jnp.array(duration) + sampling_frequency = jnp.array(sampling_frequency) + minimum_frequency = jnp.array(minimum_frequency) + + # Specify the output directory and the name of the simulation. + outdir = "outdir" + label = "jax_fast_tutorial" + + # Set up a random seed for result reproducibility. This is optional! + bilby.core.utils.random.seed(88170235) + + priors = bilby.gw.prior.BBHPriorDict() + injection_parameters = priors.sample() + injection_parameters["geocent_time"] = 1000000000.0 + injection_parameters["luminosity_distance"] = 400.0 + del priors["ra"], priors["dec"] + priors["zenith"] = bilby.core.prior.Cosine() + priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) + priors["L1_time"] = bilby.core.prior.Uniform( + injection_parameters["geocent_time"] - 0.1, + injection_parameters["geocent_time"] + 0.1, + ) + + # Fixed arguments passed into the source model + waveform_arguments = dict( + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + minimum_frequency=minimum_frequency, + ) + + # Create the waveform_generator using a LAL BinaryBlackHole source function + waveform_generator = bilby.gw.WaveformGenerator( + duration=duration, + sampling_frequency=sampling_frequency, + frequency_domain_source_model=ripple_bbh, + waveform_arguments=waveform_arguments, + use_cache=False, + ) + + # Set up interferometers. In this case we'll use two interferometers + # (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design + # sensitivity + ifos = bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=sampling_frequency, + duration=duration, + start_time=injection_parameters["geocent_time"] - duration + 2, + ) + ifos.inject_signal( + waveform_generator=waveform_generator, + parameters=injection_parameters, + raise_error=False, + ) + ifos.set_array_backend(jnp) + + # Initialise the likelihood by passing in the interferometer data (ifos) and + # the waveform generator + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=waveform_generator, + priors=priors, + phase_marginalization=True, + distance_marginalization=True, + reference_frame=ifos, + time_reference="L1", + ) + # Do an initial likelihood evaluation to trigger any internal setup + likelihood.log_likelihood_ratio(priors.sample()) + # Wrap the likelihood with the JittedLikelihood to JIT compile the likelihood + # evaluation + likelihood = JittedLikelihood(likelihood) + # Evaluate the likelihood once to trigger the JIT compilation, this will take + # a few seconds as compiling the waveform takes some time + likelihood.log_likelihood_ratio(priors.sample()) + + # use the log_compiles context so we can make sure there aren't recompilations + # inside the sampling loop + with jax.log_compiles(): + result = bilby.run_sampler( + likelihood=likelihood, + priors=priors, + sampler="dynesty", + nlive=100, + sample="acceptance-walk", + naccept=5, + injection_parameters=injection_parameters, + outdir=outdir, + label=label, + npool=None, + save="hdf5", + rseed=np.random.randint(0, 100000), + ) + + # Make a corner plot. + result.plot_corner() + + +if __name__ == "__main__": + main() diff --git a/jax_requirements.txt b/jax_requirements.txt new file mode 100644 index 000000000..013e0f955 --- /dev/null +++ b/jax_requirements.txt @@ -0,0 +1,3 @@ +interpax +jax +orng diff --git a/optional_requirements.txt b/optional_requirements.txt index c10d7908b..f0f2205f6 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,5 +1,6 @@ celerite george +parameterized plotly pytest-requires pytest-rerunfailures diff --git a/pyproject.toml b/pyproject.toml index 319e6e4d3..3c5a36e57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,9 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Physics" ] description = "A user-friendly Bayesian inference library" @@ -29,7 +32,7 @@ maintainers = [ ] name = "bilby" readme = "README.rst" -requires-python = ">=3.10" +requires-python = ">=3.11" [project.entry-points."bilby.samplers"] "bilby.bilby_mcmc" = "bilby.bilby_mcmc.sampler:Bilby_MCMC" @@ -102,11 +105,13 @@ addopts = [ packages = [ "bilby", "bilby.bilby_mcmc", + "bilby.compat", "bilby.core", "bilby.core.prior", "bilby.core.sampler", "bilby.core.utils", "bilby.gw", + "bilby.gw.compat", "bilby.gw.detector", "bilby.gw.eos", "bilby.gw.likelihood", @@ -121,10 +126,12 @@ dependencies = {file = ["requirements.txt"]} [tool.setuptools.dynamic.optional-dependencies] all = {file = [ "gw_requirements.txt", + "jax_requirements.txt", "mcmc_requirements.txt", "optional_requirements.txt" ]} gw = {file = ["gw_requirements.txt"]} +jax = {file = ["jax_requirements.txt"]} mcmc = {file = ["mcmc_requirements.txt"]} [tool.setuptools.package-data] diff --git a/requirements.txt b/requirements.txt index b045db212..f1a91484d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ -bilby.cython>=0.3.0 +# see https://github.com/data-apis/array-api-compat/pull/341 +array_api_compat>=1.13 +array_api_extra dynesty>=2.0.1 emcee corner @@ -11,4 +13,4 @@ dill tqdm h5py attrs -importlib-metadata>=3.6; python_version < '3.10' +plum-dispatch diff --git a/test/conftest.py b/test/conftest.py index d08c38604..733601d0e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,19 +1,99 @@ +import importlib + +import array_api_compat as aac import pytest +from bilby.compat.utils import BILBY_DEVICE + def pytest_addoption(parser): parser.addoption( "--skip-roqs", action="store_true", default=False, help="Skip all tests that require ROQs" ) + parser.addoption( + "--array-backend", + default=None, + help="Which array to use for testing", + ) def pytest_configure(config): config.addinivalue_line("markers", "requires_roqs: mark a test that requires ROQs") + config.addinivalue_line("markers", "array_backend: mark that a test uses all array backends") def pytest_collection_modifyitems(config, items): if config.getoption("--skip-roqs"): skip_roqs = pytest.mark.skip(reason="Skipping tests that require ROQs") - for item in items: - if "requires_roqs" in item.keywords: - item.add_marker(skip_roqs) + else: + skip_roqs = None + if config.getoption("--array-backend") is not None: + array_only = pytest.mark.skip(reason="Only running backend dependent tests") + else: + array_only = None + for item in items: + if "requires_roqs" in item.keywords and config.getoption("--skip-roqs"): + item.add_marker(skip_roqs) + elif "array_backend" not in item.keywords and array_only is not None: + item.add_marker(array_only) + + +def _xp(request): + # The configuration here loosely follows scipy + # https://github.com/scipy/scipy/blob/b167cae18888a34fc43a439e729383b50f4d373e/scipy/conftest.py#L186 + backend = request.config.getoption("--array-backend") + match backend: + case None | "numpy": + import numpy as xp + case "jax" | "jax.numpy": + import jax + + jax.config.update("jax_enable_x64", True) + jax.config.update("jax_default_device", jax.devices(BILBY_DEVICE)[0]) + xp = jax.numpy + case "torch": + import torch + # torch starts a lot of threads, so disable this on the first import + # to avoid segfaults + try: + torch.set_default_device(BILBY_DEVICE) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + torch.set_default_dtype(torch.float64) + except RuntimeError: + pass + xp = torch + case _: + try: + + xp = importlib.import_module(backend) + except ImportError: + raise ValueError(f"Unknown backend for testing: {backend}") + return aac.get_namespace(xp.ones(1)) + + +def _rng(xp): + import array_api_compat as aac + from bilby.core.utils.random import resolve_random_state + + if aac.is_numpy_namespace(xp): + return resolve_random_state(12345) + elif aac.is_jax_namespace(xp): + import jax.random + return resolve_random_state(jax.random.PRNGKey(12345)) + elif aac.is_torch_namespace(xp): + import torch + return resolve_random_state(torch.Tensor([12345])) + else: + raise ValueError(f"Unknown array namespace {xp} for RNG") + + +@pytest.fixture +def xp(request): + return _xp(request) + + +@pytest.fixture(scope="class") +def xp_class(request): + request.cls.xp = _xp(request) + request.cls.rng = _rng(request.cls.xp) diff --git a/test/core/grid_test.py b/test/core/grid_test.py index f14a95134..781077f34 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -1,30 +1,33 @@ import unittest -import numpy as np import shutil import os -from scipy.stats import multivariate_normal + +import array_api_compat as aac +import numpy as np +import pytest import bilby -# set 2D multivariate Gaussian likelihood class MultiGaussian(bilby.Likelihood): - def __init__(self, mean, cov): + # set 2D multivariate Gaussian likelihood + def __init__(self, mean, cov, *, xp=np): super(MultiGaussian, self).__init__() - self.cov = np.array(cov) - self.mean = np.array(mean) - self.sigma = np.sqrt(np.diag(self.cov)) - self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) + self.xp = xp + self.cov = xp.asarray(cov) + self.mean = xp.asarray(mean) + self.sigma = xp.sqrt(xp.diag(self.cov)) @property def dim(self): return len(self.cov[0]) def log_likelihood(self, parameters): - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.pdf.logpdf(x) + return -parameters["x0"]**2 / 2 - parameters["x1"]**2 / 2 - np.log(2 * np.pi) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGrid(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(7) @@ -33,7 +36,7 @@ def setUp(self): self.mus = [0.0, 0.0] self.cov = [[1.0, 0.0], [0.0, 1.0]] dim = len(self.mus) - self.likelihood = MultiGaussian(self.mus, self.cov) + self.likelihood = MultiGaussian(self.mus, self.cov, xp=self.xp) # set priors out to +/- 5 sigma self.priors = bilby.core.prior.PriorDict() @@ -61,6 +64,7 @@ def setUp(self): grid_size=self.grid_size, likelihood=self.likelihood, save=True, + xp=self.xp, ) self.grid = grid @@ -140,7 +144,9 @@ def test_max_marginalized_likelihood(self): self.assertEqual(1.0, self.grid.marginalize_likelihood(self.grid.parameter_names[1]).max()) def test_ln_evidence(self): - self.assertAlmostEqual(self.expected_ln_evidence, self.grid.ln_evidence, places=5) + ln_z = self.grid.ln_evidence + self.assertEqual(aac.get_namespace(ln_z), self.xp) + self.assertAlmostEqual(self.expected_ln_evidence, float(ln_z), places=5) def test_fail_grid_size(self): with self.assertRaises(TypeError): @@ -151,6 +157,7 @@ def test_fail_grid_size(self): grid_size=2.3, likelihood=self.likelihood, save=True, + xp=self.xp, ) def test_mesh_grid(self): @@ -165,7 +172,8 @@ def test_grid_integer_points(self): outdir="outdir", priors=self.priors, grid_size=n_points, - likelihood=self.likelihood + likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual(tuple(n_points), grid.mesh_grid[0].shape) @@ -179,7 +187,8 @@ def test_grid_dict_points(self): outdir="outdir", priors=self.priors, grid_size=n_points, - likelihood=self.likelihood + likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual((n_points["x0"], n_points["x1"]), grid.mesh_grid[0].shape) self.assertEqual(grid.mesh_grid[0][0, 0], self.priors[self.grid.parameter_names[0]].minimum) @@ -196,6 +205,7 @@ def test_grid_from_array(self): priors=self.priors, grid_size=n_points, likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual((len(x0s), len(x1s)), grid.mesh_grid[0].shape) @@ -208,7 +218,7 @@ def test_grid_from_array(self): def test_save_and_load_from_filename(self): filename = os.path.join("outdir", "test_output.json") self.grid.save_to_file(filename=filename) - new_grid = bilby.core.grid.Grid.read(filename=filename) + new_grid = bilby.core.grid.Grid.read(filename=filename, xp=self.xp) self.assertListEqual(new_grid.parameter_names, self.grid.parameter_names) self.assertEqual(new_grid.n_dims, self.grid.n_dims) @@ -221,7 +231,7 @@ def test_save_and_load_from_filename(self): def test_save_and_load_from_outdir_label(self): self.grid.save_to_file(overwrite=True, outdir="outdir") - new_grid = bilby.core.grid.Grid.read(outdir="outdir", label="label") + new_grid = bilby.core.grid.Grid.read(outdir="outdir", label="label", xp=self.xp) self.assertListEqual(self.grid.parameter_names, new_grid.parameter_names) self.assertEqual(self.grid.n_dims, new_grid.n_dims) @@ -238,7 +248,7 @@ def test_save_and_load_from_outdir_label(self): def test_save_and_load_gzip(self): filename = os.path.join("outdir", "test_output.json.gz") self.grid.save_to_file(filename=filename) - new_grid = bilby.core.grid.Grid.read(filename=filename) + new_grid = bilby.core.grid.Grid.read(filename=filename, xp=self.xp) self.assertListEqual(self.grid.parameter_names, new_grid.parameter_names) self.assertEqual(self.grid.n_dims, new_grid.n_dims) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 3c4c71c26..2d7425bb2 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -1,9 +1,13 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np +import pytest +import array_api_extra as xpx import bilby.core.likelihood +from bilby.compat.utils import array_module from bilby.core.likelihood import ( Likelihood, GaussianLikelihood, @@ -51,10 +55,12 @@ def test_meta_data(self): self.assertEqual(self.likelihood.meta_data, meta_data) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalytical1DLikelihood(unittest.TestCase): def setUp(self): - self.x = np.arange(start=0, stop=100, step=1) - self.y = np.arange(start=0, stop=100, step=1) + self.x = self.xp.arange(0, 100, step=1) + self.y = self.xp.arange(0, 100, step=1) def test_func(x, parameter1, parameter2): return parameter1 * x + parameter2 @@ -78,7 +84,7 @@ def test_init_x(self): self.assertTrue(np.array_equal(self.x, self.analytical_1d_likelihood.x)) def test_set_x_to_array(self): - new_x = np.arange(start=0, stop=50, step=2) + new_x = self.xp.arange(0, 50, step=2) self.analytical_1d_likelihood.x = new_x self.assertTrue(np.array_equal(new_x, self.analytical_1d_likelihood.x)) @@ -98,7 +104,7 @@ def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.analytical_1d_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.analytical_1d_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.analytical_1d_likelihood.y)) @@ -154,17 +160,20 @@ def test_repr(self): self.assertEqual(expected, repr(self.analytical_1d_likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGaussianLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.sigma = 0.1 - self.x = np.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + np.random.normal(0, self.sigma, self.N) + self.x = self.xp.linspace(0, 1, self.N) + self.y = 2 * self.x + 1 + self.xp.asarray(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function + self.parameters = dict(m=self.xp.asarray(2.0), c=self.xp.asarray(0.0)) def tearDown(self): del self.N @@ -211,19 +220,27 @@ def test_repr(self): ) self.assertEqual(expected, repr(likelihood)) + def test_return_class(self): + likelihood = GaussianLikelihood(self.x, self.y, self.function, self.sigma) + logl = likelihood.log_likelihood(self.parameters) + self.assertEqual(aac.get_namespace(logl), self.xp) + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestStudentTLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.nu = self.N - 2 self.sigma = 1 - self.x = np.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + np.random.normal(0, self.sigma, self.N) + self.x = self.xp.linspace(0, 1, self.N) + self.y = 2 * self.x + 1 + self.xp.asarray(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function + self.parameters = dict(m=self.xp.asarray(2.0), c=self.xp.asarray(0.0)) def tearDown(self): del self.N @@ -262,6 +279,11 @@ def test_log_likelihood_nu_negative(self): with self.assertRaises(ValueError): likelihood.log_likelihood(parameters) + def test_setting_nu_positive_does_not_change_class_attribute(self): + likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=None) + likelihood.nu = 98 + self.assertEqual(likelihood.nu, 98) + def test_lam(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=0, sigma=0.5) @@ -279,25 +301,28 @@ def test_repr(self): self.assertEqual(expected, repr(likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPoissonLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.mu = 5 - self.x = np.linspace(0, 1, self.N) - self.y = np.random.poisson(self.mu, self.N) - self.yfloat = np.copy(self.y) * 1.0 - self.yneg = np.copy(self.y) - self.yneg[0] = -1 + self.x = self.xp.linspace(0, 1, self.N) + self.y = self.xp.asarray(np.random.poisson(self.mu, self.N)) + self.yfloat = self.y * 1.0 + self.yneg = self.y * 1.0 + self.yneg = xpx.at(self.yneg, 0).set(-1) def test_function(x, c): return c def test_function_array(x, c): - return np.ones(len(x)) * c + return self.xp.ones(len(x)) * c self.function = test_function self.function_array = test_function_array self.poisson_likelihood = PoissonLikelihood(self.x, self.y, self.function) + self.bad_parameters = dict(c=self.xp.asarray(-2.0)) def tearDown(self): del self.N @@ -311,6 +336,8 @@ def tearDown(self): del self.poisson_likelihood def test_init_y_non_integer(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("Torch tensor dtype does not have a 'kind' attribute") with self.assertRaises(ValueError): PoissonLikelihood(self.x, self.yfloat, self.function) @@ -330,12 +357,14 @@ def test_neg_rate_array(self): likelihood.log_likelihood(parameters) def test_init_y(self): - self.assertTrue(np.array_equal(self.y, self.poisson_likelihood.y)) + self.assertEqual(aac.get_namespace(self.y), aac.get_namespace(self.poisson_likelihood.y)) + np.testing.assert_array_equal(np.asarray(self.y), np.asarray(self.poisson_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.poisson_likelihood.y = new_y - self.assertTrue(np.array_equal(new_y, self.poisson_likelihood.y)) + self.assertEqual(aac.get_namespace(new_y), aac.get_namespace(self.poisson_likelihood.y)) + np.testing.assert_array_equal(np.asarray(new_y), np.asarray(self.poisson_likelihood.y)) def test_set_y_to_positive_int(self): new_y = 5 @@ -360,25 +389,25 @@ def test_log_likelihood_wrong_func_return_type(self): def test_log_likelihood_negative_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, -2]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([3, 6, -2]) ) with self.assertRaises(ValueError): poisson_likelihood.log_likelihood(dict()) def test_log_likelihood_zero_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, 0]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([3, 6, 0]) ) self.assertEqual(-np.inf, poisson_likelihood.log_likelihood(dict())) def test_log_likelihood_dummy(self): """ Merely tests if it goes into the right if else bracket """ poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.linspace(1, 100, self.N) + x=self.x, y=self.y, func=lambda x: self.xp.linspace(1, 100, self.N) ) - with mock.patch("numpy.sum") as m: + with mock.patch(f"{self.xp.__name__}.sum") as m: m.return_value = 1 - self.assertEqual(1, poisson_likelihood.log_likelihood(dict())) + self.assertEqual(1, poisson_likelihood.log_likelihood(dict(c=5))) def test_repr(self): likelihood = PoissonLikelihood(self.x, self.y, self.function) @@ -388,26 +417,29 @@ def test_repr(self): self.assertEqual(expected, repr(likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestExponentialLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.mu = 5 - self.x = np.linspace(0, 1, self.N) - self.y = np.random.exponential(self.mu, self.N) - self.yneg = np.copy(self.y) - self.yneg[0] = -1.0 + self.x = self.xp.linspace(0, 1, self.N) + self.y = self.xp.asarray(np.random.exponential(self.mu, self.N)) + self.yneg = self.y * 1.0 + self.yneg = xpx.at(self.yneg, 0).set(-1.0) def test_function(x, c): return c def test_function_array(x, c): - return c * np.ones(len(x)) + return c * self.xp.ones(len(x)) self.function = test_function self.function_array = test_function_array self.exponential_likelihood = ExponentialLikelihood( x=self.x, y=self.y, func=self.function ) + self.bad_parameters = dict(c=self.xp.asarray(-1.0)) def tearDown(self): del self.N @@ -436,7 +468,7 @@ def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.exponential_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.exponential_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.exponential_likelihood.y)) @@ -461,14 +493,17 @@ def test_set_y_to_negative_float(self): def test_set_y_to_nd_array_with_negative_element(self): with self.assertRaises(ValueError): - self.exponential_likelihood.y = np.array([4.3, -1.2, 4]) + self.exponential_likelihood.y = self.xp.asarray([4.3, -1.2, 4]) def test_log_likelihood_default(self): """ Merely tests that it ends up at the right place in the code """ exponential_likelihood = ExponentialLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([4.2]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([4.2]) ) - with mock.patch("numpy.sum") as m: + # xp is not always the same as self.xp as array_api_compat uses its + # own version of numpy + xp = array_module(exponential_likelihood.func(None)) + with mock.patch(f"{xp.__name__}.sum") as m: m.return_value = 3 self.assertEqual(-3, exponential_likelihood.log_likelihood(dict())) @@ -479,11 +514,17 @@ def test_repr(self): self.assertEqual(expected, repr(self.exponential_likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalyticalMultidimensionalCovariantGaussian(unittest.TestCase): def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean = [10, 11, 12] + if self.xp != np: + self.cov = self.xp.asarray(self.cov, dtype=float) + self.sigma = self.xp.asarray(self.sigma, dtype=float) + self.mean = self.xp.asarray(self.mean, dtype=float) self.likelihood = AnalyticalMultidimensionalCovariantGaussian( mean=self.mean, cov=self.cov ) @@ -507,16 +548,30 @@ def test_dim(self): self.assertEqual(3, self.likelihood.dim) def test_log_likelihood(self): - likelihood = AnalyticalMultidimensionalCovariantGaussian(mean=[0], cov=[1]) - self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) + likelihood = AnalyticalMultidimensionalCovariantGaussian( + mean=self.xp.asarray([0.0]), cov=self.xp.asarray([1.0]) + ) + logl = likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))) + self.assertEqual( + -np.log(2 * np.pi) / 2, + likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))), + ) + self.assertEqual(aac.get_namespace(logl), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalyticalMultidimensionalBimodalCovariantGaussian(unittest.TestCase): def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean_1 = [10, 11, 12] self.mean_2 = [20, 21, 22] + if self.xp != np: + self.cov = self.xp.asarray(self.cov, dtype=float) + self.sigma = self.xp.asarray(self.sigma, dtype=float) + self.mean_1 = self.xp.asarray(self.mean_1, dtype=float) + self.mean_2 = self.xp.asarray(self.mean_2, dtype=float) self.likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=self.mean_1, mean_2=self.mean_2, cov=self.cov ) @@ -547,7 +602,10 @@ def test_log_likelihood(self): likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=[0], mean_2=[0], cov=[1] ) - self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) + self.assertEqual( + -np.log(2 * np.pi) / 2, + likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))), + ) class TestJointLikelihood(unittest.TestCase): diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 12892aca1..cbcadd880 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -1,16 +1,24 @@ import unittest -import numpy as np +import array_api_compat as aac import bilby +import numpy as np +import pytest +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestDiscreteValuesPrior(unittest.TestCase): + def setUp(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("DiscreteValues prior is unstable for torch backend") + def test_single_sample(self): values = [1.1, 1.2, 1.3] discrete_value_prior = bilby.core.prior.DiscreteValues(values) in_prior = True for _ in range(1000): - s = discrete_value_prior.sample() + s = discrete_value_prior.sample(random_state=self.rng) if s not in values: in_prior = False self.assertTrue(in_prior) @@ -20,7 +28,7 @@ def test_array_sample(self): nvalues = 4 discrete_value_prior = bilby.core.prior.DiscreteValues(values) N = 100000 - s = discrete_value_prior.sample(N) + s = discrete_value_prior.sample(N, random_state=self.rng) zeros = np.sum(s == 1.0) ones = np.sum(s == 1.1) twos = np.sum(s == 1.2) @@ -35,60 +43,64 @@ def test_single_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(1.1), 1 / N) - self.assertEqual(discrete_value_prior.prob(2.2), 1 / N) - self.assertEqual(discrete_value_prior.prob(300.0), 1 / N) - self.assertEqual(discrete_value_prior.prob(0.5), 0) - self.assertEqual(discrete_value_prior.prob(200), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(200)), 0) def test_single_probability_unsorted(self): N = 3 values = [1.1, 300, 2.2] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(1.1), 1 / N) - self.assertEqual(discrete_value_prior.prob(2.2), 1 / N) - self.assertEqual(discrete_value_prior.prob(300.0), 1 / N) - self.assertEqual(discrete_value_prior.prob(0.5), 0) - self.assertEqual(discrete_value_prior.prob(200), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(200)), 0) + self.assertEqual( + aac.get_namespace(discrete_value_prior.prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertTrue( - np.all( - discrete_value_prior.prob([1.1, 2.2, 2.2, 300.0, 200.0]) - == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) - ) - ) + probs = discrete_value_prior.prob(self.xp.asarray([1.1, 2.2, 2.2, 300.0, 200.0])) + self.assertEqual(aac.get_namespace(probs), self.xp) + np.testing.assert_array_equal(np.asarray(probs), np.array([1 / N] * 4 + [0])) def test_single_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.ln_prob(1.1), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(2.2), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(300), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(150), -np.inf) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(1.1)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(2.2)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(300)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(150)), -np.inf) + self.assertEqual( + aac.get_namespace(discrete_value_prior.ln_prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertTrue( - np.all( - discrete_value_prior.ln_prob([1.1, 2.2, 2.2, 300, 150]) - == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) - ) - ) + ln_probs = discrete_value_prior.ln_prob(self.xp.asarray([1.1, 2.2, 2.2, 300, 150])) + self.assertEqual(aac.get_namespace(ln_probs), self.xp) + np.testing.assert_array_equal(np.asarray(ln_probs), np.array([-np.log(N)] * 4 + [-np.inf])) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestCategoricalPrior(unittest.TestCase): def test_single_sample(self): categorical_prior = bilby.core.prior.Categorical(3) in_prior = True for _ in range(1000): - s = categorical_prior.sample() + s = categorical_prior.sample(random_state=self.rng) if s not in [0, 1, 2]: in_prior = False self.assertTrue(in_prior) @@ -97,7 +109,9 @@ def test_array_sample(self): ncat = 4 categorical_prior = bilby.core.prior.Categorical(ncat) N = 100000 - s = categorical_prior.sample(N) + s = categorical_prior.sample(N, random_state=self.rng) + self.assertEqual(aac.get_namespace(s), self.xp) + s = np.asarray(s) zeros = np.sum(s == 0) ones = np.sum(s == 1) twos = np.sum(s == 2) @@ -111,37 +125,55 @@ def test_array_sample(self): def test_single_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.prob(0), 1 / N) - self.assertEqual(categorical_prior.prob(1), 1 / N) - self.assertEqual(categorical_prior.prob(2), 1 / N) - self.assertEqual(categorical_prior.prob(0.5), 0) + self.assertEqual(categorical_prior.prob(self.xp.asarray(0)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(1)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(2)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual( + aac.get_namespace(categorical_prior.prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.prob([0, 1, 1, 2, 3]) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]))) + probs = categorical_prior.prob(self.xp.asarray([0, 1, 1, 2, 3])) + self.assertEqual(aac.get_namespace(probs), self.xp) + + self.assertTrue(np.all( + np.asarray(probs) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) + )) def test_single_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.ln_prob(0), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(1), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(2), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(0.5), -np.inf) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(0)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(1)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(2)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(0.5)), -np.inf) + self.assertEqual( + aac.get_namespace(categorical_prior.ln_prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.ln_prob([0, 1, 1, 2, 3]) == np.array( - [-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]))) + ln_prob = categorical_prior.ln_prob(self.xp.asarray([0, 1, 1, 2, 3])) + self.assertEqual(aac.get_namespace(ln_prob), self.xp) + self.assertTrue(np.all( + np.asarray(ln_prob) == np.array([-np.log(N)] * 4 + [-np.inf]) + )) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestWeightedCategoricalPrior(unittest.TestCase): def test_single_sample(self): categorical_prior = bilby.core.prior.WeightedCategorical(3, [1, 2, 3]) in_prior = True for _ in range(1000): - s = categorical_prior.sample() + s = categorical_prior.sample(random_state=self.rng) if s not in [0, 1, 2]: in_prior = False self.assertTrue(in_prior) @@ -157,7 +189,9 @@ def test_array_sample(self): weights = np.arange(1, ncat + 1) categorical_prior = bilby.core.prior.WeightedCategorical(ncat, weights=weights) N = 100000 - s = categorical_prior.sample(N) + s = categorical_prior.sample(N, random_state=self.rng) + self.assertEqual(aac.get_namespace(s), self.xp) + s = np.asarray(s) cases = 0 for i in categorical_prior.values: case = np.sum(s == i) @@ -170,26 +204,35 @@ def test_single_probability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in categorical_prior.values: + for i in self.xp.asarray(categorical_prior.values): self.assertEqual(categorical_prior.prob(i), weights[i] / np.sum(weights)) - self.assertEqual(categorical_prior.prob(0.5), 0) + prob = categorical_prior.prob(self.xp.asarray(0.5)) + self.assertEqual(prob, 0) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_array_probability(self): N = 3 - test_cases = [0, 1, 1, 2, 3] + test_cases = self.xp.asarray([0, 1, 1, 2, 3]) weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) probs = np.arange(1, N + 2) / np.sum(weights) probs[-1] = 0 - self.assertTrue(np.all(categorical_prior.prob(test_cases) == probs[test_cases])) + new = categorical_prior.prob(test_cases) + self.assertEqual(aac.get_namespace(new), self.xp) + self.assertTrue(np.all(np.asarray(new) == probs[test_cases])) def test_single_lnprobability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in categorical_prior.values: - self.assertEqual(categorical_prior.ln_prob(i), np.log(weights[i] / np.sum(weights))) - self.assertEqual(categorical_prior.prob(0.5), 0) + for i in self.xp.asarray(categorical_prior.values): + self.assertEqual( + categorical_prior.ln_prob(self.xp.asarray(i)), + np.log(weights[i] / np.sum(weights)), + ) + prob = categorical_prior.prob(self.xp.asarray(0.5)) + self.assertEqual(prob, 0) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_array_lnprobability(self): N = 3 @@ -200,7 +243,9 @@ def test_array_lnprobability(self): ln_probs = np.log(np.arange(1, N + 2) / np.sum(weights)) ln_probs[-1] = -np.inf - self.assertTrue(np.all(categorical_prior.ln_prob(test_cases) == ln_probs[test_cases])) + new = categorical_prior.ln_prob(self.xp.asarray(test_cases)) + self.assertEqual(aac.get_namespace(new), self.xp) + self.assertTrue(np.all(np.asarray(new) == ln_probs[test_cases])) def test_cdf(self): """ @@ -212,12 +257,13 @@ def test_cdf(self): weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - sample = categorical_prior.sample(size=10) - original = np.asarray(sample) - new = np.array(categorical_prior.rescale( + sample = categorical_prior.sample(size=10, random_state=self.rng) + original = self.xp.asarray(sample) + new = self.xp.asarray(categorical_prior.rescale( categorical_prior.cdf(sample) )) np.testing.assert_array_equal(original, new) + self.assertEqual(type(new), type(original)) if __name__ == "__main__": diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index c9b788732..e61bbbf11 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import Mock +import array_api_compat as aac import numpy as np +import pytest import bilby @@ -56,7 +58,7 @@ def test_base_prob(self): self.assertTrue(np.isnan(self.prior.prob(5))) def test_base_ln_prob(self): - self.prior.prob = lambda val: val + self.prior.prob = lambda val, *, xp=None: val self.assertEqual(np.log(5), self.prior.ln_prob(5)) def test_is_in_prior(self): @@ -139,6 +141,8 @@ def test_prob_inside(self): self.assertEqual(1, self.prior.prob(0.5)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConstraintPriorNormalisation(unittest.TestCase): def setUp(self): self.priors = dict( @@ -154,8 +158,10 @@ def conversion_func(parameters): def test_prob_integrate_to_one(self): keys = ["a", "b", "c"] n_samples = 1000000 - samples = self.priors.sample_subset(keys=keys, size=n_samples) + samples = self.priors.sample_subset(keys=keys, size=n_samples, random_state=self.rng) prob = self.priors.prob(samples, axis=0) + self.assertEqual(aac.get_namespace(prob), self.xp) + prob = np.asarray(prob) dm1 = self.priors["a"].maximum - self.priors["a"].minimum dm2 = self.priors["b"].maximum - self.priors["b"].minimum prior_volume = (dm1 * dm2) @@ -169,5 +175,24 @@ def test_prob_integrate_to_one(self): self.assertAlmostEqual(1, integral, delta=7 * sigma_integral) +class TestPriorSubclassWithoutXpWarning(unittest.TestCase): + def test_custom_subclass_without_xp_issues_warning(self): + """Test that a custom prior subclass without xp parameter in rescale method issues a warning.""" + with pytest.warns( + DeprecationWarning, + match=r"rescale.*CustomPriorWithoutXp.*xp.*keyword argument", + ): + # Define a custom prior subclass that doesn't include xp in rescale method + class CustomPriorWithoutXp(bilby.core.prior.Prior): + def rescale(self, val): + """Custom rescale without xp parameter""" + return val * 2 + + prior = CustomPriorWithoutXp(name="custom_prior") + import jax.numpy as jnp + rescaled = prior.rescale(jnp.array([0.1, 0.2, 3])) + self.assertEqual(aac.get_namespace(rescaled), jnp) + + if __name__ == "__main__": unittest.main() diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..2d3a874f0 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -3,9 +3,11 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np import pandas as pd import pickle +import pytest import bilby @@ -172,6 +174,8 @@ def test_cond_prior_instantiation_no_boundary_prior(self): self.assertIsNone(prior.boundary) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConditionalPriorDict(unittest.TestCase): def setUp(self): def condition_func_1(reference_parameters, var_0): @@ -208,7 +212,12 @@ def condition_func_3(reference_parameters, var_1, var_2): self.conditional_priors_manually_set_items = ( bilby.core.prior.ConditionalPriorDict() ) - self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4) + self.test_sample = dict( + var_0=self.xp.asarray(0.7), + var_1=self.xp.asarray(0.6), + var_2=self.xp.asarray(0.5), + var_3=self.xp.asarray(0.4), + ) self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)]) for key, value in dict( var_0=self.prior_0, @@ -260,12 +269,14 @@ def test_conditional_keys_setting_items(self): ) def test_prob(self): - self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample)) + prob = self.conditional_priors.prob(sample=self.test_sample) + self.assertEqual(self.test_value, prob) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_prob_illegal_conditions(self): del self.conditional_priors["var_0"] with self.assertRaises(bilby.core.prior.IllegalConditionsException): - self.conditional_priors.prob(sample=self.test_sample) + self.conditional_priors.prob(sample=self.test_sample, xp=self.xp) def test_ln_prob(self): self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample)) @@ -324,7 +335,7 @@ def test_rescale(self): expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res) + np.testing.assert_array_equal(expected, res) def test_rescale_with_joint_prior(self): """ @@ -349,19 +360,20 @@ def test_rescale_with_joint_prior(self): ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = list(self.test_sample.values()) + ref_variables = ref_variables[:2] + [0.1] + ref_variables[2:] + [0.4] + keys = list(self.test_sample.keys()) + keys = keys[:2] + ["mvgvar_0"] + keys[2:] + ["mvgvar_1"] res = priordict.rescale(keys=keys, theta=ref_variables) - self.assertIsInstance(res, list) self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(aac.get_namespace(res), self.xp) # check conditional values are still as expected expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res[0:4]) + np.testing.assert_array_equal(expected, list(res)[:2] + list(res)[3:5]) def test_cdf(self): """ @@ -370,11 +382,11 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ sample = self.conditional_priors.sample() - self.assertEqual( + np.testing.assert_array_equal( self.conditional_priors.rescale( sample.keys(), self.conditional_priors.cdf(sample=sample).values() - ), list(sample.values()) + ), np.array(list(sample.values())) ) def test_rescale_illegal_conditions(self): @@ -446,6 +458,8 @@ def _tp_conditional_uniform(ref_params, period): prior.sample_subset(["tp"], 1000) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestDirichletPrior(unittest.TestCase): def setUp(self): @@ -455,6 +469,10 @@ def tearDown(self): if os.path.isdir("priors"): shutil.rmtree("priors") + def test_samples_correct_type(self): + samples = self.priors.sample(10, random_state=self.rng) + self.assertEqual(aac.get_namespace(samples["dirichlet_1"]), self.xp) + def test_samples_sum_to_less_than_one(self): """ Test that the samples sum to less than one as required for the diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 089611aee..cdd996f19 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -2,7 +2,9 @@ import unittest from unittest.mock import Mock, patch +import array_api_compat as aac import numpy as np +import pytest import bilby @@ -22,6 +24,8 @@ def __init__(self, names, bounds=None): setattr(bilby.core.prior, "FakeJointPriorDist", FakeJointPriorDist) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorDict(unittest.TestCase): def setUp(self): @@ -268,30 +272,41 @@ def test_dict_argument_is_not_string_or_dict(self): def test_sample_subset_correct_size(self): size = 7 samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size + keys=self.prior_set_from_dict.keys(), size=size, + random_state=self.rng, ) self.assertEqual(len(self.prior_set_from_dict), len(samples)) for key in samples: self.assertEqual(size, len(samples[key])) + self.assertEqual(aac.get_namespace(samples[key]), self.xp) def test_sample_subset_correct_size_when_non_priors_in_dict(self): self.prior_set_from_dict["asdf"] = "not_a_prior" samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys() + keys=self.prior_set_from_dict.keys(), + random_state=self.rng, ) self.assertEqual(len(self.prior_set_from_dict) - 1, len(samples)) + for key in samples: + if not isinstance(samples[key], (int, float)): + self.assertIsNotNone(aac.get_namespace(samples[key]), self.xp) def test_sample_subset_with_actual_subset(self): size = 3 - samples = self.prior_set_from_dict.sample_subset(keys=["length"], size=size) - expected = dict(length=np.array([42.0, 42.0, 42.0])) + samples = self.prior_set_from_dict.sample_subset( + keys=["length"], size=size, random_state=self.rng + ) + expected = dict(length=self.xp.asarray([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + self.assertEqual(aac.get_namespace(samples["length"]), self.xp) def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] - out = self.prior_set_from_dict.sample_subset_constrained_as_array(keys, size) - self.assertTrue(isinstance(out, np.ndarray)) + out = self.prior_set_from_dict.sample_subset_constrained_as_array( + keys, size, random_state=self.rng + ) + self.assertEqual(aac.get_namespace(out), self.xp) self.assertTrue(out.shape == (len(keys), size)) def test_sample_subset_constrained(self): @@ -312,7 +327,7 @@ def conversion_function(parameters): with patch("bilby.core.prior.logger.warning") as mock_warning: samples1 = priors1.sample_subset_constrained( - keys=list(priors1.keys()), size=N + keys=list(priors1.keys()), size=N, random_state=self.rng ) self.assertEqual(len(priors1) - 1, len(samples1)) for key in samples1: @@ -325,14 +340,17 @@ def conversion_function(parameters): with patch("bilby.core.prior.logger.warning") as mock_warning: samples2 = priors2.sample_subset_constrained( - keys=list(priors2.keys()), size=N + keys=list(priors2.keys()), size=N, random_state=self.rng ) self.assertEqual(len(priors2), len(samples2)) for key in samples2: self.assertEqual(N, len(samples2[key])) mock_warning.assert_not_called() - def test_sample(self): + def test_sample_with_random_seed(self): + """ + This test uses the default RNG, so don't specify random_state. + """ size = 7 bilby.core.utils.random.seed(42) samples1 = self.prior_set_from_dict.sample_subset( @@ -342,21 +360,34 @@ def test_sample(self): samples2 = self.prior_set_from_dict.sample(size=size) self.assertEqual(set(samples1.keys()), set(samples2.keys())) for key in samples1: - self.assertTrue(np.array_equal(samples1[key], samples2[key])) + np.testing.assert_array_equal(samples1[key], samples2[key]) + + def test_sample_returns_correct_type(self): + """ + This test uses the default RNG, so don't specify random_state. + """ + size = 7 + samples = self.prior_set_from_dict.sample_subset( + keys=self.prior_set_from_dict.keys(), size=size, random_state=self.rng + ) + for key in samples: + self.assertEqual(aac.get_namespace(samples[key]), self.xp) def test_prob(self): - samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) + samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], random_state=self.rng) expected = self.first_prior.prob(samples["mass"]) * self.second_prior.prob( samples["speed"] ) self.assertEqual(expected, self.prior_set_from_dict.prob(samples)) + self.assertEqual(aac.get_namespace(expected), self.xp) def test_ln_prob(self): - samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) + samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], random_state=self.rng) expected = self.first_prior.ln_prob( samples["mass"] ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + self.assertEqual(aac.get_namespace(expected), self.xp) def test_rescale(self): theta = [0.5, 0.5, 0.5] @@ -380,13 +411,14 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ - sample = self.prior_set_from_dict.sample() - original = np.array(list(sample.values())) - new = np.array(self.prior_set_from_dict.rescale( + sample = self.prior_set_from_dict.sample(random_state=self.rng) + original = self.xp.asarray(list(sample.values())) + new = self.xp.asarray(self.prior_set_from_dict.rescale( sample.keys(), self.prior_set_from_dict.cdf(sample=sample).values() )) self.assertLess(max(abs(original - new)), 1e-10) + self.assertEqual(aac.get_namespace(new), self.xp) def test_redundancy(self): for key in self.prior_set_from_dict.keys(): diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 17d360d0c..300310616 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -1,14 +1,37 @@ +import array_api_compat as aac import bilby import unittest import numpy as np import os +import pytest import scipy.stats as ss from scipy.integrate import trapezoid +aligned_prior_complex = bilby.gw.prior.AlignedSpin( + a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), + z_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0, minimum=-1), + name="test", + unit="unit", + num_interp=1000, +) + +hp_map_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "prior_files/GW150914_testing_skymap.fits", +) +hp_dist = bilby.gw.prior.HealPixMapPriorDist( + hp_map_file, names=["testra", "testdec"] +) +hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( + hp_map_file, names=["testra", "testdec", "testdistance"], distance=True +) + + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorClasses(unittest.TestCase): def setUp(self): - # set multivariate Gaussian mvg = bilby.core.prior.MultivariateGaussianDist( names=["testa", "testb"], @@ -22,16 +45,10 @@ def setUp(self): covs=np.array([[2.0, 0.5], [0.5, 2.0]]), weights=1.0, ) - hp_map_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "prior_files/GW150914_testing_skymap.fits", - ) - hp_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec"] - ) - hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec", "testdistance"], distance=True - ) + + # need to reset this for the repr test to get equality correct + hp_dist.requested_parameters = {"testra": None, "testdec": None} + hp_3d_dist.requested_parameters = {"testra": None, "testdec": None, "testdistance": None} def condition_func(reference_params, test_param): return reference_params.copy() @@ -102,13 +119,7 @@ def condition_func(reference_params, test_param): name="test", unit="unit", minimum=1e-2, maximum=1e2 ), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), - bilby.gw.prior.AlignedSpin( - a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), - z_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0, minimum=-1), - name="test", - unit="unit", - num_interp=1000, - ), + aligned_prior_complex, bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), bilby.core.prior.MultivariateNormal(dist=mvn, name="testa", unit="unit"), @@ -243,10 +254,24 @@ def condition_func(reference_params, test_param): dist=hp_3d_dist, name="testdistance", unit="unit" ), ] + if aac.is_torch_namespace(self.xp): + self.priors = [ + p for p in self.priors + if not isinstance(p, bilby.core.prior.Interped) + ] + elif aac.is_jax_namespace(self.xp): + self.priors = [ + p for p in self.priors + if not isinstance(p, bilby.core.prior.StudentT) + ] def tearDown(self): del self.priors + def _validate_return_type(self, val): + if not isinstance(val, (int, float)): + self.assertEqual(aac.get_namespace(val), self.xp) + def test_minimum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: @@ -257,26 +282,37 @@ def test_minimum_rescaling(self): # the edge of the prior is extremely suppressed for these priors # and so the rescale function doesn't quite return the lower bound continue - elif bilby.core.prior.JointPrior in prior.__class__.__mro__: - minimum_sample = prior.rescale(0) - if prior.dist.filled_rescale(): - self.assertAlmostEqual(minimum_sample[0], prior.minimum) - self.assertAlmostEqual(minimum_sample[1], prior.minimum) - else: - minimum_sample = prior.rescale(0) - self.assertAlmostEqual(minimum_sample, prior.minimum) + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue + with self.subTest(prior=prior): + if bilby.core.prior.JointPrior in prior.__class__.__mro__: + minimum_sample = prior.rescale(self.xp.asarray(0)) + if prior.dist.filled_rescale(): + self.assertAlmostEqual(np.asarray(minimum_sample[0]), prior.minimum) + self.assertAlmostEqual(np.asarray(minimum_sample[1]), prior.minimum) + else: + minimum_sample = prior.rescale(self.xp.asarray(0)) + self.assertAlmostEqual(np.asarray(minimum_sample), prior.minimum) def test_maximum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: - if bilby.core.prior.JointPrior in prior.__class__.__mro__: - maximum_sample = prior.rescale(0) - if prior.dist.filled_rescale(): - self.assertAlmostEqual(maximum_sample[0], prior.maximum) - self.assertAlmostEqual(maximum_sample[1], prior.maximum) - else: - maximum_sample = prior.rescale(1) - self.assertAlmostEqual(maximum_sample, prior.maximum) + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue + with self.subTest(prior=prior): + if bilby.core.prior.JointPrior in prior.__class__.__mro__: + maximum_sample = prior.rescale(self.xp.asarray(0)) + if prior.dist.filled_rescale(): + self.assertAlmostEqual(np.asarray(maximum_sample[0]), prior.maximum) + self.assertAlmostEqual(np.asarray(maximum_sample[1]), prior.maximum) + elif isinstance(prior, bilby.gw.prior.AlignedSpin): + maximum_sample = prior.rescale(self.xp.asarray(1)) + self.assertGreater(np.asarray(maximum_sample), 0.997) + else: + maximum_sample = prior.rescale(self.xp.asarray(1)) + self.assertAlmostEqual(np.asarray(maximum_sample), prior.maximum) def test_many_sample_rescaling(self): """Test the the rescaling works as expected.""" @@ -284,20 +320,27 @@ def test_many_sample_rescaling(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue + many_samples = prior.rescale(self.xp.asarray(np.random.uniform(0, 1, 1000))) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): continue - self.assertTrue( - all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)) - ) + with self.subTest(prior=prior): + self.assertTrue( + all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)) + ) + self._validate_return_type(many_samples) def test_least_recently_sampled(self): for prior in self.priors: - least_recently_sampled_expected = prior.sample() - self.assertEqual( - least_recently_sampled_expected, prior.least_recently_sampled - ) + with self.subTest(prior=prior): + least_recently_sampled_expected = prior.sample(random_state=self.rng) + self.assertEqual( + least_recently_sampled_expected, prior.least_recently_sampled + ) + self._validate_return_type(least_recently_sampled_expected) def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" @@ -305,10 +348,11 @@ def test_sampling_single(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - single_sample = prior.sample() - self.assertTrue( - (single_sample >= prior.minimum) & (single_sample <= prior.maximum) - ) + with self.subTest(prior=prior): + single_sample = prior.sample(random_state=self.rng) + self.assertGreaterEqual(single_sample, prior.minimum) + self.assertLessEqual(single_sample, prior.maximum) + self._validate_return_type(single_sample) def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" @@ -316,17 +360,17 @@ def test_sampling_many(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.sample(5000) - self.assertTrue( - (all(many_samples >= prior.minimum)) - & (all(many_samples <= prior.maximum)) - ) + with self.subTest(prior=prior): + many_samples = prior.sample(5000, random_state=self.rng) + self.assertGreaterEqual(min(many_samples), prior.minimum) + self.assertLessEqual(max(many_samples), prior.maximum) + self._validate_return_type(many_samples) def test_probability_above_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: if prior.maximum != np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: @@ -342,7 +386,7 @@ def test_probability_below_domain(self): # SymmetricLogUniform has support down to -maximum continue if prior.minimum != -np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: @@ -353,31 +397,48 @@ def test_probability_below_domain(self): def test_least_recently_sampled_2(self): for prior in self.priors: - lrs = prior.sample() - self.assertEqual(lrs, prior.least_recently_sampled) + with self.subTest(prior=prior): + lrs = prior.sample(random_state=self.rng) + self.assertEqual(lrs, prior.least_recently_sampled) + self._validate_return_type(lrs) def test_prob_and_ln_prob(self): for prior in self.priors: - sample = prior.sample() - if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa + sample = prior.sample(random_state=self.rng) + if bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa # due to the way that the Multivariate Gaussian prior must sequentially call # the prob and ln_prob functions, it must be ignored in this test. - self.assertAlmostEqual( - np.log(prior.prob(sample)), prior.ln_prob(sample), 12 - ) + continue + with self.subTest(prior=prior): + lnprob = prior.ln_prob(sample) + prob = prior.prob(sample) + self._validate_return_type(lnprob) + self._validate_return_type(prob) + # lower precision for jax running tests with float32 + lnprob = np.asarray(lnprob) + prob = np.asarray(prob) + self.assertAlmostEqual(np.log(prob), lnprob, 6) def test_many_prob_and_many_ln_prob(self): for prior in self.priors: - samples = prior.sample(10) - if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa + samples = prior.sample(10, random_state=self.rng) + if bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa + continue + with self.subTest(prior=prior): ln_probs = prior.ln_prob(samples) probs = prior.prob(samples) + self._validate_return_type(ln_probs) + self._validate_return_type(probs) + ln_probs = np.asarray(ln_probs) + probs = np.asarray(probs) for sample, logp, p in zip(samples, ln_probs, probs): - self.assertAlmostEqual(prior.ln_prob(sample), logp) - self.assertAlmostEqual(prior.prob(sample), p) + new_lnprob = np.asarray(prior.ln_prob(sample)) + new_prob = np.asarray(prior.prob(sample)) + self.assertAlmostEqual(new_lnprob, logp, 6) + self.assertAlmostEqual(new_prob, p, 6) def test_cdf_is_inverse_of_rescaling(self): - domain = np.linspace(0, 1, 100) + domain = self.xp.linspace(0, 1, 100) threshold = 1e-9 for prior in self.priors: if ( @@ -385,22 +446,34 @@ def test_cdf_is_inverse_of_rescaling(self): or bilby.core.prior.JointPrior in prior.__class__.__mro__ ): continue - elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): - rescaled = prior.rescale(domain) - cdf_vals = prior.cdf(rescaled) - rescaled_2 = prior.rescale(cdf_vals) - cdf_vals_2 = prior.cdf(rescaled_2) - self.assertTrue(np.array_equal(rescaled, rescaled_2)) - max_difference = max(np.abs(cdf_vals - cdf_vals_2)) - else: - rescaled = prior.rescale(domain) - max_difference = max(np.abs(domain - prior.cdf(rescaled))) - self.assertLess(max_difference, threshold) + elif isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): + # JAX implementation of StudentT prior rescale is not accurate enough + continue + with self.subTest(prior=prior): + if isinstance(prior, bilby.core.prior.WeightedDiscreteValues): + rescaled = prior.rescale(domain) + cdf_vals = prior.cdf(rescaled) + rescaled_2 = prior.rescale(cdf_vals) + cdf_vals_2 = prior.cdf(rescaled_2) + self.assertTrue(np.array_equal(rescaled, rescaled_2)) + max_difference = max(np.abs(cdf_vals - cdf_vals_2)) + for arr in [rescaled, rescaled_2, cdf_vals, cdf_vals_2]: + self._validate_return_type(arr) + else: + rescaled = prior.rescale(domain) + max_difference = max(np.abs(domain - prior.cdf(rescaled))) + self._validate_return_type(rescaled) + self.assertLess(max_difference, threshold) def test_cdf_one_above_domain(self): for prior in self.priors: - if prior.maximum != np.inf: - outside_domain = np.linspace( + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue + if prior.maximum == np.inf: + continue + with self.subTest(prior=prior): + outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 ) self.assertTrue(all(prior.cdf(outside_domain) == 1)) @@ -410,13 +483,18 @@ def test_cdf_zero_below_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if ( bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf ): continue - if prior.minimum != -np.inf: - outside_domain = np.linspace( + if prior.minimum == -np.inf: + continue + with self.subTest(prior=prior): + outside_domain = self.xp.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 ) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) @@ -425,7 +503,8 @@ def test_cdf_float_with_float_input(self): for prior in self.priors: if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue - self.assertIsInstance(prior.cdf(prior.sample()), float) + with self.subTest(prior=prior): + self.assertIsInstance(prior.cdf(prior.sample()), float) def test_log_normal_fail(self): with self.assertRaises(ValueError): @@ -563,12 +642,20 @@ def test_fermidirac_fail(self): def test_probability_in_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: - if prior.minimum == -np.inf: - prior.minimum = -1e5 - if prior.maximum == np.inf: - prior.maximum = 1e5 - domain = np.linspace(prior.minimum, prior.maximum, 1000) - self.assertTrue(all(prior.prob(domain) >= 0)) + with self.subTest(prior=prior): + if prior.minimum == -np.inf: + minimum = -1e5 + else: + minimum = prior.minimum + if prior.maximum == np.inf: + maximum = 1e5 + else: + maximum = prior.maximum + domain = self.xp.linspace(minimum, maximum, 1000) + prob = prior.prob(domain) + self._validate_return_type(prob) + prob = np.asarray(prob) + self.assertTrue(all(prob >= 0)) def test_probability_surrounding_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" @@ -579,18 +666,20 @@ def test_probability_surrounding_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) - indomain = (surround_domain >= prior.minimum) | ( - surround_domain <= prior.maximum - ) - outdomain = (surround_domain < prior.minimum) | ( - surround_domain > prior.maximum - ) + with np.errstate(invalid="ignore"): + surround_domain = self.xp.linspace(prior.minimum - 1, prior.maximum + 1, 1000) + indomain = (surround_domain >= prior.minimum) | ( + surround_domain <= prior.maximum + ) + outdomain = (surround_domain < prior.minimum) | ( + surround_domain > prior.maximum + ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_request(): continue - self.assertTrue(all(prior.prob(surround_domain[indomain]) >= 0)) - self.assertTrue(all(prior.prob(surround_domain[outdomain]) == 0)) + with self.subTest(prior=prior): + self.assertTrue(all(prior.prob(surround_domain[indomain]) >= 0)) + self.assertTrue(all(prior.prob(surround_domain[outdomain]) == 0)) def test_normalized(self): """ @@ -633,11 +722,18 @@ def test_normalized(self): domain = np.linspace(prior.minimum, prior.maximum, 10000) elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): domain = prior.values - self.assertTrue(np.sum(prior.prob(domain)) == 1) continue else: domain = np.linspace(prior.minimum, prior.maximum, 1000) - self.assertAlmostEqual(trapezoid(prior.prob(domain), domain), 1, 3) + with self.subTest(prior=prior): + if isinstance(prior, bilby.core.prior.WeightedDiscreteValues): + probs = prior.prob(self.xp.asarray(domain)) + self._validate_return_type(probs) + self.assertTrue(np.sum(np.asarray(probs)) == 1) + else: + probs = prior.prob(self.xp.asarray(domain)) + self.assertAlmostEqual(trapezoid(np.array(probs), domain), 1, 3) + self._validate_return_type(probs) def test_accuracy(self): """Test that each of the priors' functions is calculated accurately, as compared to scipy's calculations""" @@ -732,27 +828,33 @@ def test_accuracy(self): bilby.core.prior.WeightedDiscreteValues, ) if isinstance(prior, (testTuple)): - np.testing.assert_almost_equal(prior.prob(domain), scipy_prob) - np.testing.assert_almost_equal(prior.ln_prob(domain), scipy_lnprob) - np.testing.assert_almost_equal(prior.cdf(domain), scipy_cdf) - np.testing.assert_almost_equal( - prior.rescale(rescale_domain), scipy_rescale - ) + with self.subTest(prior=prior): + np.testing.assert_almost_equal(prior.prob(self.xp.asarray(domain)), scipy_prob) + np.testing.assert_almost_equal(prior.ln_prob(self.xp.asarray(domain)), scipy_lnprob) + np.testing.assert_almost_equal(prior.cdf(self.xp.asarray(domain)), scipy_cdf) + if isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): + # JAX implementation of StudentT prior rescale is not accurate enough + continue + np.testing.assert_almost_equal( + prior.rescale(self.xp.asarray(rescale_domain)), scipy_rescale + ) def test_unit_setting(self): for prior in self.priors: - if isinstance(prior, bilby.gw.prior.Cosmological): - self.assertEqual(None, prior.unit) - else: - self.assertEqual("unit", prior.unit) + with self.subTest(prior=prior): + if isinstance(prior, bilby.gw.prior.Cosmological): + self.assertEqual(None, prior.unit) + else: + self.assertEqual("unit", prior.unit) def test_eq_different_classes(self): for i in range(len(self.priors)): for j in range(len(self.priors)): - if i == j: - self.assertEqual(self.priors[i], self.priors[j]) - else: - self.assertNotEqual(self.priors[i], self.priors[j]) + with self.subTest(i=self.priors[i], j=self.priors[j]): + if i == j: + self.assertEqual(self.priors[i], self.priors[j]) + else: + self.assertNotEqual(self.priors[i], self.priors[j]) def test_eq_other_condition(self): prior_1 = bilby.core.prior.PowerLaw( @@ -788,6 +890,7 @@ def test_repr(self): repr_prior_string = repr_prior_string.replace( "HealPixMapPriorDist", "bilby.gw.prior.HealPixMapPriorDist" ) + prior.dist.rescale_parameters = {key: None for key in prior.dist.names} elif isinstance(prior, bilby.gw.prior.UniformComovingVolume): repr_prior_string = "bilby.gw.prior." + repr(prior) elif "Conditional" in prior.__class__.__name__: @@ -795,8 +898,9 @@ def test_repr(self): else: repr_prior_string = "bilby.core.prior." + repr(prior) - repr_prior = eval(repr_prior_string, None, dict(inf=np.inf)) - self.assertEqual(prior, repr_prior) + with self.subTest(prior=prior): + repr_prior = eval(repr_prior_string, None, dict(inf=np.inf)) + self.assertEqual(prior, repr_prior) def test_set_maximum_setting(self): for prior in self.priors: @@ -820,8 +924,9 @@ def test_set_maximum_setting(self): ), ): continue - prior.maximum = (prior.maximum + prior.minimum) / 2 - self.assertTrue(max(prior.sample(10000)) < prior.maximum) + with self.subTest(prior=prior): + prior.maximum = (prior.maximum + prior.minimum) / 2 + self.assertTrue(max(prior.sample(10000, random_state=self.rng)) < prior.maximum) def test_set_minimum_setting(self): for prior in self.priors: @@ -846,8 +951,9 @@ def test_set_minimum_setting(self): ), ): continue - prior.minimum = (prior.maximum + prior.minimum) / 2 - self.assertTrue(min(prior.sample(10000)) > prior.minimum) + with self.subTest(prior=prior): + prior.minimum = (prior.maximum + prior.minimum) / 2 + self.assertTrue(min(prior.sample(10000, random_state=self.rng)) > prior.minimum) if __name__ == "__main__": diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index d2cdcc55a..7c5716b8a 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -1,6 +1,9 @@ -import numpy as np import unittest +import array_api_compat as aac +import numpy as np +import pytest + import bilby from bilby.core.prior.slabspike import SlabSpikePrior from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \ @@ -60,13 +63,15 @@ def test_set_spike_height_domain_edge(self): self.prior.spike_height = 1 +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestSlabSpikeClasses(unittest.TestCase): def setUp(self): - self.minimum = 0.4 - self.maximum = 2.4 - self.spike_loc = 1.5 - self.spike_height = 0.3 + self.minimum = self.xp.asarray(0.4) + self.maximum = self.xp.asarray(2.4) + self.spike_loc = self.xp.asarray(1.5) + self.spike_height = self.xp.asarray(0.3) self.slabs = [ Uniform(minimum=self.minimum, maximum=self.maximum), @@ -75,20 +80,22 @@ def setUp(self): TruncatedGaussian(minimum=self.minimum, maximum=self.maximum, mu=0, sigma=1), Beta(minimum=self.minimum, maximum=self.maximum, alpha=1, beta=1), Gaussian(mu=0, sigma=1), - Cosine(), - Sine(), + Cosine(minimum=self.xp.asarray(-np.pi / 2), maximum=self.xp.asarray(np.pi / 2)), + Sine(minimum=self.xp.asarray(0), maximum=self.xp.asarray(np.pi)), HalfGaussian(sigma=1), LogNormal(mu=1, sigma=2), Exponential(mu=2), - StudentT(df=2), Logistic(mu=2, scale=1), Cauchy(alpha=1, beta=2), Gamma(k=1, theta=1.), - ChiSquared(nu=2)] + ChiSquared(nu=2), + ] + if not aac.is_jax_namespace(self.xp): + StudentT(df=2), self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc) for slab in self.slabs] - self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000) - self.test_nodes_infinite_support = np.linspace(-10, 10, 1000) + self.test_nodes_finite_support = self.xp.linspace(self.minimum, self.maximum, 1000) + self.test_nodes_infinite_support = self.xp.linspace(-10, 10, 1000) self.test_nodes = [self.test_nodes_finite_support if np.isinf(slab.minimum) or np.isinf(slab.maximum) else self.test_nodes_finite_support for slab in self.slabs] @@ -107,6 +114,7 @@ def test_prob_on_slab(self): expected = slab.prob(test_nodes) * slab_spike.slab_fraction actual = slab_spike.prob(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_prob_on_spike(self): for slab_spike in self.slab_spikes: @@ -117,10 +125,13 @@ def test_ln_prob_on_slab(self): expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction) actual = slab_spike.ln_prob(test_nodes) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_ln_prob_on_spike(self): for slab_spike in self.slab_spikes: - self.assertEqual(np.inf, slab_spike.ln_prob(self.spike_loc)) + actual = slab_spike.ln_prob(self.spike_loc) + self.assertEqual(np.inf, actual) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_inverse_cdf_below_spike_with_spike_at_minimum(self): for slab in self.slabs: @@ -143,19 +154,22 @@ def test_cdf_below_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction - actual = slab_spike.cdf(self.spike_loc) + actual = slab_spike.cdf(self.xp.asarray(self.spike_loc)) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) - self.assertTrue(np.array_equal(expected, actual)) + np.testing.assert_allclose(expected, actual, rtol=1e-12) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_at_minimum(self): for slab_spike in self.slab_spikes: @@ -172,31 +186,39 @@ def test_cdf_at_maximum(self): def test_rescale_no_spike(self): for slab in self.slabs: slab_spike = SlabSpikePrior(slab=slab, spike_height=0, spike_location=slab.minimum) - vals = np.linspace(0, 1, 1000) + vals = self.xp.linspace(0, 1, 1000) expected = slab.rescale(vals) actual = slab_spike.rescale(vals) - print(slab) + self.assertEqual(aac.get_namespace(actual), self.xp) + self.assertEqual(aac.get_namespace(expected), self.xp) + actual = np.asarray(actual) + expected = np.asarray(expected) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) def test_rescale_below_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) + vals = self.xp.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) expected = slab.rescale(vals / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike, - slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000) - expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) + vals = self.xp.linspace( + slab_spike.inverse_cdf_below_spike, + slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000 + ) + expected = self.xp.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_above_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) - expected = np.ones(len(vals)) * slab.rescale( + vals = self.xp.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) + expected = self.xp.ones(len(vals)) * slab.rescale( (vals - self.spike_height) / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) diff --git a/test/core/result_test.py b/test/core/result_test.py index b8c1106a0..96095c0bb 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -13,6 +13,8 @@ from bilby.core.utils import logger +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestJson(unittest.TestCase): def setUp(self): @@ -28,12 +30,12 @@ def test_list_encoding(self): self.assertTrue(np.all(data["x"] == decoded["x"])) def test_array_encoding(self): - data = dict(x=np.array([1, 2, 3.4])) + data = dict(x=self.xp.asarray([1, 2, 3.4])) encoded = json.dumps(data, cls=self.encoder) decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data["x"]), type(decoded["x"])) - self.assertTrue(np.all(data["x"] == decoded["x"])) + self.assertTrue(self.xp.all(data["x"] == decoded["x"])) def test_complex_encoding(self): data = dict(x=1 + 3j) @@ -918,6 +920,8 @@ def test_reweight_different_likelihood_weights_correct(self): self.assertNotEqual(new.log_evidence, self.result.log_evidence) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestResultSaveAndRead(unittest.TestCase): @pytest.fixture(autouse=True) @@ -943,7 +947,11 @@ def setUp(self): search_parameter_keys=["x", "y"], fixed_parameter_keys=["c", "d"], priors=priors, - sampler_kwargs=dict(test="test", func=lambda x: x), + sampler_kwargs=dict( + test="test", + func=lambda x: x, + some_array=self.xp.ones((5, 5)), + ), injection_parameters=dict(x=0.5, y=0.5), meta_data=dict(test="test"), sampling_time=100.0, diff --git a/test/core/series_test.py b/test/core/series_test.py index bf1b19c43..c2b8dccdb 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -1,15 +1,20 @@ import unittest + +import array_api_compat as aac import numpy as np +import pytest from bilby.core.utils import create_frequency_series, create_time_series from bilby.core.series import CoupledTimeAndFrequencySeries +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestCoupledTimeAndFrequencySeries(unittest.TestCase): def setUp(self): - self.duration = 2 - self.sampling_frequency = 4096 - self.start_time = -1 + self.duration = self.xp.asarray(2.0) + self.sampling_frequency = self.xp.asarray(4096.0) + self.start_time = self.xp.asarray(-1.0) self.series = CoupledTimeAndFrequencySeries( duration=self.duration, sampling_frequency=self.sampling_frequency, @@ -43,10 +48,10 @@ def test_start_time_from_init(self): self.assertEqual(self.start_time, self.series.start_time) def test_frequency_array_type(self): - self.assertIsInstance(self.series.frequency_array, np.ndarray) + self.assertEqual(aac.get_namespace(self.series.frequency_array), self.xp) def test_time_array_type(self): - self.assertIsInstance(self.series.time_array, np.ndarray) + self.assertEqual(aac.get_namespace(self.series.time_array), self.xp) def test_frequency_array_from_init(self): expected = create_frequency_series( @@ -63,8 +68,8 @@ def test_time_array_from_init(self): self.assertTrue(np.array_equal(expected, self.series.time_array)) def test_frequency_array_setter(self): - new_sampling_frequency = 100 - new_duration = 3 + new_sampling_frequency = self.xp.asarray(100.0) + new_duration = self.xp.asarray(3.0) new_frequency_array = create_frequency_series( sampling_frequency=new_sampling_frequency, duration=new_duration ) @@ -79,9 +84,9 @@ def test_frequency_array_setter(self): self.assertAlmostEqual(self.start_time, self.series.start_time) def test_time_array_setter(self): - new_sampling_frequency = 100 - new_duration = 3 - new_start_time = 4 + new_sampling_frequency = self.xp.asarray(100.0) + new_duration = self.xp.asarray(3.0) + new_start_time = self.xp.asarray(4.0) new_time_array = create_time_series( sampling_frequency=new_sampling_frequency, duration=new_duration, @@ -90,31 +95,31 @@ def test_time_array_setter(self): self.series.time_array = new_time_array self.assertTrue(np.array_equal(new_time_array, self.series.time_array)) self.assertAlmostEqual( - new_sampling_frequency, self.series.sampling_frequency, places=1 + np.asarray(new_sampling_frequency), np.asarray(self.series.sampling_frequency), places=1 ) - self.assertAlmostEqual(new_duration, self.series.duration, places=1) - self.assertAlmostEqual(new_start_time, self.series.start_time, places=1) + self.assertAlmostEqual(np.asarray(new_duration), np.asarray(self.series.duration), places=1) + self.assertAlmostEqual(np.asarray(new_start_time), np.asarray(self.series.start_time), places=1) def test_time_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = 4 + self.series.duration = self.xp.asarray(4) with self.assertRaises(ValueError): _ = self.series.time_array def test_time_array_without_duration(self): - self.series.sampling_frequency = 4096 + self.series.sampling_frequency = self.xp.asarray(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.time_array def test_frequency_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = 4 + self.series.duration = self.xp.asarray(4) with self.assertRaises(ValueError): _ = self.series.frequency_array def test_frequency_array_without_duration(self): - self.series.sampling_frequency = 4096 + self.series.sampling_frequency = self.xp.asarray(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.frequency_array diff --git a/test/core/utils_test.py b/test/core/utils_test.py index df46d6bb3..d8a78beee 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -1,6 +1,8 @@ import unittest import os +import array_api_compat as aac +import array_api_extra as xpx import dill import numpy as np from astropy import constants @@ -49,35 +51,42 @@ def test_gravitational_constant(self): self.assertEqual(bilby.core.utils.gravitational_constant, lal.G_SI) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestFFT(unittest.TestCase): def setUp(self): - self.sampling_frequency = 10 + self.sampling_frequency = self.xp.asarray(10) def tearDown(self): del self.sampling_frequency def test_nfft_sine_function(self): - injected_frequency = 2.7324 - duration = 100 - times = utils.create_time_series(self.sampling_frequency, duration) + xp = self.xp + injected_frequency = xp.asarray(2.7324) + duration = xp.asarray(100) + times = utils.create_time_series(xp.asarray(self.sampling_frequency), duration) - time_domain_strain = np.sin(2 * np.pi * times * injected_frequency + 0.4) + time_domain_strain = xp.sin(2 * np.pi * times * injected_frequency + 0.4) frequency_domain_strain, frequencies = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) - frequency_at_peak = frequencies[np.argmax(np.abs(frequency_domain_strain))] + frequency_at_peak = frequencies[xp.argmax(abs(frequency_domain_strain))] + self.assertEqual(aac.get_namespace(frequency_at_peak), xp) + frequency_at_peak = np.asarray(frequency_at_peak) + injected_frequency = np.asarray(injected_frequency) self.assertAlmostEqual(injected_frequency, frequency_at_peak, places=1) def test_nfft_infft(self): - time_domain_strain = np.random.normal(0, 1, 10) + xp = self.xp + time_domain_strain = xp.asarray(np.random.normal(0, 1, 10)) frequency_domain_strain, _ = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) new_time_domain_strain = bilby.core.utils.infft( frequency_domain_strain, self.sampling_frequency ) - self.assertTrue(np.allclose(time_domain_strain, new_time_domain_strain)) + self.assertTrue(xp.allclose(time_domain_strain, new_time_domain_strain)) class TestInferParameters(unittest.TestCase): @@ -119,11 +128,13 @@ def test_self_handling_method_as_function(self): self.assertListEqual(expected, actual) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTimeAndFrequencyArrays(unittest.TestCase): def setUp(self): - self.start_time = 1.3 - self.sampling_frequency = 5 - self.duration = 1.6 + self.start_time = self.xp.asarray(1.3) + self.sampling_frequency = self.xp.asarray(5) + self.duration = self.xp.asarray(1.6) self.frequency_array = utils.create_frequency_series( sampling_frequency=self.sampling_frequency, duration=self.duration ) @@ -141,12 +152,13 @@ def tearDown(self): del self.time_array def test_create_time_array(self): - expected_time_array = np.array([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) + expected_time_array = self.xp.asarray([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) time_array = utils.create_time_series( sampling_frequency=self.sampling_frequency, duration=self.duration, starting_time=self.start_time, ) + self.assertEqual(aac.get_namespace(time_array), self.xp) self.assertTrue(np.allclose(expected_time_array, time_array)) def test_create_frequency_array(self): @@ -164,7 +176,7 @@ def test_get_sampling_frequency_from_time_array(self): self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_time_array_unequally_sampled(self): - self.time_array[-1] += 0.0001 + self.time_array = xpx.at(self.time_array, -1).set(self.time_array[-1] + 0.0001) with self.assertRaises(ValueError): _, _ = utils.get_sampling_frequency_and_duration_from_time_array( self.time_array @@ -190,7 +202,9 @@ def test_get_sampling_frequency_from_frequency_array(self): self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self): - self.frequency_array[-1] += 0.0001 + self.frequency_array = xpx.at( + self.frequency_array, -1 + ).set(self.frequency_array[-1] + 0.0001) with self.assertRaises(ValueError): _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array( self.frequency_array @@ -233,34 +247,38 @@ def test_consistency_frequency_array_to_frequency_array(self): def test_illegal_sampling_frequency_and_duration(self): with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): _ = utils.create_time_series( - sampling_frequency=7.7, duration=1.3, starting_time=0 + sampling_frequency=self.xp.asarray(7.7), + duration=self.xp.asarray(1.3), + starting_time=self.xp.asarray(0), ) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestReflect(unittest.TestCase): def test_in_range(self): - xprime = np.array([0.1, 0.5, 0.9]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([0.1, 0.5, 0.9]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_one_to_two(self): - xprime = np.array([1.1, 1.5, 1.9]) - x = np.array([0.9, 0.5, 0.1]) + xprime = self.xp.asarray([1.1, 1.5, 1.9]) + x = self.xp.asarray([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_two_to_three(self): - xprime = np.array([2.1, 2.5, 2.9]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([2.1, 2.5, 2.9]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_one_to_zero(self): - xprime = np.array([-0.9, -0.5, -0.1]) - x = np.array([0.9, 0.5, 0.1]) + xprime = self.xp.asarray([-0.9, -0.5, -0.1]) + x = self.xp.asarray([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_two_to_minus_one(self): - xprime = np.array([-1.9, -1.5, -1.1]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([-1.9, -1.5, -1.1]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) @@ -325,8 +343,12 @@ def plot(): self.assertTrue(os.path.isfile(self.filename)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestUnsortedInterp2d(unittest.TestCase): def setUp(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("Skipping Interp2d tests for torch backend") self.xx = np.linspace(0, 1, 10) self.yy = np.linspace(0, 1, 10) self.zz = np.random.random((10, 10)) @@ -343,36 +365,42 @@ def test_returns_none_for_floats_outside_range(self): self.assertIsNone(self.interpolant(-0.5, 0.5)) def test_returns_float_for_float_and_array(self): - self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray) - self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray) - self.assertIsInstance( - self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray + input_array = self.xp.asarray(np.random.random(10)) + self.assertEqual(aac.get_namespace(self.interpolant(input_array, 0.5)), self.xp) + self.assertEqual(aac.get_namespace( + self.interpolant(input_array, input_array)), self.xp ) + self.assertEqual(aac.get_namespace(self.interpolant(0.5, input_array)), self.xp) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): - self.interpolant(np.random.random(10), np.random.random(20)) + self.interpolant( + self.xp.asarray(np.random.random(10)), + self.xp.asarray(np.random.random(20)), + ) def test_returns_fill_in_correct_place(self): - x_data = np.random.random(10) - y_data = np.random.random(10) - x_data[3] = -1 - self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3])) + x_data = self.xp.asarray(np.random.random(10)) + y_data = self.xp.asarray(np.random.random(10)) + x_data = xpx.at(x_data, 3).set(-1) + self.assertTrue(self.xp.isnan(self.interpolant(x_data, y_data)[3])) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTrapeziumRuleIntegration(unittest.TestCase): def setUp(self): - self.x = np.linspace(0, 1, 100) - self.dxs = np.diff(self.x) + self.x = self.xp.linspace(0, 1, 100) + self.dxs = self.xp.diff(self.x) self.dx = self.dxs[0] with np.errstate(divide="ignore"): - self.lnfunc1 = np.log(self.x) + self.lnfunc1 = self.xp.log(self.x) self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2 with np.errstate(divide="ignore"): - self.lnfunc2 = np.log(self.x ** 2) + self.lnfunc2 = self.xp.log(self.x ** 2) self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 - self.irregularx = np.array( + self.irregularx = self.xp.asarray( [ self.x[0], self.x[12], @@ -390,9 +418,9 @@ def setUp(self): ] ) with np.errstate(divide="ignore"): - self.lnfunc1irregular = np.log(self.irregularx) - self.lnfunc2irregular = np.log(self.irregularx ** 2) - self.irregulardxs = np.diff(self.irregularx) + self.lnfunc1irregular = self.xp.log(self.irregularx) + self.lnfunc2irregular = self.xp.log(self.irregularx ** 2) + self.irregulardxs = self.xp.diff(self.irregularx) def test_incorrect_step_type(self): with self.assertRaises(TypeError): @@ -407,19 +435,19 @@ def test_integral_func1(self): res2 = utils.logtrapzexp(self.lnfunc1, self.dxs) self.assertTrue(np.abs(res1 - res2) < 1e-12) - self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12) + self.assertTrue(np.abs((self.xp.exp(res1) - self.func1int) / self.func1int) < 1e-12) def test_integral_func2(self): res = utils.logtrapzexp(self.lnfunc2, self.dxs) - self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4) + self.assertTrue(np.abs((self.xp.exp(res) - self.func2int) / self.func2int) < 1e-4) def test_integral_func1_irregular_steps(self): res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs) - self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12) + self.assertTrue(np.abs((self.xp.exp(res) - self.func1int) / self.func1int) < 1e-12) def test_integral_func2_irregular_steps(self): res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs) - self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2) + self.assertTrue(np.abs((self.xp.exp(res) - self.func2int) / self.func2int) < 1e-2) class TestSavingNumpyRandomGenerator(unittest.TestCase): diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index fc0f4321a..9d2d46b48 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -1,25 +1,29 @@ import unittest +import array_api_compat as aac import numpy as np import pandas as pd +import pytest import bilby from bilby.gw import conversion +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestBasicConversions(unittest.TestCase): def setUp(self): - self.mass_1 = 1.4 - self.mass_2 = 1.3 - self.mass_ratio = 13 / 14 - self.total_mass = 2.7 - self.chirp_mass = (1.4 * 1.3) ** 0.6 / 2.7 ** 0.2 - self.symmetric_mass_ratio = (1.4 * 1.3) / 2.7 ** 2 - self.cos_angle = -1 - self.angle = np.pi - self.lambda_1 = 300 - self.lambda_2 = 300 * (14 / 13) ** 5 - self.lambda_tilde = ( + self.mass_1 = self.xp.asarray(1.4) + self.mass_2 = self.xp.asarray(1.3) + self.mass_ratio = self.xp.asarray(13 / 14) + self.total_mass = self.xp.asarray(2.7) + self.chirp_mass = (self.mass_1 * self.mass_2) ** 0.6 / self.total_mass ** 0.2 + self.symmetric_mass_ratio = (self.mass_1 * self.mass_2) / self.total_mass ** 2 + self.cos_angle = self.xp.asarray(-1.0) + self.angle = self.xp.pi + self.lambda_1 = self.xp.asarray(300.0) + self.lambda_2 = self.xp.asarray(300.0 * (14 / 13) ** 5) + self.lambda_tilde = self.xp.asarray( 8 / 13 * ( @@ -38,7 +42,7 @@ def setUp(self): * (self.lambda_1 - self.lambda_2) ) ) - self.delta_lambda_tilde = ( + self.delta_lambda_tilde = self.xp.asarray( 1 / 2 * ( @@ -74,30 +78,36 @@ def test_total_mass_and_mass_ratio_to_component_masses(self): self.assertTrue( all([abs(mass_1 - self.mass_1) < 1e-5, abs(mass_2 - self.mass_2) < 1e-5]) ) + self.assertEqual(aac.get_namespace(mass_1), self.xp) + self.assertEqual(aac.get_namespace(mass_2), self.xp) def test_chirp_mass_and_primary_mass_to_mass_ratio(self): mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio( self.chirp_mass, self.mass_1 ) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_symmetric_mass_ratio_to_mass_ratio(self): mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio( self.symmetric_mass_ratio ) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.chirp_mass_and_total_mass_to_symmetric_mass_ratio( self.chirp_mass, self.total_mass ) - self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertAlmostEqual(float(self.symmetric_mass_ratio), float(symmetric_mass_ratio)) + self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_chirp_mass_and_mass_ratio_to_total_mass(self): total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass( self.chirp_mass, self.mass_ratio ) - self.assertAlmostEqual(self.total_mass, total_mass) + self.assertAlmostEqual(float(self.total_mass), float(total_mass)) + self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_chirp_mass_and_mass_ratio_to_component_masses(self): mass_1, mass_2 = \ @@ -105,30 +115,37 @@ def test_chirp_mass_and_mass_ratio_to_component_masses(self): self.chirp_mass, self.mass_ratio) self.assertAlmostEqual(self.mass_1, mass_1) self.assertAlmostEqual(self.mass_2, mass_2) + self.assertEqual(aac.get_namespace(mass_1), self.xp) + self.assertEqual(aac.get_namespace(mass_2), self.xp) def test_component_masses_to_chirp_mass(self): chirp_mass = conversion.component_masses_to_chirp_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.chirp_mass, chirp_mass) + self.assertEqual(aac.get_namespace(chirp_mass), self.xp) def test_component_masses_to_total_mass(self): total_mass = conversion.component_masses_to_total_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.total_mass, total_mass) + self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_component_masses_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.component_masses_to_symmetric_mass_ratio( self.mass_1, self.mass_2 ) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_component_masses_to_mass_ratio(self): mass_ratio = conversion.component_masses_to_mass_ratio(self.mass_1, self.mass_2) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_mass_1_and_chirp_mass_to_mass_ratio(self): mass_ratio = conversion.mass_1_and_chirp_mass_to_mass_ratio( self.mass_1, self.chirp_mass ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_lambda_tilde_to_lambda_1_lambda_2(self): lambda_1, lambda_2 = conversion.lambda_tilde_to_lambda_1_lambda_2( @@ -142,6 +159,8 @@ def test_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) + self.assertEqual(aac.get_namespace(lambda_1), self.xp) + self.assertEqual(aac.get_namespace(lambda_2), self.xp) def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ( @@ -158,18 +177,22 @@ def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) + self.assertEqual(aac.get_namespace(lambda_1), self.xp) + self.assertEqual(aac.get_namespace(lambda_2), self.xp) def test_lambda_1_lambda_2_to_lambda_tilde(self): lambda_tilde = conversion.lambda_1_lambda_2_to_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.lambda_tilde - lambda_tilde) < 1e-5) + self.assertEqual(aac.get_namespace(lambda_tilde), self.xp) def test_lambda_1_lambda_2_to_delta_lambda_tilde(self): delta_lambda_tilde = conversion.lambda_1_lambda_2_to_delta_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.delta_lambda_tilde - delta_lambda_tilde) < 1e-5) + self.assertEqual(aac.get_namespace(delta_lambda_tilde), self.xp) def test_identity_conversion(self): original_samples = dict( @@ -600,18 +623,20 @@ def test_comoving_luminosity_with_cosmology(self): self.assertAlmostEqual(max(abs(dl - self.distances)), 0, 4) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGenerateMassParameters(unittest.TestCase): def setUp(self): - self.expected_values = {'mass_1': 2.0, - 'mass_2': 1.0, - 'chirp_mass': 1.2167286837864113, - 'total_mass': 3.0, - 'mass_1_source': 4.0, - 'mass_2_source': 2.0, - 'chirp_mass_source': 2.433457367572823, - 'total_mass_source': 6, - 'symmetric_mass_ratio': 0.2222222222222222, - 'mass_ratio': 0.5} + self.expected_values = {'mass_1': self.xp.asarray(2.0), + 'mass_2': self.xp.asarray(1.0), + 'chirp_mass': self.xp.asarray(1.2167286837864113), + 'total_mass': self.xp.asarray(3.0), + 'mass_1_source': self.xp.asarray(4.0), + 'mass_2_source': self.xp.asarray(2.0), + 'chirp_mass_source': self.xp.asarray(2.433457367572823), + 'total_mass_source': self.xp.asarray(6), + 'symmetric_mass_ratio': self.xp.asarray(0.2222222222222222), + 'mass_ratio': self.xp.asarray(0.5)} def helper_generation_from_keys(self, keys, expected_values, source=False): # Explicitly test the helper generate_component_masses @@ -627,8 +652,8 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): self.assertTrue("mass_2" in local_test_vars_with_component_masses.keys()) for key in local_test_vars_with_component_masses.keys(): self.assertAlmostEqual( - local_test_vars_with_component_masses[key], - self.expected_values[key]) + np.asarray(local_test_vars_with_component_masses[key]), + np.asarray(self.expected_values[key])) # Test the function more generally local_all_mass_parameters = \ @@ -658,7 +683,14 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): ) ) for key in local_all_mass_parameters.keys(): - self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key]) + self.assertAlmostEqual( + np.asarray(expected_values[key]), + np.asarray(local_all_mass_parameters[key]), + ) + self.assertEqual( + aac.get_namespace(local_all_mass_parameters[key]), + self.xp, + ) def test_from_mass_1_and_mass_2(self): self.helper_generation_from_keys(["mass_1", "mass_2"], @@ -725,6 +757,8 @@ def test_from_chirp_mass_source_and_symmetric_mass_2(self): self.expected_values, source=True) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestEquationOfStateConversions(unittest.TestCase): ''' Class to test equation of state conversions. @@ -733,48 +767,48 @@ class TestEquationOfStateConversions(unittest.TestCase): ''' def setUp(self): - self.mass_1_source_spectral = [ + self.mass_1_source_spectral = self.xp.asarray([ 4.922542724434885, 4.350626907771598, 4.206155335439082, 1.7822696459661311, 1.3091740103047926 - ] - self.mass_2_source_spectral = [ + ]) + self.mass_2_source_spectral = self.xp.asarray([ 3.459974694590303, 1.2276461777181447, 3.7287707089639976, 0.3724016563531846, 1.055042934805801 - ] - self.spectral_pca_gamma_0 = [ + ]) + self.spectral_pca_gamma_0 = self.xp.asarray([ 0.7074873121348357, 0.05855931126849878, 0.7795329261793462, 1.467907561566463, 2.9066488405635624 - ] - self.spectral_pca_gamma_1 = [ + ]) + self.spectral_pca_gamma_1 = self.xp.asarray([ -0.29807111670823816, 2.027708558522935, -1.4415775226512115, -0.7104870098896858, -0.4913817181089619 - ] - self.spectral_pca_gamma_2 = [ + ]) + self.spectral_pca_gamma_2 = self.xp.asarray([ 0.25625095371021156, -0.19574096643220049, -0.2710238103460012, 0.22815820981582358, -0.1543413205016374 - ] - self.spectral_pca_gamma_3 = [ + ]) + self.spectral_pca_gamma_3 = self.xp.asarray([ -0.04030365100175101, 0.05698030777919032, -0.045595911403040264, -0.023480394227900117, -0.07114492992285618 - ] + ]) self.spectral_gamma_0 = [ 1.1259406796075457, 0.3191335618787259, @@ -875,10 +909,12 @@ def test_spectral_pca_to_spectral(self): self.spectral_pca_gamma_2[i], self.spectral_pca_gamma_3[i] ) - self.assertAlmostEqual(spectral_gamma_0, self.spectral_gamma_0[i], places=5) - self.assertAlmostEqual(spectral_gamma_1, self.spectral_gamma_1[i], places=5) - self.assertAlmostEqual(spectral_gamma_2, self.spectral_gamma_2[i], places=5) - self.assertAlmostEqual(spectral_gamma_3, self.spectral_gamma_3[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_0), self.spectral_gamma_0[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_1), self.spectral_gamma_1[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_2), self.spectral_gamma_2[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_3), self.spectral_gamma_3[i], places=5) + for val in [spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3]: + self.assertEqual(aac.get_namespace(val), self.xp) def test_spectral_params_to_lambda_1_lambda_2(self): ''' @@ -906,8 +942,8 @@ def test_spectral_params_to_lambda_1_lambda_2(self): self.mass_1_source_spectral[i], self.mass_2_source_spectral[i] ) - self.assertAlmostEqual(self.lambda_1_spectral[i], lambda_1, places=0) - self.assertAlmostEqual(self.lambda_2_spectral[i], lambda_2, places=0) + self.assertAlmostEqual(self.lambda_1_spectral[i], float(lambda_1), places=0) + self.assertAlmostEqual(self.lambda_2_spectral[i], float(lambda_2), places=0) self.assertAlmostEqual(self.eos_check_spectral[i], eos_check) def test_polytrope_or_causal_params_to_lambda_1_lambda_2_causal(self): diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index 358825b23..7340a5f8d 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -1,11 +1,15 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np +import pytest import bilby +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestInterferometerGeometry(unittest.TestCase): def setUp(self): self.length = 30 @@ -26,6 +30,7 @@ def setUp(self): xarm_tilt=self.xarm_tilt, yarm_tilt=self.yarm_tilt, ) + self.geometry.set_array_backend(self.xp) def tearDown(self): del self.length @@ -40,27 +45,35 @@ def tearDown(self): def test_length_setting(self): self.assertEqual(self.geometry.length, self.length) + self.assertEqual(aac.get_namespace(self.geometry.length), self.xp) def test_latitude_setting(self): self.assertEqual(self.geometry.latitude, self.latitude) + self.assertEqual(aac.get_namespace(self.geometry.latitude), self.xp) def test_longitude_setting(self): self.assertEqual(self.geometry.longitude, self.longitude) + self.assertEqual(aac.get_namespace(self.geometry.longitude), self.xp) def test_elevation_setting(self): self.assertEqual(self.geometry.elevation, self.elevation) + self.assertEqual(aac.get_namespace(self.geometry.elevation), self.xp) def test_xarm_azi_setting(self): self.assertEqual(self.geometry.xarm_azimuth, self.xarm_azimuth) + self.assertEqual(aac.get_namespace(self.geometry.xarm_azimuth), self.xp) def test_yarm_azi_setting(self): self.assertEqual(self.geometry.yarm_azimuth, self.yarm_azimuth) + self.assertEqual(aac.get_namespace(self.geometry.yarm_azimuth), self.xp) def test_xarm_tilt_setting(self): self.assertEqual(self.geometry.xarm_tilt, self.xarm_tilt) + self.assertEqual(aac.get_namespace(self.geometry.xarm_tilt), self.xp) def test_yarm_tilt_setting(self): self.assertEqual(self.geometry.yarm_tilt, self.yarm_tilt) + self.assertEqual(aac.get_namespace(self.geometry.yarm_tilt), self.xp) def test_vertex_without_update(self): _ = self.geometry.vertex @@ -141,32 +154,38 @@ def test_y_with_latitude_update(self): def test_detector_tensor_with_x_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.xarm_azimuth += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.yarm_azimuth += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_x_tilt_update(self): original = self.geometry.detector_tensor self.geometry.xarm_tilt += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_tilt_update(self): original = self.geometry.detector_tensor self.geometry.yarm_tilt += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_longitude_update(self): original = self.geometry.detector_tensor self.geometry.longitude += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_latitude_update(self): original = self.geometry.detector_tensor self.geometry.latitude += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_unit_vector_along_arm_default(self): with self.assertRaises(ValueError): @@ -177,17 +196,20 @@ def test_unit_vector_along_arm_x(self): self.geometry.latitude = 0 self.geometry.xarm_tilt = 0 self.geometry.xarm_azimuth = 0 + self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("x") self.assertTrue(np.allclose(arm, np.array([0, 1, 0]))) + self.assertEqual(aac.get_namespace(arm), self.xp) def test_unit_vector_along_arm_y(self): self.geometry.longitude = 0 self.geometry.latitude = 0 self.geometry.yarm_tilt = 0 self.geometry.yarm_azimuth = 90 + self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("y") - print(arm) self.assertTrue(np.allclose(arm, np.array([0, 0, 1]))) + self.assertEqual(aac.get_namespace(arm), self.xp) def test_repr(self): expected = ( diff --git a/test/gw/detector/networks_test.py b/test/gw/detector/networks_test.py index 942bd882f..061f75571 100644 --- a/test/gw/detector/networks_test.py +++ b/test/gw/detector/networks_test.py @@ -203,7 +203,9 @@ def test_set_strain_data_from_power_spectral_density(self, m): self.ifo_list.set_strain_data_from_power_spectral_densities( sampling_frequency=123, duration=6.2, start_time=3 ) - m.assert_called_with(sampling_frequency=123, duration=6.2, start_time=3) + m.assert_called_with( + sampling_frequency=123, duration=6.2, start_time=3, random_state=None + ) self.assertEqual(len(self.ifo_list), m.call_count) def test_inject_signal_pol_and_wg_none(self): diff --git a/test/gw/likelihood/marginalization_test.py b/test/gw/likelihood/marginalization_test.py index 351e516f8..b5da6ba16 100644 --- a/test/gw/likelihood/marginalization_test.py +++ b/test/gw/likelihood/marginalization_test.py @@ -3,6 +3,7 @@ import pytest import unittest from copy import deepcopy +from functools import cached_property from itertools import product from parameterized import parameterized @@ -230,54 +231,63 @@ def setUp(self): maximum=self.parameters["geocent_time"] + 0.1 ) - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, + minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", - frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"), - frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"), ) ) - self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy" - self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" - self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.multiband_waveform_generator = bilby.gw.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, - minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", ) ) - self.multiband_waveform_generator = bilby.gw.WaveformGenerator( + @property + def roq_dir(self): + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) + for path in trial_roq_paths: + if os.path.isdir(path): + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + + @property + def roq_linear_matrix_file(self): + return f"{self.roq_dir}/B_linear.npy" + + @property + def roq_quadratic_matrix_file(self): + return f"{self.roq_dir}/B_quadratic.npy" + + @cached_property + def roq_waveform_generator(self): + return bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, waveform_approximant="IMRPhenomPv2", + frequency_nodes_linear=np.load(f"{self.roq_dir}/fnodes_linear.npy"), + frequency_nodes_quadratic=np.load(f"{self.roq_dir}/fnodes_quadratic.npy"), ) ) @@ -287,7 +297,6 @@ def tearDown(self): del self.parameters del self.interferometers del self.waveform_generator - del self.roq_waveform_generator del self.priors @classmethod diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 9d7a7e36f..a2f3c3cf1 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1,17 +1,61 @@ import os import unittest import tempfile +from functools import cached_property from itertools import product from parameterized import parameterized import pytest from copy import deepcopy +import array_api_compat as aac import h5py import numpy as np import bilby +from array_api_compat import is_array_api_obj from bilby.gw.likelihood import BilbyROQParamsRangeError +class BackendWaveformGenerator(bilby.gw.waveform_generator.WaveformGenerator): + """ + A thin wrapper to emulate different backends in the waveform generator. + + This ensures that all frequency arrays that might be used inside the + source are cast to numpy for compatibility. The outputs are converted + to the appropriate array type. + """ + def __init__(self, wfg, xp): + self.wfg = wfg + self.xp = xp + + def __getattr__(self, name): + if name == "xp": + return self.xp + return getattr(self.wfg, name) + + def convert_nested_dict(self, data): + if is_array_api_obj(data): + return self.xp.asarray(data) + elif isinstance(data, dict): + return {key: self.convert_nested_dict(value) for key, value in data.items()} + else: + raise ValueError("Input must be an array API object or a dict of such objects.") + + def _strain_from_model(self, model_data_points, model, parameters, *, xp=None): + model_data_points = np.asarray(model_data_points) + return super()._strain_from_model(model_data_points, model, parameters) + + def frequency_domain_strain(self, parameters): + self.wfg.frequency_array = np.asarray(self.wfg.frequency_array) + if "frequency_nodes" in self.wfg.waveform_arguments: + self.wfg.waveform_arguments["frequency_nodes"] = np.asarray( + self.wfg.waveform_arguments["frequency_nodes"] + ) + wf = self.wfg.__class__.frequency_domain_strain(self, parameters) + return self.convert_nested_dict(wf) + + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestBasicGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) @@ -26,21 +70,23 @@ def setUp(self): phi_jl=0.3, luminosity_distance=4000.0, theta_jn=0.4, - psi=2.659, + psi=self.xp.asarray(2.659), phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, + geocent_time=self.xp.asarray(1126259642.413), + ra=self.xp.asarray(1.375), + dec=self.xp.asarray(-1.2108), ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) self.interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=2048, duration=4 + sampling_frequency=self.xp.asarray(2048.0), duration=self.xp.asarray(4.0) ) - self.waveform_generator = bilby.gw.waveform_generator.GWSignalWaveformGenerator( - duration=4, - sampling_frequency=2048, + self.interferometers.set_array_backend(self.xp) + base_wfg = bilby.gw.waveform_generator.GWSignalWaveformGenerator( + duration=self.xp.asarray(4.0), + sampling_frequency=self.xp.asarray(2048.0), waveform_arguments=dict(waveform_approximant="IMRPhenomPv2"), ) + self.waveform_generator = BackendWaveformGenerator(base_wfg, self.xp) self.likelihood = bilby.gw.likelihood.BasicGravitationalWaveTransient( interferometers=self.interferometers, @@ -55,23 +101,27 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" - self.likelihood.noise_log_likelihood() + nll = self.likelihood.noise_log_likelihood() self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 + -4014.1787704539474, float(nll), 3 ) + self.assertEqual(aac.get_namespace(nll), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" - self.likelihood.log_likelihood(self.parameters) - self.assertAlmostEqual(self.likelihood.log_likelihood(self.parameters), -4032.4397343470005, 3) + logl = self.likelihood.log_likelihood(self.parameters) + self.assertAlmostEqual(float(logl), -4032.4397343470005, 3) + self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" + llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), - self.likelihood.log_likelihood_ratio(self.parameters), + llr, 3, ) + self.assertEqual(aac.get_namespace(llr), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the @@ -86,11 +136,13 @@ def test_repr(self): self.assertEqual(expected, repr(self.likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) - self.duration = 4 - self.sampling_frequency = 2048 + self.duration = self.xp.asarray(4.0) + self.sampling_frequency = self.xp.asarray(2048.0) self.parameters = dict( mass_1=31.0, mass_2=29.0, @@ -102,21 +154,23 @@ def setUp(self): phi_jl=0.3, luminosity_distance=4000.0, theta_jn=0.4, - psi=2.659, - phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, + psi=self.xp.asarray(2.659), + phase=self.xp.asarray(1.3), + geocent_time=self.xp.asarray(1126259642.413), + ra=self.xp.asarray(1.375), + dec=self.xp.asarray(-1.2108), ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) self.interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration ) - self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.interferometers.set_array_backend(self.xp) + wfg = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, ) + self.waveform_generator = BackendWaveformGenerator(wfg, self.xp) self.prior = bilby.gw.prior.BBHPriorDict() self.prior["geocent_time"] = bilby.prior.Uniform( @@ -139,24 +193,27 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" - self.likelihood.noise_log_likelihood() + nll = self.likelihood.noise_log_likelihood() + self.assertEqual(aac.get_namespace(nll), self.xp) self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 + -4014.1787704539474, float(nll), 3 ) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" - self.likelihood.log_likelihood(self.parameters) - self.assertAlmostEqual(self.likelihood.log_likelihood(self.parameters), - -4032.4397343470005, 3) + logl = self.likelihood.log_likelihood(self.parameters) + self.assertAlmostEqual(float(logl), -4032.4397343470005, 3) + self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" + llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( - self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), - self.likelihood.log_likelihood_ratio(self.parameters), + float(self.likelihood.log_likelihood(self.parameters)) - float(self.likelihood.noise_log_likelihood()), + float(llr), 3, ) + self.assertEqual(aac.get_namespace(llr), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the @@ -236,14 +293,16 @@ def test_reference_frame_agrees_with_default(self): ) parameters = self.parameters.copy() del parameters["ra"], parameters["dec"] - parameters["zenith"] = 1.0 - parameters["azimuth"] = 1.0 + parameters["zenith"] = self.xp.asarray(1.0) + parameters["azimuth"] = self.xp.asarray(1.0) parameters["ra"], parameters["dec"] = bilby.gw.utils.zenith_azimuth_to_ra_dec( zenith=parameters["zenith"], azimuth=parameters["azimuth"], geocent_time=parameters["geocent_time"], - ifos=bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos=new_likelihood.reference_frame, ) + self.assertEqual(aac.get_namespace(parameters["ra"]), self.xp) + self.assertEqual(aac.get_namespace(parameters["dec"]), self.xp) self.assertEqual( new_likelihood.log_likelihood_ratio(parameters), self.likelihood.log_likelihood_ratio(parameters) @@ -264,42 +323,39 @@ def test_time_reference_agrees_with_default(self): ) parameters = self.parameters.copy() parameters["H1_time"] = parameters["geocent_time"] + time_delay - self.assertEqual( + self.assertAlmostEqual( new_likelihood.log_likelihood_ratio(parameters), - self.likelihood.log_likelihood_ratio(parameters) + self.likelihood.log_likelihood_ratio(parameters), + 8, ) -@pytest.mark.requires_roqs -class TestROQLikelihood(unittest.TestCase): - def setUp(self): - self.duration = 4 - self.sampling_frequency = 2048 +class ROQBasisMixin: - # Possible locations for the ROQ: in the docker image, local, or on CIT + @property + def roq_dir(self): trial_roq_paths = [ "/roq_basis", os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), "/home/cbc/ROQ_data/IMRPhenomPv2/4s", ] - roq_dir = None + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) for path in trial_roq_paths: if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) - fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) - fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) +@pytest.mark.requires_roqs +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") +@pytest.mark.flaky(reruns=3) # pyfftw is flake on some machines +class TestROQLikelihood(ROQBasisMixin, unittest.TestCase): + def setUp(self): + self.duration = self.xp.asarray(4.0) + self.sampling_frequency = self.xp.asarray(2048.0) + bilby.core.utils.random.seed(500) self.test_parameters = dict( mass_1=36.0, @@ -312,17 +368,18 @@ def setUp(self): phi_jl=0.3, luminosity_distance=1000.0, theta_jn=0.4, - psi=0.659, + psi=self.xp.asarray(0.659), phase=1.3, - geocent_time=1.2, - ra=1.3, - dec=-1.2, + geocent_time=self.xp.asarray(1.2), + ra=self.xp.asarray(1.3), + dec=self.xp.asarray(-1.2), ) ifos = bilby.gw.detector.InterferometerList(["H1"]) ifos.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration ) + ifos.set_array_backend(self.xp) self.priors = bilby.gw.prior.BBHPriorDict() self.priors.pop("mass_1") @@ -342,6 +399,7 @@ def setUp(self): waveform_approximant="IMRPhenomPv2", ), ) + non_roq_wfg = BackendWaveformGenerator(non_roq_wfg, self.xp) ifos.inject_signal( parameters=self.test_parameters, waveform_generator=non_roq_wfg @@ -349,20 +407,6 @@ def setUp(self): self.ifos = ifos - roq_wfg = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.duration, - sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, - waveform_arguments=dict( - frequency_nodes_linear=fnodes_linear, - frequency_nodes_quadratic=fnodes_quadratic, - reference_frequency=20.0, - waveform_approximant="IMRPhenomPv2", - ), - ) - - self.roq_wfg = roq_wfg - self.non_roq = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=ifos, waveform_generator=non_roq_wfg ) @@ -374,41 +418,74 @@ def setUp(self): priors=self.priors.copy(), ) - self.roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=ifos, - waveform_generator=roq_wfg, - linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file, - priors=self.priors, - ) - - self.roq_phase = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=ifos, - waveform_generator=roq_wfg, - linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file, - phase_marginalization=True, - priors=self.priors.copy(), - ) - def tearDown(self): del ( - self.roq, self.non_roq, self.non_roq_phase, - self.roq_phase, self.ifos, self.priors, ) + @property + def linear_matrix_file(self): + return f"{self.roq_dir}/B_linear.npy" + + @property + def quadratic_matrix_file(self): + return f"{self.roq_dir}/B_quadratic.npy" + + @property + def params_file(self): + return f"{self.roq_dir}/params.dat" + + @cached_property + def roq_wfg(self): + fnodes_linear_file = f"{self.roq_dir}/fnodes_linear.npy" + fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" + fnodes_linear = np.load(fnodes_linear_file).T + fnodes_quadratic = np.load(fnodes_quadratic_file).T + wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + waveform_arguments=dict( + frequency_nodes_linear=fnodes_linear, + frequency_nodes_quadratic=fnodes_quadratic, + reference_frequency=20.0, + waveform_approximant="IMRPhenomPv2", + ), + ) + return BackendWaveformGenerator(wfg, self.xp) + + @cached_property + def roq(self): + return bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=self.ifos, + waveform_generator=self.roq_wfg, + linear_matrix=self.linear_matrix_file, + quadratic_matrix=self.quadratic_matrix_file, + priors=self.priors, + ) + + @cached_property + def roq_phase(self): + return bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=self.ifos, + waveform_generator=self.roq_wfg, + linear_matrix=self.linear_matrix_file, + quadratic_matrix=self.quadratic_matrix_file, + phase_marginalization=True, + priors=self.priors.copy(), + ) + def test_matches_non_roq(self): + roq_llr = self.roq.log_likelihood_ratio(self.test_parameters) + non_roq_llr = self.non_roq.log_likelihood_ratio(self.test_parameters) self.assertLess( - abs( - self.non_roq.log_likelihood_ratio(self.test_parameters) - - self.roq.log_likelihood_ratio(self.test_parameters) - ) / self.non_roq.log_likelihood_ratio(self.test_parameters), + abs(non_roq_llr - roq_llr) / non_roq_llr, 1e-3, ) + self.assertEqual(aac.get_namespace(roq_llr), self.xp) def test_time_prior_out_of_bounds_returns_zero(self): parameters = deepcopy(self.test_parameters) @@ -424,10 +501,12 @@ def test_create_roq_weights_with_params(self): quadratic_matrix=self.quadratic_matrix_file, priors=self.priors, ) + roq_llr = roq.log_likelihood_ratio(self.test_parameters) self.assertEqual( - roq.log_likelihood_ratio(self.test_parameters), + roq_llr, self.roq.log_likelihood_ratio(self.test_parameters) ) + self.assertEqual(aac.get_namespace(roq_llr), self.xp) def test_create_roq_weights_frequency_mismatch_works_with_params(self): @@ -537,33 +616,18 @@ def test_create_roq_weights_fails_due_to_duration(self): @pytest.mark.requires_roqs -class TestRescaledROQLikelihood(unittest.TestCase): +class TestRescaledROQLikelihood(unittest.TestCase, ROQBasisMixin): def test_rescaling(self): + linear_matrix_file = f"{self.roq_dir}/B_linear.npy" + quadratic_matrix_file = f"{self.roq_dir}/B_quadratic.npy" - # Possible locations for the ROQ: in the docker image, local, or on CIT - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) + fnodes_linear_file = f"{self.roq_dir}/fnodes_linear.npy" fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) + fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) + self.linear_matrix_file = f"{self.roq_dir}/B_linear.npy" + self.quadratic_matrix_file = f"{self.roq_dir}/B_quadratic.npy" + self.params_file = f"{self.roq_dir}/params.dat" scale_factor = 0.5 params = np.genfromtxt(self.params_file, names=True) @@ -611,7 +675,9 @@ def test_rescaling(self): @pytest.mark.requires_roqs -class TestROQLikelihoodHDF5(unittest.TestCase): +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") +class TestROQLikelihoodHDF5(unittest.TestCase, ROQBasisMixin): """ Test ROQ likelihood constructed from .hdf5 basis @@ -619,14 +685,13 @@ class TestROQLikelihoodHDF5(unittest.TestCase): respectively, and 2 quadratic bases constructed over 8Msun= self.priors["chirp_mass"].minimum) * @@ -843,13 +910,14 @@ def assertLess_likelihood_errors( self.priors["chirp_mass"].maximum = mc_max interferometers = bilby.gw.detector.InterferometerList(["H1", "L1"]) + interferometers.set_array_backend(self.xp) for ifo in interferometers: if minimum_frequency is None: ifo.minimum_frequency = self.minimum_frequency else: - ifo.minimum_frequency = minimum_frequency + ifo.minimum_frequency = self.xp.asarray(minimum_frequency) if maximum_frequency is not None: - ifo.maximum_frequency = maximum_frequency + ifo.maximum_frequency = self.xp.asarray(maximum_frequency) interferometers.set_strain_data_from_zero_noise( sampling_frequency=self.sampling_frequency, duration=self.duration, @@ -884,6 +952,7 @@ def assertLess_likelihood_errors( waveform_approximant=self.waveform_approximant ) ) + waveform_generator = BackendWaveformGenerator(waveform_generator, self.xp) interferometers.inject_signal(waveform_generator=waveform_generator, parameters=self.injection_parameters) likelihood = bilby.gw.GravitationalWaveTransient( @@ -901,12 +970,13 @@ def assertLess_likelihood_errors( waveform_approximant=self.waveform_approximant ) ) + search_waveform_generator = BackendWaveformGenerator(search_waveform_generator, self.xp) likelihood_roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=interferometers, priors=self.priors, waveform_generator=search_waveform_generator, - linear_matrix=basis_linear, - quadratic_matrix=basis_quadratic, + linear_matrix=f"{self.roq_dir}/{basis_linear}", + quadratic_matrix=f"{self.roq_dir}/{basis_quadratic}", roq_scale_factor=roq_scale_factor ) for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11): @@ -915,10 +985,11 @@ def assertLess_likelihood_errors( llr = likelihood.log_likelihood_ratio(parameters) llr_roq = likelihood_roq.log_likelihood_ratio(parameters) self.assertLess(np.abs(llr - llr_roq), max_llr_error) + self.assertEqual(aac.get_namespace(llr_roq), self.xp) @pytest.mark.requires_roqs -class TestCreateROQLikelihood(unittest.TestCase): +class TestCreateROQLikelihood(unittest.TestCase, ROQBasisMixin): """ Test if ROQ likelihood is constructed without any errors from .hdf5 or .npy basis @@ -926,9 +997,8 @@ class TestCreateROQLikelihood(unittest.TestCase): respectively, and 2 quadratic bases constructed over 8Msun 0.001) + self.assertEqual(aac.get_namespace(samples), self.xp) if __name__ == "__main__": diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index cf78849c7..9b6815264 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -3,6 +3,7 @@ from shutil import rmtree from importlib.metadata import version +import array_api_compat as aac import numpy as np import lal import lalsimulation as lalsim @@ -15,6 +16,8 @@ from bilby.gw import utils as gwutils +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGWUtils(unittest.TestCase): def setUp(self): self.outdir = "outdir" @@ -27,44 +30,53 @@ def tearDown(self): pass def test_asd_from_freq_series(self): - freq_data = np.array([1, 2, 3]) + freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) + self.assertEqual(aac.get_namespace(asd), self.xp) + asd = np.asarray(asd) + freq_data = np.asarray(freq_data) self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) def test_psd_from_freq_series(self): - freq_data = np.array([1, 2, 3]) + freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) + self.assertEqual(aac.get_namespace(psd), self.xp) + psd = np.asarray(psd) + freq_data = np.asarray(freq_data) self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) def test_inner_product(self): - aa = np.array([1, 2, 3]) - bb = np.array([5, 6, 7]) - frequency = np.array([0.2, 0.4, 0.6]) + aa = self.xp.asarray([1, 2, 3]) + bb = self.xp.asarray([5, 6, 7]) + frequency = self.xp.asarray([0.2, 0.4, 0.6]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() ip = gwutils.inner_product(aa, bb, frequency, PSD) self.assertEqual(ip, 0) + self.assertEqual(aac.get_namespace(ip), self.xp) def test_noise_weighted_inner_product(self): - aa = np.array([1e-23, 2e-23, 3e-23]) - bb = np.array([5e-23, 6e-23, 7e-23]) - frequency = np.array([100, 101, 102]) + aa = self.xp.asarray([1e-23, 2e-23, 3e-23]) + bb = self.xp.asarray([5e-23, 6e-23, 7e-23]) + frequency = self.xp.asarray([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 nwip = gwutils.noise_weighted_inner_product(aa, bb, psd, duration) - self.assertEqual(nwip, 239.87768033598326) + # torch doesn't have enough precision + self.assertAlmostEqual(float(nwip), 239.87768033598326, 10) self.assertEqual( gwutils.optimal_snr_squared(aa, psd, duration), gwutils.noise_weighted_inner_product(aa, aa, psd, duration), ) + self.assertEqual(aac.get_namespace(nwip), self.xp) def test_matched_filter_snr(self): - signal = np.array([1e-23, 2e-23, 3e-23]) - frequency_domain_strain = np.array([5e-23, 6e-23, 7e-23]) - frequency = np.array([100, 101, 102]) + signal = self.xp.asarray([1e-23, 2e-23, 3e-23]) + frequency_domain_strain = self.xp.asarray([5e-23, 6e-23, 7e-23]) + frequency = self.xp.asarray([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -72,7 +84,29 @@ def test_matched_filter_snr(self): mfsnr = gwutils.matched_filter_snr( signal, frequency_domain_strain, psd, duration ) - self.assertEqual(mfsnr, 25.510869054168282) + # torch doesn't have enough precision + self.assertAlmostEqual(float(mfsnr), 25.510869054168282, 10) + self.assertEqual(aac.get_namespace(mfsnr), self.xp) + + def test_overlap(self): + signal = self.xp.linspace(1e-23, 21e-23, 21) + frequency_domain_strain = self.xp.linspace(5e-23, 25e-23, 21) + frequency = self.xp.linspace(100, 120, 21) + PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() + psd = PSD.power_spectral_density_interpolated(frequency) + duration = 4 + overlap = gwutils.overlap( + signal, + frequency_domain_strain, + psd, + delta_frequency=1 / duration, + lower_cut_off=3, + upper_cut_off=18, + norm_a=gwutils.optimal_snr_squared(signal, psd, duration), + norm_b=gwutils.optimal_snr_squared(frequency_domain_strain, psd, duration), + ) + self.assertEqual(aac.get_namespace(overlap), self.xp) + self.assertAlmostEqual(float(overlap), 2.76914407e-05) @pytest.mark.skip(reason="GWOSC unstable: avoiding this test") def test_get_event_time(self): @@ -264,6 +298,8 @@ def test_safe_cast_mode_to_int(self): gwutils.safe_cast_mode_to_int(None) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestSkyFrameConversion(unittest.TestCase): def setUp(self) -> None: @@ -281,23 +317,39 @@ def tearDown(self) -> None: del self.ifos del self.samples + def test_conversion_single(self) -> None: + sample = self.priors.sample() + zenith = self.xp.asarray(sample["zenith"]) + azimuth = self.xp.asarray(sample["azimuth"]) + time = self.xp.asarray(sample["time"]) + self.ifos.set_array_backend(self.xp) + ra, dec = bilby.gw.utils.zenith_azimuth_to_ra_dec( + zenith, azimuth, time, self.ifos + ) + self.assertEqual(aac.get_namespace(ra), self.xp) + self.assertEqual(aac.get_namespace(dec), self.xp) + def test_conversion_gives_correct_prior(self) -> None: - zeniths = self.samples["zenith"] - azimuths = self.samples["azimuth"] - times = self.samples["time"] - args = zip(*[ - (zenith, azimuth, time, self.ifos) - for zenith, azimuth, time in zip(zeniths, azimuths, times) - ]) - ras, decs = zip(*map(bilby.gw.utils.zenith_azimuth_to_ra_dec, *args)) + zeniths = self.xp.asarray(self.samples["zenith"]) + azimuths = self.xp.asarray(self.samples["azimuth"]) + times = self.xp.asarray(self.samples["time"]) + self.ifos.set_array_backend(self.xp) + ras, decs = bilby.gw.utils.zenith_azimuth_to_ra_dec( + zeniths, azimuths, times, self.ifos + ) + self.assertEqual(aac.get_namespace(ras), self.xp) + self.assertEqual(aac.get_namespace(decs), self.xp) + ras = np.asarray(ras) + decs = np.asarray(decs) self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) -def test_ln_i0_mathces_scipy(): +@pytest.mark.array_backend +def test_ln_i0_mathces_scipy(xp): from scipy.special import i0 - values = np.linspace(-10, 10, 101) - assert max(abs(gwutils.ln_i0(values) - np.log(i0(values)))) < 1e-10 + values = xp.linspace(-10, 10, 101) + assert max(abs(gwutils.ln_i0(values) - xp.log(i0(values)))) < 1e-10 if __name__ == "__main__": diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index a8f942ed0..b8bcbb533 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -1,9 +1,12 @@ import unittest from unittest import mock +import array_api_compat as aac import bilby import lalsimulation import numpy as np +import pytest +from bilby.compat.utils import xp_wrap def dummy_func_array_return_value( @@ -36,16 +39,21 @@ def dummy_func_dict_return_value( return ht +@xp_wrap def dummy_func_array_return_value_2( - array, amplitude, mu, sigma, ra, dec, geocent_time, psi + array, amplitude, mu, sigma, ra, dec, geocent_time, psi, *, xp=None ): - return dict(plus=np.array(array), cross=np.array(array)) + return dict(plus=xp.asarray(array), cross=xp.asarray(array)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - 1, 4096, frequency_domain_source_model=dummy_func_dict_return_value + self.xp.asarray(1.0), + self.xp.asarray(4096.0), + frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( amplitude=1e-21, @@ -118,9 +126,11 @@ def conversion_func(): def test_duration(self): self.assertEqual(self.waveform_generator.duration, 1) + self.assertEqual(aac.get_namespace(self.waveform_generator.duration), self.xp) def test_sampling_frequency(self): self.assertEqual(self.waveform_generator.sampling_frequency, 4096) + self.assertEqual(aac.get_namespace(self.waveform_generator.sampling_frequency), self.xp) def test_source_model(self): self.assertEqual( @@ -129,10 +139,10 @@ def test_source_model(self): ) def test_frequency_array_type(self): - self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray) + self.assertEqual(aac.array_namespace(self.waveform_generator.frequency_array), self.xp) def test_time_array_type(self): - self.assertIsInstance(self.waveform_generator.time_array, np.ndarray) + self.assertEqual(aac.array_namespace(self.waveform_generator.time_array), self.xp) def test_source_model_parameters(self): formatted_parameters = self.waveform_generator._format_parameters( @@ -266,11 +276,13 @@ def conversion_func(): self.assertEqual(conversion_func, self.waveform_generator.parameter_conversion) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=1, - sampling_frequency=4096, + duration=self.xp.asarray(1.0), + sampling_frequency=self.xp.asarray(4096.0), frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( @@ -312,6 +324,8 @@ def test_frequency_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_time_domain_source_model_call_with_ndarray(self): self.waveform_generator.frequency_domain_source_model = None @@ -329,6 +343,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_time_domain_source_model_call_with_dict(self): self.waveform_generator.frequency_domain_source_model = None @@ -347,6 +362,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -456,8 +473,8 @@ def test_frequency_domain_caching_changing_model(self): def test_time_domain_caching_changing_model(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=1, - sampling_frequency=4096, + duration=self.xp.asarray(1.0), + sampling_frequency=self.xp.asarray(4096.0), time_domain_source_model=dummy_func_dict_return_value, ) original_waveform = self.waveform_generator.frequency_domain_strain( @@ -472,12 +489,18 @@ def test_time_domain_caching_changing_model(self): self.assertFalse( np.array_equal(original_waveform["plus"], new_waveform["plus"]) ) + self.assertEqual(aac.get_namespace(new_waveform["plus"]), self.xp) + self.assertEqual(aac.get_namespace(new_waveform["cross"]), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTimeDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - 1, 4096, time_domain_source_model=dummy_func_dict_return_value + self.xp.asarray(1.0), + self.xp.asarray(4096.0), + time_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( amplitude=1e-21, @@ -518,6 +541,27 @@ def test_time_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) + + def test_time_domain_source_model_call_with_explicit_backend(self): + expected = self.waveform_generator.time_domain_source_model( + self.waveform_generator.time_array, + self.simulation_parameters["amplitude"], + self.simulation_parameters["mu"], + self.simulation_parameters["sigma"], + self.simulation_parameters["ra"], + self.simulation_parameters["dec"], + self.simulation_parameters["geocent_time"], + self.simulation_parameters["psi"], + ) + actual = self.waveform_generator.time_domain_strain( + parameters=self.simulation_parameters, xp=self.xp + ) + self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) + self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None @@ -537,6 +581,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None @@ -557,6 +602,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None diff --git a/test/hyper/hyper_pe_test.py b/test/hyper/hyper_pe_test.py index 4ca58927d..83b3c0a51 100644 --- a/test/hyper/hyper_pe_test.py +++ b/test/hyper/hyper_pe_test.py @@ -105,7 +105,7 @@ def test_resample_without_max_samples(self): self.sampling_model, log_evidences=self.log_evidences, ) - resampled = like.resample_posteriors(10) + resampled = like.resample_posteriors(max_samples=10) self.assertEqual(resampled["a"].shape, (len(self.lengths), 10))