Skip to content
Closed
20 changes: 11 additions & 9 deletions gwpopulation/backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from importlib import import_module
from types import ModuleType
from typing import Any

__backend__ = ""
SUPPORTED_BACKENDS = ["numpy", "cupy", "jax"]
_np_module = dict(numpy="numpy", cupy="cupy", jax="jax.numpy")
_scipy_module = dict(numpy="scipy", cupy="cupyx.scipy", jax="jax.scipy")
__backend__: str = ""
SUPPORTED_BACKENDS: list[str] = ["numpy", "cupy", "jax"]
_np_module: dict[str, str] = dict(numpy="numpy", cupy="cupy", jax="jax.numpy")
_scipy_module: dict[str, str] = dict(numpy="scipy", cupy="cupyx.scipy", jax="jax.scipy")

__all__ = [
"SUPPORTED_BACKENDS",
Expand Down Expand Up @@ -45,7 +47,7 @@
"""


def modules_to_update():
def modules_to_update() -> tuple[list[str], list[str], dict[str, str]]:
"""
Return all modules that need to be updated with the backend.

Expand Down Expand Up @@ -76,7 +78,7 @@ def modules_to_update():
return all_with_xp, all_with_scs, others


def _configure_jax(xp):
def _configure_jax(xp: ModuleType) -> None:
"""
Configuration requirements for :code:`jax`

Expand All @@ -87,7 +89,7 @@ def _configure_jax(xp):
config.update("jax_enable_x64", True)


def _load_numpy_and_scipy(backend):
def _load_numpy_and_scipy(backend: str) -> tuple[ModuleType, ModuleType]:
try:
xp = import_module(_np_module[backend])
scs = import_module(_scipy_module[backend]).special
Expand All @@ -102,7 +104,7 @@ def _load_numpy_and_scipy(backend):
return xp, scs


def _load_arbitrary(func, backend):
def _load_arbitrary(func: str, backend: str) -> Any:
if func.startswith("scipy"):
func = func.replace("scipy", _scipy_module[backend])
elif func.startswith("numpy"):
Expand All @@ -111,7 +113,7 @@ def _load_arbitrary(func, backend):
return getattr(import_module(module), func)


def set_backend(backend="numpy"):
def set_backend(backend: str = "numpy") -> None:
"""
Set the backend for :code:`GWPopulation` and plugins.

Expand Down
14 changes: 8 additions & 6 deletions gwpopulation/experimental/cosmo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
can be used to add cosmological functionality to a population model.
"""

from typing import Any

import numpy as xp
from bilby.hyper.model import Model
from wcosmo import z_at_value
Expand All @@ -24,7 +26,7 @@ class CosmoMixin:
Should be of :code:`wcosmo.available.keys()`.
"""

def __init__(self, cosmo_model="Planck15"):
def __init__(self, cosmo_model: str = "Planck15") -> None:
wcosmo_disable_units()
self.cosmo_model = cosmo_model
if self.cosmo_model == "FlatwCDM":
Expand All @@ -35,7 +37,7 @@ def __init__(self, cosmo_model="Planck15"):
self.cosmology_names = []
self._cosmo = available[cosmo_model]

def cosmology_variables(self, parameters):
def cosmology_variables(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
Extract the cosmological parameters from the provided parameters.

Expand All @@ -51,7 +53,7 @@ def cosmology_variables(self, parameters):
"""
return {key: parameters[key] for key in self.cosmology_names}

def cosmology(self, parameters):
def cosmology(self, parameters: dict[str, Any]) -> WCosmoMixin:
"""
Return the cosmology model given the parameters.

Expand All @@ -70,7 +72,7 @@ def cosmology(self, parameters):
else:
return self._cosmo(**self.cosmology_variables(parameters))

def detector_frame_to_source_frame(self, data, **parameters):
def detector_frame_to_source_frame(self, data: dict[str, Any], **parameters: Any) -> tuple[dict[str, Any], Any]:
r"""
Convert detector frame samples to source frame samples given cosmological
parameters. Calculate the corresponding
Expand Down Expand Up @@ -133,11 +135,11 @@ class CosmoModel(Model, CosmoMixin):
Should be of :code:`wcosmo.available.keys()`.
"""

def __init__(self, model_functions=None, cosmo_model="Planck15"):
def __init__(self, model_functions: list[Any] | None = None, cosmo_model: str = "Planck15") -> None:
Model.__init__(self, model_functions=model_functions, cache=False)
CosmoMixin.__init__(self, cosmo_model=cosmo_model)

def prob(self, data, **kwargs):
def prob(self, data: dict[str, Any], **kwargs: Any) -> Any:
"""
Compute the total population probability for the provided data given
the keyword arguments.
Expand Down
12 changes: 7 additions & 5 deletions gwpopulation/experimental/jax.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import warnings
from collections.abc import Callable
from copy import deepcopy
from functools import partial
from typing import Any

import numpy as np
from bilby.core.likelihood import Likelihood


def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True):
def generic_bilby_likelihood_function(likelihood: Likelihood, parameters: dict[str, Any], use_ratio: bool = True) -> float:
"""
A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`.

Expand Down Expand Up @@ -48,8 +50,8 @@ class JittedLikelihood(Likelihood):
"""

def __init__(
self, likelihood, likelihood_func=generic_bilby_likelihood_function, kwargs=None
):
self, likelihood: Likelihood, likelihood_func: Callable = generic_bilby_likelihood_function, kwargs: dict[str, Any] | None = None
) -> None:
from jax import jit

if kwargs is None:
Expand All @@ -59,8 +61,8 @@ def __init__(
self.likelihood_func = jit(partial(likelihood_func, likelihood))
super().__init__()

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return getattr(self._likelihood, name)

def log_likelihood_ratio(self, parameters):
def log_likelihood_ratio(self, parameters: dict[str, Any]) -> float:
return float(np.nan_to_num(self.likelihood_func(parameters, **self.kwargs)))
20 changes: 11 additions & 9 deletions gwpopulation/experimental/numpyro.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -11,13 +13,13 @@


def gwpopulation_likelihood_model(
likelihood,
hyper_params,
varmax=np.inf,
apply_selection=True,
predictive_resample=True,
fit_keys=None,
):
likelihood: Any,
hyper_params: dict[str, Any],
varmax: float = np.inf,
apply_selection: bool = True,
predictive_resample: bool = True,
fit_keys: list[str] | None = None,
) -> Any:
"""
A :code:`numpyro` implementation of :func:`gwpopulation.hyperpe.HyperparameterLikelihood`.

Expand Down Expand Up @@ -106,8 +108,8 @@ def gwpopulation_likelihood_model(


def posterior_predictive_resample(
data, weights, label, event=None, shape=(), fit_keys=None
):
data: dict[str, Any], weights: Any, label: str | int, event: int | None = None, shape: tuple = (), fit_keys: list[str] | None = None
) -> None:
"""
Draw a sample from the predictive distribution given the weights for each sample.

Expand Down
54 changes: 28 additions & 26 deletions gwpopulation/hyperpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"""

import types
from collections.abc import Callable
from typing import Any

import numpy as np
from bilby.core.likelihood import Likelihood
Expand Down Expand Up @@ -71,14 +73,14 @@ class HyperparameterLikelihood(Likelihood):

def __init__(
self,
posteriors,
hyper_prior,
ln_evidences=None,
max_samples=1e100,
selection_function=lambda args: 1,
conversion_function=lambda args: (args, None),
maximum_uncertainty=xp.inf,
):
posteriors: list[Any],
hyper_prior: Model | Callable,
ln_evidences: list[float] | None = None,
max_samples: float = 1e100,
selection_function: Callable[[dict[str, Any]], float | tuple[float, float]] = lambda args: 1.0,
conversion_function: Callable[[dict[str, Any]], tuple[dict[str, Any], Any]] = lambda args: (args, None),
maximum_uncertainty: float = xp.inf,
) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -144,7 +146,7 @@ def __init__(
__doc__ += __init__.__doc__

@property
def maximum_uncertainty(self):
def maximum_uncertainty(self) -> float:
"""
The maximum allowed uncertainty in the estimate of the log-likelihood.
If the uncertainty is larger than this value a log likelihood of -inf
Expand All @@ -153,14 +155,14 @@ def maximum_uncertainty(self):
return self._maximum_uncertainty

@maximum_uncertainty.setter
def maximum_uncertainty(self, value):
def maximum_uncertainty(self, value: float) -> None:
self._maximum_uncertainty = value
if value in [xp.inf, np.inf]:
self._max_variance = value
else:
self._max_variance = value**2

def ln_likelihood_and_variance(self, parameters):
def ln_likelihood_and_variance(self, parameters: dict[str, Any]) -> tuple[Any, float]:
"""
Compute the ln likelihood estimator and its variance.
"""
Expand All @@ -177,23 +179,23 @@ def ln_likelihood_and_variance(self, parameters):
ln_l += selection
return ln_l, to_number(variance, float)

def log_likelihood_ratio(self, parameters):
def log_likelihood_ratio(self, parameters: dict[str, Any]) -> float:
ln_l, variance = self.ln_likelihood_and_variance(parameters=parameters)
ln_l = xp.nan_to_num(ln_l, nan=-xp.inf)
ln_l -= xp.nan_to_num(xp.inf * (self.maximum_uncertainty < variance), nan=0)
return to_number(xp.nan_to_num(ln_l), float)

def noise_log_likelihood(self):
def noise_log_likelihood(self) -> float:
return self.total_noise_evidence

def log_likelihood(self, parameters):
def log_likelihood(self, parameters: dict[str, Any]) -> float:
return self.noise_log_likelihood() + self.log_likelihood_ratio(
parameters=parameters
)

def _compute_per_event_ln_bayes_factors(
self, parameters, *, return_uncertainty=True
):
self, parameters: dict[str, Any], *, return_uncertainty: bool = True
) -> Any | tuple[Any, Any]:
weights = self.hyper_prior.prob(self.data, **parameters) / self.sampling_prior
expectation = xp.mean(weights, axis=-1)
if return_uncertainty:
Expand All @@ -205,7 +207,7 @@ def _compute_per_event_ln_bayes_factors(
else:
return xp.log(expectation)

def _get_selection_factor(self, parameters, *, return_uncertainty=True):
def _get_selection_factor(self, parameters: dict[str, Any], *, return_uncertainty: bool = True) -> Any | tuple[Any, Any]:
selection, variance = self._selection_function_with_uncertainty(
parameters=parameters
)
Expand All @@ -216,7 +218,7 @@ def _get_selection_factor(self, parameters, *, return_uncertainty=True):
else:
return total_selection

def _selection_function_with_uncertainty(self, parameters):
def _selection_function_with_uncertainty(self, parameters: dict[str, Any]) -> tuple[Any, Any]:
result = self.selection_function(parameters)
if isinstance(result, tuple):
selection, variance = result
Expand All @@ -225,7 +227,7 @@ def _selection_function_with_uncertainty(self, parameters):
variance = 0.0
return selection, variance

def generate_extra_statistics(self, sample):
def generate_extra_statistics(self, sample: dict[str, Any]) -> dict[str, Any]:
r"""
Given an input sample, add extra statistics

Expand Down Expand Up @@ -274,7 +276,7 @@ def generate_extra_statistics(self, sample):
sample["variance"] = to_number(total_variance, float)
return sample

def generate_rate_posterior_sample(self, parameters):
def generate_rate_posterior_sample(self, parameters: dict[str, Any]) -> float:
r"""
Generate a sample from the posterior distribution for rate assuming a
:math:`1 / R` prior.
Expand Down Expand Up @@ -311,7 +313,7 @@ def generate_rate_posterior_sample(self, parameters):
rate = gamma(a=self.n_posteriors).rvs() / vt
return rate

def resample_posteriors(self, posteriors, max_samples=1e300):
def resample_posteriors(self, posteriors: list[Any], max_samples: float = 1e300) -> dict[str, Any]:
"""
Convert list of pandas DataFrame object to dict of arrays.

Expand Down Expand Up @@ -342,7 +344,7 @@ def resample_posteriors(self, posteriors, max_samples=1e300):
data[key] = xp.array(data[key])
return data

def posterior_predictive_resample(self, samples, return_weights=False):
def posterior_predictive_resample(self, samples: Any, return_weights: bool = False) -> dict[str, Any] | tuple[dict[str, Any], Any]:
"""
Resample the original single event posteriors to use the PPD from each
of the other events as the prior.
Expand Down Expand Up @@ -420,7 +422,7 @@ def posterior_predictive_resample(self, samples, return_weights=False):
return new_samples

@property
def meta_data(self):
def meta_data(self) -> dict[str, Any]:
return dict(
model=[get_name(model) for model in self.hyper_prior.models],
data={key: to_numpy(self.data[key]) for key in self.data},
Expand All @@ -442,7 +444,7 @@ class RateLikelihood(HyperparameterLikelihood):

__doc__ += HyperparameterLikelihood.__init__.__doc__

def _get_selection_factor(self, parameters, *, return_uncertainty=True):
def _get_selection_factor(self, parameters: dict[str, Any], *, return_uncertainty: bool = True) -> Any | tuple[Any, Any]:
r"""
The selection factor for the rate likelihood is

Expand Down Expand Up @@ -472,7 +474,7 @@ def _get_selection_factor(self, parameters, *, return_uncertainty=True):
else:
return total_selection

def generate_rate_posterior_sample(self, parameters):
def generate_rate_posterior_sample(self, parameters: dict[str, Any]) -> Any:
"""
Since the rate is a sampled parameter,
this simply returns the current value of the rate parameter.
Expand All @@ -490,7 +492,7 @@ class NullHyperparameterLikelihood(HyperparameterLikelihood):
`Farr <https://arxiv.org/abs/1904.10879>`_.
"""

def ln_likelihood_and_variance(self, parameters=None):
def ln_likelihood_and_variance(self, parameters: dict[str, Any] | None = None) -> tuple[float, float]:
"""
Compute the ln likelihood estimator and its variance.
"""
Expand Down
Loading
Loading