-
Notifications
You must be signed in to change notification settings - Fork 134
Support non-numpy array backends #886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ColmTalbot
wants to merge
140
commits into
bilby-dev:main
Choose a base branch
from
ColmTalbot:bilback
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
140 commits
Select commit
Hold shift + click to select a range
1e3f4af
FEAT: enable backend switching for base gravitational-wave transient …
ColmTalbot cf5c611
FEAT: support multiband and relative binning likelihoods
ColmTalbot 2bfc833
FEAT: make more conversions backend agnostic
ColmTalbot d29b860
FEAT: use more normal conversions
ColmTalbot c5eb323
FEAT: move backend switching code to bilby
ColmTalbot 9bec666
FEAT: make core prior backend agnostic
ColmTalbot 0d9aba6
FEAT: make non-numpy arrays serializable
ColmTalbot aea0ea8
BUG: fix some array conversion methods
ColmTalbot 8e61b7f
DEV: some more prior agnosticism
ColmTalbot b658a02
TEST: make all prior tests run
ColmTalbot 3ed92e0
DEV: move some jax functionality to compat
ColmTalbot c49e6ef
REFACTOR: use array backend for ln_i0
ColmTalbot ef3d482
make distance marginalizatio backend transparent
ColmTalbot 95be371
DEV: some more prior dict array refactoring
ColmTalbot 9f7eb38
fix jax logic for distance marginalization
ColmTalbot 553689f
improve efficiency of setting up multibanding
ColmTalbot 55c3cef
make high-dimensional gaussians jax compatible
ColmTalbot fd7c3c3
make cubic spline calibration work with jax backend
ColmTalbot a062ca3
BUG: fix linspace calls
ColmTalbot a9587a8
ENH: fix bottleneck in relative binning for JAX
ColmTalbot 77b39a1
ENH: make interpolated prior backend friendly
ColmTalbot f18dd0e
REFACTOR: refactor backend-specific interpolation code
ColmTalbot fd28b85
ENH: make sine gaussian model backend independent
ColmTalbot 7b98e7c
ENH: make roq likelihood backend independent
ColmTalbot 221cd57
BUG: fix roq slicing
ColmTalbot a9d3490
FEAT: make condition chi evaluable
ColmTalbot 93fefb9
MAINT: make whitening work for non-numpy
ColmTalbot 997d094
EXAMPLE: update jax example
ColmTalbot 2a6d0f0
BUG: fix interpax interpolation method
ColmTalbot 590de47
REFACTOR: update variable backend for new parameter method
ColmTalbot c7e7525
some simplifications of array transparency
ColmTalbot 8d99e56
HYPER: make hyperparameter likelihood handle array backends
ColmTalbot 53fdb6d
MAINT: switch back to bilby_cython
ColmTalbot cf2a033
TYPO: fix typo in multiband time-marginalized likelihood
ColmTalbot 51b42e6
MAINT: removed unused import
ColmTalbot 0e8739c
BUG: add explicit array cast in conversion
ColmTalbot b1cb48a
REFACTOR: some refactoring of array edge cases
ColmTalbot 4b3d701
MAINT: removed extra ripple code
ColmTalbot bd36cc2
REFACTOR: make bilby_cython an optional dependency
ColmTalbot cd6bccb
FMT: formatting fixes
ColmTalbot 043b3c9
BUG: fix array introspection for conversion
ColmTalbot bc070af
REFACTOR: make parameters for waveform generator more strict
ColmTalbot 2a537b1
BUG: fix core likelihood tests
ColmTalbot b87327a
BUG: fix calibration calculations
ColmTalbot a605c05
EXAMPLE: update jax fast tutorial
ColmTalbot 76c61a5
TST: refactor marginalization tests to be less restrictive
ColmTalbot 9c4bbb0
DOC: update jittedlikelihood docstring
ColmTalbot 18d486b
TEST: speed up initializing prior tests
ColmTalbot c258a62
BUG: fix some test failures
ColmTalbot a0caae9
BUG: fix conditional+joint prior rescaling
ColmTalbot c0e7dc0
BUG: fix some gnarly conversion corner cases
ColmTalbot 6949e7f
BUG: fix multiband likelihood
ColmTalbot 97c87e3
BUG: fix bug in array_namespace check
ColmTalbot c3b6830
TEST: make sure healpix prior doesn't store state between calls
ColmTalbot dca5bd3
FMT: example formatting fixes
ColmTalbot 22f376e
BUG: make sure indices don't overflow in roq
ColmTalbot e991080
BUG: fix multiband time marginalization setup
ColmTalbot 4024d70
BUG: fix roq interpolation for out of bounds sample
ColmTalbot 9549798
TYPO: fix typo in jax example
ColmTalbot 0397ba6
REFACTOR: refactor more roq likelihood tests
ColmTalbot 0ee5f14
MAINT: revert new conversions
ColmTalbot 78e6faa
CI: fix selecting only non-windows os
ColmTalbot 3dc8d63
MAINT: make sure compat subpackages are listed in pyproject
ColmTalbot a13304c
TYPO: Fix package list formatting in pyproject.toml
ColmTalbot f7ba0d3
BUG: readd erroneously removed line
ColmTalbot 3ec5f6c
DOC: remove extraneous docstring
ColmTalbot 79ef95c
TEST: fix test failures
ColmTalbot 19f296b
TEST: start adding jax tests
ColmTalbot f57b86f
CI: add jax tests to CI
ColmTalbot c3c59c1
MAINT: add jax extras option
ColmTalbot 2730baa
Some more jax testing updates
ColmTalbot 083e865
MAINT: actually add jax requirements
ColmTalbot 59b3da8
CI: don't trivially skip all tests...
ColmTalbot 72c8afa
Initial pass at making grid work with jax
ColmTalbot 3538000
TEST: add more jax test coverage
ColmTalbot d9461be
FMT: precommit fixes
ColmTalbot c5f2127
TEST: fix jax tests
ColmTalbot fc97f62
TEST: add basic gw conversion jax tests
ColmTalbot aa8d846
TEST: more debugging slab spike test
ColmTalbot fa71fc6
TEST: jax tests work again
ColmTalbot 1cf8b60
DOC: add initial doc page for array backend
ColmTalbot feafa23
TEST: add a bunch of gw tests
ColmTalbot b41df86
DOC: fix doc page formatting
ColmTalbot 4bf5cd1
FMT: fix formatting
ColmTalbot 9329f75
BUG: fix typo in bilby_cython call
ColmTalbot 4a66829
BUG: fix list input for asd calculation
ColmTalbot 239613d
FMT: fix syntax for array conversion and backend checks
ColmTalbot b330aa7
BUG: fix some broken formatting
ColmTalbot fe76a67
FMT: fix formatting
ColmTalbot faf7535
BUG: fix bugs in testing
ColmTalbot 640d911
Fix some more conversions
ColmTalbot 73f89b4
Add pytorch core testing
ColmTalbot bbb72d9
FMT: run precommits
ColmTalbot fcfabdc
Make torch fully tested
ColmTalbot 9b3b5b8
FMT: pre-commit fix
ColmTalbot 12ba0b5
TEST: fix torch roq tests
ColmTalbot 07a5ebe
CI: prioritize torch tests
ColmTalbot 42d8b07
TEST: another attempt to fix torch tests
ColmTalbot c669bfa
Another attempt at fixing torch ROQ tests
ColmTalbot 70f029d
Fix arrays of data setting
ColmTalbot 0184120
BUG: fix some more roq array issues
ColmTalbot b2cf8aa
Make ROQ calculations use correct array backend
ColmTalbot 86763f8
BUG: fix a missing array case
ColmTalbot a44cdbf
FMT: pre-commit fixes
ColmTalbot 2542fa2
CI: drop torch tests for python 3.10
ColmTalbot a68bd37
FMT: precommit fix
ColmTalbot 4e3ca65
TEST: exclude studentt tests for jax
ColmTalbot a909e24
Add some more explicit array casts
ColmTalbot 9343b45
BUG: bug fixes for prior and gw likelihoods
ColmTalbot b99fc35
BUG: fix array namespace for torch
ColmTalbot a7fbff1
Update patches and backend documentation
ColmTalbot 933dd94
Address some comments and add some docstrings
ColmTalbot 5ffa0d2
Fix precommits
ColmTalbot 618ca4f
MAINT: Don't track uv lock file
ColmTalbot 607059b
MAINT: drop python 3.10 support
ColmTalbot 5254b5f
Remove extra multibanding time marginalization lines
ColmTalbot 310b69a
TEST: fix test failures
ColmTalbot 716696e
FEAT: update random number generation for non-numpy backends
ColmTalbot 92f6b2b
FMT: fix pre commits
ColmTalbot a705b6c
MAINT: make array api compatibility disabled by default
ColmTalbot aee9823
TYPO: typo fixes
ColmTalbot 8089007
BUG: fix nyquist frequency for white noise
ColmTalbot 60590d6
TEST: fix a test
ColmTalbot 781a502
TEST: add default device specification
ColmTalbot 4a46bcc
FMT: remove extra empty line
ColmTalbot 01d839a
BUG: fix failures in tests
ColmTalbot f4c5440
MAINT: move orng into jax dependencies
ColmTalbot 92c8ebc
TEST: update waveform generator test
ColmTalbot 06241cb
TEST: improve visibility of prior test failures
ColmTalbot a41bc0d
DOC: update documentation
ColmTalbot ab67301
TEST: fix broken tests
ColmTalbot 487a074
FMT: fix precommits
ColmTalbot 89d80d1
CI: split array backend tests into a separate job
ColmTalbot 6143c26
TYPO: fix typo in ci
ColmTalbot b151876
BUG: fix up torch tests
ColmTalbot 11b94b6
TEST: fix testing with array api support
ColmTalbot 3ec7c7f
BUG: fix typos in healpix priors
ColmTalbot 6ccaa45
TEST: loosen some numeric tests for torch
ColmTalbot 2d15cf9
CI: fix basic install workflow
ColmTalbot 6ff0ef7
CI: fix testing executables
ColmTalbot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,3 +16,4 @@ MANIFEST | |
| **/outdir | ||
| .idea/* | ||
| bilby/_version.py | ||
| uv.lock | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from ..core.likelihood import Likelihood | ||
|
|
||
|
|
||
| class JittedLikelihood(Likelihood): | ||
| """ | ||
| A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. | ||
|
|
||
| Parameters | ||
| ========== | ||
| likelihood: bilby.core.likelihood.Likelihood | ||
| The likelihood to wrap. | ||
| cast_to_float: bool | ||
| Whether to return a float instead of a :code:`jax.Array`. | ||
| """ | ||
|
|
||
| def __init__(self, likelihood, cast_to_float=True): | ||
| self._likelihood = likelihood | ||
| self._ll = jax.jit(likelihood.log_likelihood) | ||
| self._llr = jax.jit(likelihood.log_likelihood_ratio) | ||
| self.cast_to_float = cast_to_float | ||
| super().__init__() | ||
|
|
||
| def __getattr__(self, name): | ||
| return getattr(self._likelihood, name) | ||
|
|
||
| def log_likelihood(self, parameters): | ||
| parameters = {k: jnp.array(v) for k, v in parameters.items()} | ||
| ln_l = self._ll(parameters) | ||
| if self.cast_to_float: | ||
| ln_l = float(ln_l) | ||
| return ln_l | ||
|
|
||
| def log_likelihood_ratio(self, parameters): | ||
| parameters = {k: jnp.array(v) for k, v in parameters.items()} | ||
| ln_l = self._llr(parameters) | ||
| if self.cast_to_float: | ||
| ln_l = float(ln_l) | ||
| return ln_l |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import array_api_compat as aac | ||
|
|
||
| from .utils import xp_wrap, BackendNotImplementedError, BILBY_ARRAY_API | ||
|
|
||
|
|
||
| def multivariate_logpdf(xp, mean, cov): | ||
| """ | ||
| Return a function to evaluate the log probability density of a multivariate | ||
| Gaussian with given mean vector and covariance matrix for the provided | ||
| array backend. | ||
|
|
||
| Parameters | ||
| ========== | ||
| xp: numpy, torch, jax.numpy | ||
| A module that will resolve to :code:`numpy`, :code:`torch`, or | ||
| :code:`jax.numpy` in :code:`array_api_compat.is_..._namespace`. | ||
| mean: array-like | ||
| A one-dimensional array providing the mean of the distribution. | ||
| cov: array-like | ||
| A two-dimensional array providing the covariance matrix of the | ||
| distribution. | ||
|
|
||
| Returns | ||
| ======= | ||
| logpdf: callable | ||
| A callable that provides the log probaility density provided an array | ||
| of points to evaluate at. | ||
| """ | ||
| from scipy.stats import multivariate_normal | ||
|
|
||
| if not BILBY_ARRAY_API or aac.is_numpy_namespace(xp): | ||
| logpdf = multivariate_normal(mean=mean, cov=cov).logpdf | ||
| elif aac.is_jax_namespace(xp): | ||
| from functools import partial | ||
| from jax.scipy.stats.multivariate_normal import logpdf | ||
|
|
||
| logpdf = partial(logpdf, mean=mean, cov=cov) | ||
| elif aac.is_torch_namespace(xp): | ||
| from torch.distributions.multivariate_normal import MultivariateNormal | ||
|
|
||
| mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov)) | ||
| logpdf = mvn.log_prob | ||
| else: | ||
| raise BackendNotImplementedError( | ||
| f"Unable to import multivariate_logpdf for {xp}" | ||
| ) | ||
| return logpdf | ||
|
|
||
|
|
||
| @xp_wrap | ||
| def interp(x, xs, fs, /, left=None, right=None, period=None, *, xp=None): | ||
| """ | ||
| A simple implementation of numpy-style linear interpolation | ||
|
|
||
| The logic is copied from | ||
| https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 | ||
|
|
||
| Parameters | ||
| ========== | ||
| x: array-like | ||
| The values to evaluate the interpolant at. | ||
| xs: array-like | ||
| The x-values for setting up the interpolant. | ||
| ys: array-like | ||
| The values of the function for setting up the interpolant. | ||
| left: float | ||
| The value to use for x < xs[0]. Default is fs[0] | ||
| right: float | ||
| The value to use for x > xs[-1]. Default is fs[-1]. | ||
| period: float | ||
| The period of the interpolant. | ||
| Parameters left and right are ignored if period is specified. | ||
|
|
||
| Notes | ||
| ===== | ||
| To avoid overlap with the ``xp`` variable, the second and third variable | ||
| names from differ from numpy. | ||
| These arguments are enforced to be positional only. | ||
| """ | ||
| if not BILBY_ARRAY_API or hasattr(xp, "interp"): | ||
| return xp.interp(x, xs, fs, left=left, right=right, period=period) | ||
|
|
||
| if period is not None: | ||
| x = x % period | ||
| if left is None: | ||
| left = fs[0] | ||
| if right is None: | ||
| right = fs[-1] | ||
|
|
||
| x = xp.atleast_1d(x) | ||
|
|
||
| m = (fs[1:] - fs[:-1]) / (xs[1:] - xs[:-1]) | ||
| b = fs[:-1] - (m * xs[:-1]) | ||
|
|
||
| indices = xp.sum(xp.ge(x[:, None], xs[None, :]), axis=1) - 1 | ||
| indices = xp.clip(indices, 0, len(m) - 1) | ||
|
|
||
| ret = m[indices] * x + b[indices] | ||
|
|
||
| if period is None: | ||
| ret = xp.where(x < xs[0], xp.asarray(left), ret) | ||
| ret = xp.where(x > xs[-1], xp.asarray(right), ret) | ||
|
|
||
| return ret.squeeze() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| import numpy as np | ||
|
|
||
| Real = float | int | np.number | ||
| ArrayLike = np.ndarray | list | tuple |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.