Skip to content

Support non-numpy array backends#886

Open
ColmTalbot wants to merge 140 commits into
bilby-dev:mainfrom
ColmTalbot:bilback
Open

Support non-numpy array backends#886
ColmTalbot wants to merge 140 commits into
bilby-dev:mainfrom
ColmTalbot:bilback

Conversation

@ColmTalbot
Copy link
Copy Markdown
Collaborator

@ColmTalbot ColmTalbot commented Jan 7, 2025

I've been working on this PR on and off for a few months, it isn't ready yet, but I wanted to share it in case other people had early opinions.

The goal is to make it easier to interface with models/samplers implemented in e.g., JAX, that support GPU/TPU acceleration and JIT compilation.

The general guiding principles are:

  • when possible maintain existing behaviour with numpy/builtin arguments
  • work introspectively so users don't need to specify the target backend, but use input types
  • write as little backend specific code as possible, mostly through using the array-api specification and scipy interoperability

The primary changes so far are:

  • making most priors backend independent, there are a few holdouts where the underlying scipy functionality isn't compatible yet
  • core likelihoods mostly work with data from any backend
  • GW likelihoods work with any backend supported by the source function
  • the GW detector objects don't work via introspection, they need to be manually set
  • GW geometry (currently in bilby_cython) is handled via multiple-dispatch and added back into bilby

Changed behaviour:

Remaining issues:

  • Saving/loading nun-numpy arrays in result files may not work
  • I added some additional parameter conversions that I will remove
  • the bilby.gw.jaxstuff file should be removed and relevant functionality be moved elsewhere, it's currently just used for testing
  • the ROQ likelihood hasn't been ported
  • add more testing with JAX
  • translate some of the hyperparameter functionality, c.f., GWPopulation

@ColmTalbot ColmTalbot added the enhancement New feature or request label Jan 7, 2025
@ColmTalbot ColmTalbot marked this pull request as draft January 7, 2025 19:38
@ColmTalbot ColmTalbot force-pushed the bilback branch 2 times, most recently from ea348fa to 771a8a9 Compare January 22, 2026 17:00
@ColmTalbot ColmTalbot marked this pull request as ready for review January 23, 2026 15:24
@ColmTalbot ColmTalbot changed the title DRAFT: Support non-numpy array backends Support non-numpy array backends Jan 23, 2026
@ColmTalbot ColmTalbot added >100 lines refactoring to discuss To be discussed on an upcoming call labels Jan 23, 2026
@ColmTalbot
Copy link
Copy Markdown
Collaborator Author

This is now ready for review.
There are some things that won't work with JAX at the moment, e.g., various combinations of likelihood marginalization/acceleration.
I think we should accept this at the moment, for at least a bilby v3 alpha/beta release, and keep chipping away at the various subcases over time.

There are a lot of changes, but most of them are essentially np -> xp.
Some things required refactoring to avoid modifying slices of arrays as JAX doesn't like that.

Bilby can once again be installed without bilby.cython.
This should improve our general portability, but when bilby_cython is installed it will be used.

I've managed to keep test changes minimal:

  • I updated the joint prior test to make it more stringent (keys more randomly ordered).
  • I refactored some expensive prior initialization that was dramatically slowing things down.
  • I improved the logic for figuring out when ROQs are available to help my local testing.
  • Some mocks of numpy had to be updated.

@mj-will mj-will added this to the 3.0.0 milestone Jan 27, 2026
Copy link
Copy Markdown
Collaborator

@mj-will mj-will left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments but I'll need to have another look.

Comment thread bilby/compat/patches.py Outdated
Comment thread bilby/compat/utils.py
Comment thread bilby/compat/utils.py Outdated
Comment thread bilby/compat/utils.py Outdated
Comment thread bilby/core/prior/analytical.py Outdated
This maps to the inverse CDF. This has been analytically solved for this case.
"""
return gammaincinv(self.k, val) * self.theta
return xp.asarray(gammaincinv(self.k, val)) * self.theta
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean this is falling back to numpy?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should update/recheck this, but at least jax doesn't have good support for this, but it looks like tensorflow has a version that numpyro uses (jax-ml/jax#5350). cupy does have this function, so this workaround may have just been for jax. I could add a BackendNotImplementedError.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be a candidate for a small patch that uses the TF version for jax until jax supports it natively?

Comment thread bilby/core/prior/analytical.py Outdated
Comment thread bilby/core/prior/dict.py Outdated
Comment thread bilby/core/prior/dict.py
Comment on lines -877 to -902
self[key].least_recently_sampled = result[key]
if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint:
joint[self[key].dist.distname] = [key]
elif isinstance(self[key], JointPrior):
joint[self[key].dist.distname].append(key)
for names in joint.values():
# this is needed to unpack how joint prior rescaling works
# as an example of a joint prior over {a, b, c, d} we might
# get the following based on the order within the joint prior
# {a: [], b: [], c: [1, 2, 3, 4], d: []}
# -> [1, 2, 3, 4]
# -> {a: 1, b: 2, c: 3, d: 4}
values = list()
for key in names:
values = np.concatenate([values, result[key]])
for key, value in zip(names, values):
result[key] = value

def safe_flatten(value):
"""
this is gross but can be removed whenever we switch to returning
arrays, flatten converts 0-d arrays to 1-d so has to be special
cased
"""
if isinstance(value, (float, int)):
return value
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is removing this intentional?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.

Comment thread bilby/gw/utils.py Outdated
Copy link
Copy Markdown
Collaborator Author

@ColmTalbot ColmTalbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the initial comments @mj-will I'll take a pass at them ASAP.

Comment thread bilby/core/prior/analytical.py Outdated
Comment thread bilby/core/prior/analytical.py
Comment thread bilby/core/prior/analytical.py
Comment thread bilby/core/prior/dict.py Outdated
Comment thread bilby/core/prior/dict.py
Comment on lines -877 to -902
self[key].least_recently_sampled = result[key]
if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint:
joint[self[key].dist.distname] = [key]
elif isinstance(self[key], JointPrior):
joint[self[key].dist.distname].append(key)
for names in joint.values():
# this is needed to unpack how joint prior rescaling works
# as an example of a joint prior over {a, b, c, d} we might
# get the following based on the order within the joint prior
# {a: [], b: [], c: [1, 2, 3, 4], d: []}
# -> [1, 2, 3, 4]
# -> {a: 1, b: 2, c: 3, d: 4}
values = list()
for key in names:
values = np.concatenate([values, result[key]])
for key, value in zip(names, values):
result[key] = value

def safe_flatten(value):
"""
this is gross but can be removed whenever we switch to returning
arrays, flatten converts 0-d arrays to 1-d so has to be special
cased
"""
if isinstance(value, (float, int)):
return value
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.

Comment thread bilby/gw/utils.py Outdated
Comment thread bilby/gw/utils.py Outdated
Copy link
Copy Markdown
Collaborator

@GregoryAshton GregoryAshton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I got through about 60% of the diff and I'm pausing here so will submit the questions so far.

Comment thread bilby/compat/patches.py
Comment thread bilby/compat/patches.py Outdated
Comment thread bilby/core/prior/analytical.py
Comment thread bilby/core/prior/analytical.py
_cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu)
return _cdf
with np.errstate(divide="ignore"):
return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay - are the bounds being implemented here? But, I don't see the upper bound being implemented.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is carried over from the existing implementation.

Comment thread bilby/core/likelihood.py
Comment thread bilby/gw/detector/interferometer.py Outdated
Comment thread bilby/gw/detector/interferometer.py
@mj-will mj-will removed the to discuss To be discussed on an upcoming call label May 14, 2026
@ColmTalbot
Copy link
Copy Markdown
Collaborator Author

ColmTalbot commented May 14, 2026

Python 3.10 doesn't have support for a vmappable version of logsumexp through scipy leading to this job failing (https://github.com/bilby-dev/bilby/actions/runs/25883935510/job/76070707573?pr=886).

How do people feel about dropping support for Python 3.10 in Bilby 3? Numpy dropped support about a year ago.

@ColmTalbot ColmTalbot requested review from a team, GregoryAshton and mj-will May 19, 2026 13:56
@mj-will
Copy link
Copy Markdown
Collaborator

mj-will commented May 19, 2026

This looks great @ColmTalbot. i've not a proper look yet but one thought on using orng: I think ArrayRNG may be confusing so I'm considering changing it to something else (see sequince-dev/orng#9). Feel free to comment on the MR if you have thoughts.

Once that's in, I can make a stable release and get in on conda if that would be useful.

@ColmTalbot
Copy link
Copy Markdown
Collaborator Author

i've not a proper look yet but one thought on using orng: I think ArrayRNG may be confusing so I'm considering changing it to something else (see sequince-dev/orng#9). Feel free to comment on the MR if you have thoughts.

No meaningful thoughts, I agree it's probably a positive change and I'm happy to update once it is available.

Once that's in, I can make a stable release and get in on conda if that would be useful.

I think we'll definitely want this once we have a release ready. Although, I think I've refactored things so that orng doesn't need to be a top-level dependency.

Comment thread .github/workflows/basic-install.yml Outdated
Co-authored-by: Michael J. Williams <michaeljw1@googlemail.com>
Copy link
Copy Markdown
Collaborator

@mj-will mj-will left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure if all the priors are consistent with the previous versions. I've tried to go through them but may have missed some things or equally, got some logic wrong and they're fine.


def prob(self, val):
return (
xp.sign(2 * val - 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the behaviour at val=0.5 has changed. Previously it would return self.minimum where as now it will be zero since np.sign(0) = 0. I'm not sure, but I don't think this is right for a log-uniform.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hadn't occurred to me, but I actually think the old version is more incorrect. I'm happy to reinstate the old behaviour with xp.where, but it is worth documenting as I think this is non intuitive.

@xp_wrap
def cdf(self, val, *, xp=None):
asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum))
return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this has a similar problem at 0 since sign(0) = 0 and asymmetric=-inf

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we need to modify the definition of asymmetric to the following to make the CDF flat between [-min, min].

        asymmetric = xp.log(xp.maximum(xp.abs(val) / self.minimum, 1)) / xp.log(xp.asarray(self.maximum / self.minimum))

Comment on lines +753 to +757
def cdf(self, val, *, xp=None):
with np.errstate(divide="ignore"):
return 0.5 + erf(
(xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2)
) / 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it still return 0 for values below the minimum? I don't think it does but I might be missing something.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, max(val, minimum) is equivalent to a one-sided clip. This seems overly complicated though as minimum is hardcoded to zero, but that's a carry over.

Comment on lines +999 to +1001
ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val)
ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta))
return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure minimum and maximum are respected here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this seems pretty broken. What about this?

Suggested change
ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val)
ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta))
return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf)
normalized = (val - self.minimum) / (self.maximum - self.minimum)
ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -normalized) + xlogy(xp.asarray(self.alpha - 1.0), normalized)
ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta))
ln_prob -= xlogy(
xp.asarray(self.alpha + self.beta - 1),
xp.asarray(self.maximum - self.minimum),
)
return xp.where(
xp.abs(normalized - 0.5) <= 0.5,
xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf),
-xp.inf,
)

Comment on lines +1050 to +1052
with np.errstate(divide="ignore"):
val = xp.asarray(val)
return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0)))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don' think the upper boundary is consistent with the old implementation. I think it now gives -inf rather than +inf

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, log(max(1 / 0, 0)) = log(inf) = inf.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was thinking specifically of things above the bound, so:

log(max(negative, 0)) = log(0) = -inf

I agree the bound itself is still the same.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I don't know that we've been careful about out of bounds behaviour for these rescale methods in general. The behaviour here is the same as the behaviour before this change for scalars, but not arrays.

Comment on lines +178 to +189
if random_state is None or not BILBY_ARRAY_API:
return np
elif isinstance(random_state, np.random.Generator):
return np
elif aac.is_jax_array(random_state) or getattr(random_state, "backend") == "jax":
import jax.numpy as jnp
return jnp
elif aac.is_torch_array(random_state) or getattr(random_state, "backend") == "torch":
import torch
return torch
else:
return np
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be using the wrapped modules, e.g. array_api_compat.numpy?

From the docs I think this is preferred:

Wrapped array namespaces can also be imported directly. For example, array_namespace(np.array(...)) will return array_api_compat.numpy. This function will also work for any array library not wrapped by array-api-compat if it explicitly defines array_namespace (the wrapped namespace is always preferred if it exists).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against doing this, but I'm not really sure what the benefit is. I guess it would reduce the number of explicit imports.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main benefit is the APIs are standardized and the code can therefore be more generic.

For example, the methods will support some additional keywords like device:

import array_api_compat.numpy as np
device = "cpu"
x = [1, 2, 3]
np.asarray(x, device=device)

This doesn't do anything for numpy in this case, but doesn't raise an error like it would for standard numpy. This means the exact same code would work for torch with e.g. a CUDA device.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll try to use this where I can, but I think we should avoid using the array_api_compat version of numpy if BILBY_ARRAY_API == False.

FWIW, this specific example has worked with standard numpy since version 2.0.0.

Comment on lines +191 to +192
white_noise = xpx.at(white_noise, 0).set(0)
white_noise = xpx.at(white_noise, -1).set(0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be missing something, but it seems the case where the number of samples is odd is no longer handled.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the answer is that the old function was broken... and always returned yes in this condition.
We could add it back in, but it does break the specific value tests we have.

Comment thread bilby/core/grid.py
"""

if xp is None:
xp = np
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this default to aac.numpy?

Comment thread bilby/core/grid.py
raise TypeError("Parameters names must be a list or string")

out_array = log_array.copy()
out_array = copy(log_array)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure a plain copy works for torch:

>>> import torch
>>> import copy
>>> x = torch.Tensor([1.0])
>>> y = copy.copy(x)
>>> y[0] = 0
>>> x
tensor([0.])
>>> 

I'm not sure what the correct way is to handle this, in other packages I've resorted to helpers but I suspect there's a better way. For example: https://github.com/mj-will/aspire/blob/main/src/aspire/utils.py#L456

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigh

Comment thread bilby/core/grid.py
Copy link
Copy Markdown
Collaborator Author

@ColmTalbot ColmTalbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments @mj-will I'll try to fold in the changes before next week.


def prob(self, val):
return (
xp.sign(2 * val - 1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hadn't occurred to me, but I actually think the old version is more incorrect. I'm happy to reinstate the old behaviour with xp.where, but it is worth documenting as I think this is non intuitive.

Comment on lines +753 to +757
def cdf(self, val, *, xp=None):
with np.errstate(divide="ignore"):
return 0.5 + erf(
(xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2)
) / 2
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, max(val, minimum) is equivalent to a one-sided clip. This seems overly complicated though as minimum is hardcoded to zero, but that's a carry over.

Comment on lines +999 to +1001
ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val)
ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta))
return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this seems pretty broken. What about this?

Suggested change
ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val)
ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta))
return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf)
normalized = (val - self.minimum) / (self.maximum - self.minimum)
ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -normalized) + xlogy(xp.asarray(self.alpha - 1.0), normalized)
ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta))
ln_prob -= xlogy(
xp.asarray(self.alpha + self.beta - 1),
xp.asarray(self.maximum - self.minimum),
)
return xp.where(
xp.abs(normalized - 0.5) <= 0.5,
xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf),
-xp.inf,
)

Comment on lines +1050 to +1052
with np.errstate(divide="ignore"):
val = xp.asarray(val)
return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0)))
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, log(max(1 / 0, 0)) = log(inf) = inf.

@xp_wrap
def cdf(self, val, *, xp=None):
asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum))
return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we need to modify the definition of asymmetric to the following to make the CDF flat between [-min, min].

        asymmetric = xp.log(xp.maximum(xp.abs(val) / self.minimum, 1)) / xp.log(xp.asarray(self.maximum / self.minimum))

Comment on lines +191 to +192
white_noise = xpx.at(white_noise, 0).set(0)
white_noise = xpx.at(white_noise, -1).set(0)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the answer is that the old function was broken... and always returned yes in this condition.
We could add it back in, but it does break the specific value tests we have.

Comment thread bilby/core/grid.py
raise TypeError("Parameters names must be a list or string")

out_array = log_array.copy()
out_array = copy(log_array)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigh

Comment on lines +178 to +189
if random_state is None or not BILBY_ARRAY_API:
return np
elif isinstance(random_state, np.random.Generator):
return np
elif aac.is_jax_array(random_state) or getattr(random_state, "backend") == "jax":
import jax.numpy as jnp
return jnp
elif aac.is_torch_array(random_state) or getattr(random_state, "backend") == "torch":
import torch
return torch
else:
return np
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against doing this, but I'm not really sure what the benefit is. I guess it would reduce the number of explicit imports.

@ColmTalbot ColmTalbot requested a review from a team May 22, 2026 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants