Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
140 commits
Select commit Hold shift + click to select a range
1e3f4af
FEAT: enable backend switching for base gravitational-wave transient …
ColmTalbot Oct 25, 2024
cf5c611
FEAT: support multiband and relative binning likelihoods
ColmTalbot Oct 25, 2024
2bfc833
FEAT: make more conversions backend agnostic
ColmTalbot Oct 26, 2024
d29b860
FEAT: use more normal conversions
ColmTalbot Oct 28, 2024
c5eb323
FEAT: move backend switching code to bilby
ColmTalbot Nov 13, 2024
9bec666
FEAT: make core prior backend agnostic
ColmTalbot Nov 14, 2024
0d9aba6
FEAT: make non-numpy arrays serializable
ColmTalbot Nov 14, 2024
aea0ea8
BUG: fix some array conversion methods
ColmTalbot Nov 14, 2024
8e61b7f
DEV: some more prior agnosticism
ColmTalbot Dec 11, 2024
b658a02
TEST: make all prior tests run
ColmTalbot Dec 12, 2024
3ed92e0
DEV: move some jax functionality to compat
ColmTalbot Jan 25, 2025
c49e6ef
REFACTOR: use array backend for ln_i0
ColmTalbot Jan 25, 2025
ef3d482
make distance marginalizatio backend transparent
ColmTalbot Jan 25, 2025
95be371
DEV: some more prior dict array refactoring
ColmTalbot Jan 25, 2025
9f7eb38
fix jax logic for distance marginalization
ColmTalbot Jan 29, 2025
553689f
improve efficiency of setting up multibanding
ColmTalbot Jan 29, 2025
55c3cef
make high-dimensional gaussians jax compatible
ColmTalbot Jan 29, 2025
fd7c3c3
make cubic spline calibration work with jax backend
ColmTalbot Jan 30, 2025
a062ca3
BUG: fix linspace calls
ColmTalbot Feb 4, 2025
a9587a8
ENH: fix bottleneck in relative binning for JAX
ColmTalbot Feb 4, 2025
77b39a1
ENH: make interpolated prior backend friendly
ColmTalbot Feb 4, 2025
f18dd0e
REFACTOR: refactor backend-specific interpolation code
ColmTalbot Feb 5, 2025
fd28b85
ENH: make sine gaussian model backend independent
ColmTalbot Feb 5, 2025
7b98e7c
ENH: make roq likelihood backend independent
ColmTalbot Feb 5, 2025
221cd57
BUG: fix roq slicing
ColmTalbot Feb 5, 2025
a9d3490
FEAT: make condition chi evaluable
ColmTalbot Jun 3, 2025
93fefb9
MAINT: make whitening work for non-numpy
ColmTalbot Jun 12, 2025
997d094
EXAMPLE: update jax example
ColmTalbot Aug 20, 2025
2a6d0f0
BUG: fix interpax interpolation method
ColmTalbot Aug 20, 2025
590de47
REFACTOR: update variable backend for new parameter method
ColmTalbot Oct 2, 2025
c7e7525
some simplifications of array transparency
ColmTalbot Oct 2, 2025
8d99e56
HYPER: make hyperparameter likelihood handle array backends
ColmTalbot Oct 2, 2025
53fdb6d
MAINT: switch back to bilby_cython
ColmTalbot Dec 11, 2025
cf2a033
TYPO: fix typo in multiband time-marginalized likelihood
ColmTalbot Dec 22, 2025
51b42e6
MAINT: removed unused import
ColmTalbot Dec 22, 2025
0e8739c
BUG: add explicit array cast in conversion
ColmTalbot Dec 22, 2025
b1cb48a
REFACTOR: some refactoring of array edge cases
ColmTalbot Dec 22, 2025
4b3d701
MAINT: removed extra ripple code
ColmTalbot Dec 22, 2025
bd36cc2
REFACTOR: make bilby_cython an optional dependency
ColmTalbot Dec 22, 2025
cd6bccb
FMT: formatting fixes
ColmTalbot Dec 22, 2025
043b3c9
BUG: fix array introspection for conversion
ColmTalbot Jan 21, 2026
bc070af
REFACTOR: make parameters for waveform generator more strict
ColmTalbot Jan 21, 2026
2a537b1
BUG: fix core likelihood tests
ColmTalbot Jan 21, 2026
b87327a
BUG: fix calibration calculations
ColmTalbot Jan 22, 2026
a605c05
EXAMPLE: update jax fast tutorial
ColmTalbot Jan 22, 2026
76c61a5
TST: refactor marginalization tests to be less restrictive
ColmTalbot Jan 22, 2026
9c4bbb0
DOC: update jittedlikelihood docstring
ColmTalbot Jan 22, 2026
18d486b
TEST: speed up initializing prior tests
ColmTalbot Jan 22, 2026
c258a62
BUG: fix some test failures
ColmTalbot Jan 22, 2026
a0caae9
BUG: fix conditional+joint prior rescaling
ColmTalbot Jan 22, 2026
c0e7dc0
BUG: fix some gnarly conversion corner cases
ColmTalbot Jan 22, 2026
6949e7f
BUG: fix multiband likelihood
ColmTalbot Jan 22, 2026
97c87e3
BUG: fix bug in array_namespace check
ColmTalbot Jan 22, 2026
c3b6830
TEST: make sure healpix prior doesn't store state between calls
ColmTalbot Jan 22, 2026
dca5bd3
FMT: example formatting fixes
ColmTalbot Jan 22, 2026
22f376e
BUG: make sure indices don't overflow in roq
ColmTalbot Jan 22, 2026
e991080
BUG: fix multiband time marginalization setup
ColmTalbot Jan 23, 2026
4024d70
BUG: fix roq interpolation for out of bounds sample
ColmTalbot Jan 23, 2026
9549798
TYPO: fix typo in jax example
ColmTalbot Jan 23, 2026
0397ba6
REFACTOR: refactor more roq likelihood tests
ColmTalbot Jan 23, 2026
0ee5f14
MAINT: revert new conversions
ColmTalbot Jan 23, 2026
78e6faa
CI: fix selecting only non-windows os
ColmTalbot Jan 23, 2026
3dc8d63
MAINT: make sure compat subpackages are listed in pyproject
ColmTalbot Jan 23, 2026
a13304c
TYPO: Fix package list formatting in pyproject.toml
ColmTalbot Jan 23, 2026
f7ba0d3
BUG: readd erroneously removed line
ColmTalbot Jan 23, 2026
3ec5f6c
DOC: remove extraneous docstring
ColmTalbot Jan 23, 2026
79ef95c
TEST: fix test failures
ColmTalbot Jan 29, 2026
19f296b
TEST: start adding jax tests
ColmTalbot Jan 31, 2026
f57b86f
CI: add jax tests to CI
ColmTalbot Jan 31, 2026
c3c59c1
MAINT: add jax extras option
ColmTalbot Jan 31, 2026
2730baa
Some more jax testing updates
ColmTalbot Jan 31, 2026
083e865
MAINT: actually add jax requirements
ColmTalbot Jan 31, 2026
59b3da8
CI: don't trivially skip all tests...
ColmTalbot Jan 31, 2026
72c8afa
Initial pass at making grid work with jax
ColmTalbot Jan 31, 2026
3538000
TEST: add more jax test coverage
ColmTalbot Feb 1, 2026
d9461be
FMT: precommit fixes
ColmTalbot Feb 2, 2026
c5f2127
TEST: fix jax tests
ColmTalbot Feb 2, 2026
fc97f62
TEST: add basic gw conversion jax tests
ColmTalbot Feb 2, 2026
aa8d846
TEST: more debugging slab spike test
ColmTalbot Feb 2, 2026
fa71fc6
TEST: jax tests work again
ColmTalbot Feb 2, 2026
1cf8b60
DOC: add initial doc page for array backend
ColmTalbot Feb 2, 2026
feafa23
TEST: add a bunch of gw tests
ColmTalbot Feb 2, 2026
b41df86
DOC: fix doc page formatting
ColmTalbot Feb 2, 2026
4bf5cd1
FMT: fix formatting
ColmTalbot Feb 2, 2026
9329f75
BUG: fix typo in bilby_cython call
ColmTalbot Feb 2, 2026
4a66829
BUG: fix list input for asd calculation
ColmTalbot Feb 2, 2026
239613d
FMT: fix syntax for array conversion and backend checks
ColmTalbot Feb 2, 2026
b330aa7
BUG: fix some broken formatting
ColmTalbot Feb 2, 2026
fe76a67
FMT: fix formatting
ColmTalbot Feb 2, 2026
faf7535
BUG: fix bugs in testing
ColmTalbot Feb 2, 2026
640d911
Fix some more conversions
ColmTalbot Feb 2, 2026
73f89b4
Add pytorch core testing
ColmTalbot Feb 3, 2026
bbb72d9
FMT: run precommits
ColmTalbot Feb 3, 2026
fcfabdc
Make torch fully tested
ColmTalbot Feb 3, 2026
9b3b5b8
FMT: pre-commit fix
ColmTalbot Feb 3, 2026
12ba0b5
TEST: fix torch roq tests
ColmTalbot Feb 3, 2026
07a5ebe
CI: prioritize torch tests
ColmTalbot Feb 3, 2026
42d8b07
TEST: another attempt to fix torch tests
ColmTalbot Feb 3, 2026
c669bfa
Another attempt at fixing torch ROQ tests
ColmTalbot Feb 3, 2026
70f029d
Fix arrays of data setting
ColmTalbot Feb 3, 2026
0184120
BUG: fix some more roq array issues
ColmTalbot Feb 3, 2026
b2cf8aa
Make ROQ calculations use correct array backend
ColmTalbot Feb 3, 2026
86763f8
BUG: fix a missing array case
ColmTalbot Feb 3, 2026
a44cdbf
FMT: pre-commit fixes
ColmTalbot Feb 3, 2026
2542fa2
CI: drop torch tests for python 3.10
ColmTalbot Feb 3, 2026
a68bd37
FMT: precommit fix
ColmTalbot Feb 3, 2026
4e3ca65
TEST: exclude studentt tests for jax
ColmTalbot Feb 3, 2026
a909e24
Add some more explicit array casts
ColmTalbot Feb 3, 2026
9343b45
BUG: bug fixes for prior and gw likelihoods
ColmTalbot Feb 17, 2026
b99fc35
BUG: fix array namespace for torch
ColmTalbot Feb 17, 2026
a7fbff1
Update patches and backend documentation
ColmTalbot May 14, 2026
933dd94
Address some comments and add some docstrings
ColmTalbot May 14, 2026
5ffa0d2
Fix precommits
ColmTalbot May 14, 2026
618ca4f
MAINT: Don't track uv lock file
ColmTalbot May 15, 2026
607059b
MAINT: drop python 3.10 support
ColmTalbot May 15, 2026
5254b5f
Remove extra multibanding time marginalization lines
ColmTalbot May 15, 2026
310b69a
TEST: fix test failures
ColmTalbot May 15, 2026
716696e
FEAT: update random number generation for non-numpy backends
ColmTalbot May 18, 2026
92f6b2b
FMT: fix pre commits
ColmTalbot May 18, 2026
a705b6c
MAINT: make array api compatibility disabled by default
ColmTalbot May 18, 2026
aee9823
TYPO: typo fixes
ColmTalbot May 18, 2026
8089007
BUG: fix nyquist frequency for white noise
ColmTalbot May 18, 2026
60590d6
TEST: fix a test
ColmTalbot May 18, 2026
781a502
TEST: add default device specification
ColmTalbot May 19, 2026
4a46bcc
FMT: remove extra empty line
ColmTalbot May 19, 2026
01d839a
BUG: fix failures in tests
ColmTalbot May 19, 2026
f4c5440
MAINT: move orng into jax dependencies
ColmTalbot May 19, 2026
92c8ebc
TEST: update waveform generator test
ColmTalbot May 19, 2026
06241cb
TEST: improve visibility of prior test failures
ColmTalbot May 19, 2026
a41bc0d
DOC: update documentation
ColmTalbot May 19, 2026
ab67301
TEST: fix broken tests
ColmTalbot May 19, 2026
487a074
FMT: fix precommits
ColmTalbot May 19, 2026
89d80d1
CI: split array backend tests into a separate job
ColmTalbot May 20, 2026
6143c26
TYPO: fix typo in ci
ColmTalbot May 20, 2026
b151876
BUG: fix up torch tests
ColmTalbot May 20, 2026
11b94b6
TEST: fix testing with array api support
ColmTalbot May 20, 2026
3ec7c7f
BUG: fix typos in healpix priors
ColmTalbot May 20, 2026
6ccaa45
TEST: loosen some numeric tests for torch
ColmTalbot May 20, 2026
2d15cf9
CI: fix basic install workflow
ColmTalbot May 20, 2026
6ff0ef7
CI: fix testing executables
ColmTalbot May 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/basic-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# disable windows build test as bilby_cython is currently broken there
os: [ubuntu-latest, macos-latest]
python-version: ["3.10", "3.11", "3.12", "3.13"]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -39,6 +38,7 @@ jobs:
- name: Test imports
run: bash test/ci_test_imports.sh
- name: Test entry points
if: matrix.os != 'windows-latest'
run: |
for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do
${script} --help;
Expand Down
59 changes: 56 additions & 3 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ jobs:
fail-fast: false
matrix:
python:
- name: Python 3.10
version: 3.10
short-version: 310
- name: Python 3.11
version: 3.11
short-version: 311
Expand Down Expand Up @@ -84,3 +81,59 @@ jobs:
with:
name: pytest-${{ matrix.python.short-version }}
path: pytest.xml

array-backend:

name: ${{ matrix.python.name }} array backend (${{ matrix.backend.name }})
runs-on: ubuntu-latest
container: ghcr.io/bilby-dev/bilby-python${{ matrix.python.short-version }}:latest
strategy:
fail-fast: false
matrix:
python:
- name: Python 3.13
version: 3.13
short-version: 313
backend:
- name: numpy
install-args: .
- name: jax
install-args: .[jax]
- name: torch
install-args: .
env:
BILBY_ARRAY_API: 1
SCIPY_ARRAY_API: 1

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-tags: true
- name: Install package
run: |
# activate env so that conda list shows the correct environment
apt-get update
apt-get install -y gnupg curl
source $CONDA_PATH/bin/activate python${{ matrix.python.short-version }}
conda install -c conda-forge -y liblal!=7.7.0 python-lal!=7.7.0
python -m pip install ${{ matrix.backend.install-args }}
python -m pip install orng
conda list --show-channel-urls
shell: bash
- name: Run array-backend unit tests
run: |
pytest --array-backend ${{ matrix.backend.name }} --durations 10 --junitxml=pytest-array.xml
- name: Publish coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: coverage.xml
flags: python${{ matrix.backend.name }}
slug: bilby-dev/bilby
- name: Upload array-backend test results
if: always()
uses: actions/upload-artifact@v4
with:
name: pytest-array-${{ matrix.python.short-version }}-${{ matrix.backend.name }}
path: pytest-array.xml
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ MANIFEST
**/outdir
.idea/*
bilby/_version.py
uv.lock
Empty file added bilby/compat/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions bilby/compat/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import jax
import jax.numpy as jnp
from ..core.likelihood import Likelihood


class JittedLikelihood(Likelihood):
"""
A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`.

Parameters
==========
likelihood: bilby.core.likelihood.Likelihood
The likelihood to wrap.
cast_to_float: bool
Whether to return a float instead of a :code:`jax.Array`.
"""

def __init__(self, likelihood, cast_to_float=True):
self._likelihood = likelihood
self._ll = jax.jit(likelihood.log_likelihood)
self._llr = jax.jit(likelihood.log_likelihood_ratio)
self.cast_to_float = cast_to_float
super().__init__()

def __getattr__(self, name):
return getattr(self._likelihood, name)

def log_likelihood(self, parameters):
parameters = {k: jnp.array(v) for k, v in parameters.items()}
ln_l = self._ll(parameters)
if self.cast_to_float:
ln_l = float(ln_l)
return ln_l

def log_likelihood_ratio(self, parameters):
parameters = {k: jnp.array(v) for k, v in parameters.items()}
ln_l = self._llr(parameters)
if self.cast_to_float:
ln_l = float(ln_l)
return ln_l
104 changes: 104 additions & 0 deletions bilby/compat/patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import array_api_compat as aac
Comment thread
GregoryAshton marked this conversation as resolved.

from .utils import xp_wrap, BackendNotImplementedError, BILBY_ARRAY_API


def multivariate_logpdf(xp, mean, cov):
"""
Return a function to evaluate the log probability density of a multivariate
Gaussian with given mean vector and covariance matrix for the provided
array backend.

Parameters
==========
xp: numpy, torch, jax.numpy
A module that will resolve to :code:`numpy`, :code:`torch`, or
:code:`jax.numpy` in :code:`array_api_compat.is_..._namespace`.
mean: array-like
A one-dimensional array providing the mean of the distribution.
cov: array-like
A two-dimensional array providing the covariance matrix of the
distribution.

Returns
=======
logpdf: callable
A callable that provides the log probaility density provided an array
of points to evaluate at.
"""
from scipy.stats import multivariate_normal

if not BILBY_ARRAY_API or aac.is_numpy_namespace(xp):
logpdf = multivariate_normal(mean=mean, cov=cov).logpdf
elif aac.is_jax_namespace(xp):
from functools import partial
from jax.scipy.stats.multivariate_normal import logpdf

logpdf = partial(logpdf, mean=mean, cov=cov)
elif aac.is_torch_namespace(xp):
from torch.distributions.multivariate_normal import MultivariateNormal

mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov))
logpdf = mvn.log_prob
else:
raise BackendNotImplementedError(
f"Unable to import multivariate_logpdf for {xp}"
)
return logpdf


@xp_wrap
def interp(x, xs, fs, /, left=None, right=None, period=None, *, xp=None):
"""
A simple implementation of numpy-style linear interpolation

The logic is copied from
https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964

Parameters
==========
x: array-like
The values to evaluate the interpolant at.
xs: array-like
The x-values for setting up the interpolant.
ys: array-like
The values of the function for setting up the interpolant.
left: float
The value to use for x < xs[0]. Default is fs[0]
right: float
The value to use for x > xs[-1]. Default is fs[-1].
period: float
The period of the interpolant.
Parameters left and right are ignored if period is specified.

Notes
=====
To avoid overlap with the ``xp`` variable, the second and third variable
names from differ from numpy.
These arguments are enforced to be positional only.
"""
if not BILBY_ARRAY_API or hasattr(xp, "interp"):
return xp.interp(x, xs, fs, left=left, right=right, period=period)

if period is not None:
x = x % period
if left is None:
left = fs[0]
if right is None:
right = fs[-1]

x = xp.atleast_1d(x)

m = (fs[1:] - fs[:-1]) / (xs[1:] - xs[:-1])
b = fs[:-1] - (m * xs[:-1])

indices = xp.sum(xp.ge(x[:, None], xs[None, :]), axis=1) - 1
indices = xp.clip(indices, 0, len(m) - 1)

ret = m[indices] * x + b[indices]

if period is None:
ret = xp.where(x < xs[0], xp.asarray(left), ret)
ret = xp.where(x > xs[-1], xp.asarray(right), ret)

return ret.squeeze()
4 changes: 4 additions & 0 deletions bilby/compat/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

Real = float | int | np.number
ArrayLike = np.ndarray | list | tuple
Loading
Loading