Support non-numpy array backends#886
Conversation
ea348fa to
771a8a9
Compare
|
This is now ready for review. There are a lot of changes, but most of them are essentially Bilby can once again be installed without I've managed to keep test changes minimal:
|
mj-will
left a comment
There was a problem hiding this comment.
Some initial comments but I'll need to have another look.
| This maps to the inverse CDF. This has been analytically solved for this case. | ||
| """ | ||
| return gammaincinv(self.k, val) * self.theta | ||
| return xp.asarray(gammaincinv(self.k, val)) * self.theta |
There was a problem hiding this comment.
Does this mean this is falling back to numpy?
There was a problem hiding this comment.
Yeah, I should update/recheck this, but at least jax doesn't have good support for this, but it looks like tensorflow has a version that numpyro uses (jax-ml/jax#5350). cupy does have this function, so this workaround may have just been for jax. I could add a BackendNotImplementedError.
There was a problem hiding this comment.
Would this be a candidate for a small patch that uses the TF version for jax until jax supports it natively?
| self[key].least_recently_sampled = result[key] | ||
| if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: | ||
| joint[self[key].dist.distname] = [key] | ||
| elif isinstance(self[key], JointPrior): | ||
| joint[self[key].dist.distname].append(key) | ||
| for names in joint.values(): | ||
| # this is needed to unpack how joint prior rescaling works | ||
| # as an example of a joint prior over {a, b, c, d} we might | ||
| # get the following based on the order within the joint prior | ||
| # {a: [], b: [], c: [1, 2, 3, 4], d: []} | ||
| # -> [1, 2, 3, 4] | ||
| # -> {a: 1, b: 2, c: 3, d: 4} | ||
| values = list() | ||
| for key in names: | ||
| values = np.concatenate([values, result[key]]) | ||
| for key, value in zip(names, values): | ||
| result[key] = value | ||
|
|
||
| def safe_flatten(value): | ||
| """ | ||
| this is gross but can be removed whenever we switch to returning | ||
| arrays, flatten converts 0-d arrays to 1-d so has to be special | ||
| cased | ||
| """ | ||
| if isinstance(value, (float, int)): | ||
| return value |
There was a problem hiding this comment.
Is removing this intentional?
There was a problem hiding this comment.
Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.
ColmTalbot
left a comment
There was a problem hiding this comment.
Thanks for the initial comments @mj-will I'll take a pass at them ASAP.
| self[key].least_recently_sampled = result[key] | ||
| if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: | ||
| joint[self[key].dist.distname] = [key] | ||
| elif isinstance(self[key], JointPrior): | ||
| joint[self[key].dist.distname].append(key) | ||
| for names in joint.values(): | ||
| # this is needed to unpack how joint prior rescaling works | ||
| # as an example of a joint prior over {a, b, c, d} we might | ||
| # get the following based on the order within the joint prior | ||
| # {a: [], b: [], c: [1, 2, 3, 4], d: []} | ||
| # -> [1, 2, 3, 4] | ||
| # -> {a: 1, b: 2, c: 3, d: 4} | ||
| values = list() | ||
| for key in names: | ||
| values = np.concatenate([values, result[key]]) | ||
| for key, value in zip(names, values): | ||
| result[key] = value | ||
|
|
||
| def safe_flatten(value): | ||
| """ | ||
| this is gross but can be removed whenever we switch to returning | ||
| arrays, flatten converts 0-d arrays to 1-d so has to be special | ||
| cased | ||
| """ | ||
| if isinstance(value, (float, int)): | ||
| return value |
There was a problem hiding this comment.
Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.
GregoryAshton
left a comment
There was a problem hiding this comment.
Okay, I got through about 60% of the diff and I'm pausing here so will submit the questions so far.
| _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) | ||
| return _cdf | ||
| with np.errstate(divide="ignore"): | ||
| return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum) |
There was a problem hiding this comment.
Ah okay - are the bounds being implemented here? But, I don't see the upper bound being implemented.
There was a problem hiding this comment.
I think this is carried over from the existing implementation.
|
Python 3.10 doesn't have support for a vmappable version of logsumexp through scipy leading to this job failing (https://github.com/bilby-dev/bilby/actions/runs/25883935510/job/76070707573?pr=886). How do people feel about dropping support for Python 3.10 in Bilby 3? Numpy dropped support about a year ago. |
This required making some changes to the tests for conditional dicts as I've changed the output types and the backend introspection doesn't work on dict_items for some reason
|
This looks great @ColmTalbot. i've not a proper look yet but one thought on using Once that's in, I can make a stable release and get in on conda if that would be useful. |
No meaningful thoughts, I agree it's probably a positive change and I'm happy to update once it is available.
I think we'll definitely want this once we have a release ready. Although, I think I've refactored things so that orng doesn't need to be a top-level dependency. |
Co-authored-by: Michael J. Williams <michaeljw1@googlemail.com>
mj-will
left a comment
There was a problem hiding this comment.
I'm unsure if all the priors are consistent with the previous versions. I've tried to go through them but may have missed some things or equally, got some logic wrong and they're fine.
|
|
||
| def prob(self, val): | ||
| return ( | ||
| xp.sign(2 * val - 1) |
There was a problem hiding this comment.
I think the behaviour at val=0.5 has changed. Previously it would return self.minimum where as now it will be zero since np.sign(0) = 0. I'm not sure, but I don't think this is right for a log-uniform.
There was a problem hiding this comment.
This hadn't occurred to me, but I actually think the old version is more incorrect. I'm happy to reinstate the old behaviour with xp.where, but it is worth documenting as I think this is non intuitive.
| @xp_wrap | ||
| def cdf(self, val, *, xp=None): | ||
| asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) | ||
| return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) |
There was a problem hiding this comment.
I think this has a similar problem at 0 since sign(0) = 0 and asymmetric=-inf
There was a problem hiding this comment.
Yeah, I think we need to modify the definition of asymmetric to the following to make the CDF flat between [-min, min].
asymmetric = xp.log(xp.maximum(xp.abs(val) / self.minimum, 1)) / xp.log(xp.asarray(self.maximum / self.minimum))| def cdf(self, val, *, xp=None): | ||
| with np.errstate(divide="ignore"): | ||
| return 0.5 + erf( | ||
| (xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2) | ||
| ) / 2 |
There was a problem hiding this comment.
Does it still return 0 for values below the minimum? I don't think it does but I might be missing something.
There was a problem hiding this comment.
Yeah, max(val, minimum) is equivalent to a one-sided clip. This seems overly complicated though as minimum is hardcoded to zero, but that's a carry over.
| ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) | ||
| ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) | ||
| return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) |
There was a problem hiding this comment.
I'm sure minimum and maximum are respected here.
There was a problem hiding this comment.
Yeah, this seems pretty broken. What about this?
| ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) | |
| ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) | |
| return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) | |
| normalized = (val - self.minimum) / (self.maximum - self.minimum) | |
| ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -normalized) + xlogy(xp.asarray(self.alpha - 1.0), normalized) | |
| ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) | |
| ln_prob -= xlogy( | |
| xp.asarray(self.alpha + self.beta - 1), | |
| xp.asarray(self.maximum - self.minimum), | |
| ) | |
| return xp.where( | |
| xp.abs(normalized - 0.5) <= 0.5, | |
| xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf), | |
| -xp.inf, | |
| ) |
| with np.errstate(divide="ignore"): | ||
| val = xp.asarray(val) | ||
| return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0))) |
There was a problem hiding this comment.
I don' think the upper boundary is consistent with the old implementation. I think it now gives -inf rather than +inf
There was a problem hiding this comment.
I don't think so, log(max(1 / 0, 0)) = log(inf) = inf.
There was a problem hiding this comment.
Ah, I was thinking specifically of things above the bound, so:
log(max(negative, 0)) = log(0) = -inf
I agree the bound itself is still the same.
There was a problem hiding this comment.
Oh, I don't know that we've been careful about out of bounds behaviour for these rescale methods in general. The behaviour here is the same as the behaviour before this change for scalars, but not arrays.
| if random_state is None or not BILBY_ARRAY_API: | ||
| return np | ||
| elif isinstance(random_state, np.random.Generator): | ||
| return np | ||
| elif aac.is_jax_array(random_state) or getattr(random_state, "backend") == "jax": | ||
| import jax.numpy as jnp | ||
| return jnp | ||
| elif aac.is_torch_array(random_state) or getattr(random_state, "backend") == "torch": | ||
| import torch | ||
| return torch | ||
| else: | ||
| return np |
There was a problem hiding this comment.
Should this be using the wrapped modules, e.g. array_api_compat.numpy?
From the docs I think this is preferred:
Wrapped array namespaces can also be imported directly. For example, array_namespace(np.array(...)) will return array_api_compat.numpy. This function will also work for any array library not wrapped by array-api-compat if it explicitly defines array_namespace (the wrapped namespace is always preferred if it exists).
There was a problem hiding this comment.
I'm not against doing this, but I'm not really sure what the benefit is. I guess it would reduce the number of explicit imports.
There was a problem hiding this comment.
I think the main benefit is the APIs are standardized and the code can therefore be more generic.
For example, the methods will support some additional keywords like device:
import array_api_compat.numpy as np
device = "cpu"
x = [1, 2, 3]
np.asarray(x, device=device)
This doesn't do anything for numpy in this case, but doesn't raise an error like it would for standard numpy. This means the exact same code would work for torch with e.g. a CUDA device.
There was a problem hiding this comment.
Okay, I'll try to use this where I can, but I think we should avoid using the array_api_compat version of numpy if BILBY_ARRAY_API == False.
FWIW, this specific example has worked with standard numpy since version 2.0.0.
| white_noise = xpx.at(white_noise, 0).set(0) | ||
| white_noise = xpx.at(white_noise, -1).set(0) |
There was a problem hiding this comment.
May be missing something, but it seems the case where the number of samples is odd is no longer handled.
There was a problem hiding this comment.
I think the answer is that the old function was broken... and always returned yes in this condition.
We could add it back in, but it does break the specific value tests we have.
| """ | ||
|
|
||
| if xp is None: | ||
| xp = np |
There was a problem hiding this comment.
Should this default to aac.numpy?
| raise TypeError("Parameters names must be a list or string") | ||
|
|
||
| out_array = log_array.copy() | ||
| out_array = copy(log_array) |
There was a problem hiding this comment.
I'm not sure a plain copy works for torch:
>>> import torch
>>> import copy
>>> x = torch.Tensor([1.0])
>>> y = copy.copy(x)
>>> y[0] = 0
>>> x
tensor([0.])
>>>
I'm not sure what the correct way is to handle this, in other packages I've resorted to helpers but I suspect there's a better way. For example: https://github.com/mj-will/aspire/blob/main/src/aspire/utils.py#L456
ColmTalbot
left a comment
There was a problem hiding this comment.
Thanks for the comments @mj-will I'll try to fold in the changes before next week.
|
|
||
| def prob(self, val): | ||
| return ( | ||
| xp.sign(2 * val - 1) |
There was a problem hiding this comment.
This hadn't occurred to me, but I actually think the old version is more incorrect. I'm happy to reinstate the old behaviour with xp.where, but it is worth documenting as I think this is non intuitive.
| def cdf(self, val, *, xp=None): | ||
| with np.errstate(divide="ignore"): | ||
| return 0.5 + erf( | ||
| (xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2) | ||
| ) / 2 |
There was a problem hiding this comment.
Yeah, max(val, minimum) is equivalent to a one-sided clip. This seems overly complicated though as minimum is hardcoded to zero, but that's a carry over.
| ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) | ||
| ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) | ||
| return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) |
There was a problem hiding this comment.
Yeah, this seems pretty broken. What about this?
| ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) | |
| ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) | |
| return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) | |
| normalized = (val - self.minimum) / (self.maximum - self.minimum) | |
| ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -normalized) + xlogy(xp.asarray(self.alpha - 1.0), normalized) | |
| ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) | |
| ln_prob -= xlogy( | |
| xp.asarray(self.alpha + self.beta - 1), | |
| xp.asarray(self.maximum - self.minimum), | |
| ) | |
| return xp.where( | |
| xp.abs(normalized - 0.5) <= 0.5, | |
| xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf), | |
| -xp.inf, | |
| ) |
| with np.errstate(divide="ignore"): | ||
| val = xp.asarray(val) | ||
| return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0))) |
There was a problem hiding this comment.
I don't think so, log(max(1 / 0, 0)) = log(inf) = inf.
| @xp_wrap | ||
| def cdf(self, val, *, xp=None): | ||
| asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) | ||
| return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) |
There was a problem hiding this comment.
Yeah, I think we need to modify the definition of asymmetric to the following to make the CDF flat between [-min, min].
asymmetric = xp.log(xp.maximum(xp.abs(val) / self.minimum, 1)) / xp.log(xp.asarray(self.maximum / self.minimum))| white_noise = xpx.at(white_noise, 0).set(0) | ||
| white_noise = xpx.at(white_noise, -1).set(0) |
There was a problem hiding this comment.
I think the answer is that the old function was broken... and always returned yes in this condition.
We could add it back in, but it does break the specific value tests we have.
| raise TypeError("Parameters names must be a list or string") | ||
|
|
||
| out_array = log_array.copy() | ||
| out_array = copy(log_array) |
| if random_state is None or not BILBY_ARRAY_API: | ||
| return np | ||
| elif isinstance(random_state, np.random.Generator): | ||
| return np | ||
| elif aac.is_jax_array(random_state) or getattr(random_state, "backend") == "jax": | ||
| import jax.numpy as jnp | ||
| return jnp | ||
| elif aac.is_torch_array(random_state) or getattr(random_state, "backend") == "torch": | ||
| import torch | ||
| return torch | ||
| else: | ||
| return np |
There was a problem hiding this comment.
I'm not against doing this, but I'm not really sure what the benefit is. I guess it would reduce the number of explicit imports.
I've been working on this PR on and off for a few months, it isn't ready yet, but I wanted to share it in case other people had early opinions.
The goal is to make it easier to interface with models/samplers implemented in e.g., JAX, that support GPU/TPU acceleration and JIT compilation.
The general guiding principles are:
array-apispecification andscipyinteroperabilityThe primary changes so far are:
Changed behaviour:
Remaining issues:
bilby.gw.jaxstufffile should be removed and relevant functionality be moved elsewhere, it's currently just used for testing