From 8a9a4ee7131724b42a4617a9e8f97a513232f3d6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 22:53:21 +0000 Subject: [PATCH 01/11] Initial plan From 16fb1199a456181d56b07ef93406cc48c747e4ef Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 22:57:38 +0000 Subject: [PATCH 02/11] Add type hints to backend.py and utils.py Co-authored-by: ColmTalbot <25602909+ColmTalbot@users.noreply.github.com> --- gwpopulation/backend.py | 20 +++++++++++--------- gwpopulation/utils.py | 30 ++++++++++++++++-------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/gwpopulation/backend.py b/gwpopulation/backend.py index 3dd127a8..52af8e9c 100644 --- a/gwpopulation/backend.py +++ b/gwpopulation/backend.py @@ -1,9 +1,11 @@ from importlib import import_module +from types import ModuleType +from typing import Any -__backend__ = "" -SUPPORTED_BACKENDS = ["numpy", "cupy", "jax"] -_np_module = dict(numpy="numpy", cupy="cupy", jax="jax.numpy") -_scipy_module = dict(numpy="scipy", cupy="cupyx.scipy", jax="jax.scipy") +__backend__: str = "" +SUPPORTED_BACKENDS: list[str] = ["numpy", "cupy", "jax"] +_np_module: dict[str, str] = dict(numpy="numpy", cupy="cupy", jax="jax.numpy") +_scipy_module: dict[str, str] = dict(numpy="scipy", cupy="cupyx.scipy", jax="jax.scipy") __all__ = [ "SUPPORTED_BACKENDS", @@ -45,7 +47,7 @@ """ -def modules_to_update(): +def modules_to_update() -> tuple[list[str], list[str], dict[str, str]]: """ Return all modules that need to be updated with the backend. @@ -76,7 +78,7 @@ def modules_to_update(): return all_with_xp, all_with_scs, others -def _configure_jax(xp): +def _configure_jax(xp: ModuleType) -> None: """ Configuration requirements for :code:`jax` @@ -87,7 +89,7 @@ def _configure_jax(xp): config.update("jax_enable_x64", True) -def _load_numpy_and_scipy(backend): +def _load_numpy_and_scipy(backend: str) -> tuple[ModuleType, ModuleType]: try: xp = import_module(_np_module[backend]) scs = import_module(_scipy_module[backend]).special @@ -102,7 +104,7 @@ def _load_numpy_and_scipy(backend): return xp, scs -def _load_arbitrary(func, backend): +def _load_arbitrary(func: str, backend: str) -> Any: if func.startswith("scipy"): func = func.replace("scipy", _scipy_module[backend]) elif func.startswith("numpy"): @@ -111,7 +113,7 @@ def _load_arbitrary(func, backend): return getattr(import_module(module), func) -def set_backend(backend="numpy"): +def set_backend(backend: str = "numpy") -> None: """ Set the backend for :code:`GWPopulation` and plugins. diff --git a/gwpopulation/utils.py b/gwpopulation/utils.py index b8a909af..aca9003f 100644 --- a/gwpopulation/utils.py +++ b/gwpopulation/utils.py @@ -2,16 +2,18 @@ Helper functions for probability distributions and backend switching. """ +from collections.abc import Callable from numbers import Number from operator import ge, gt, ne +from typing import Any import numpy as np from scipy import special as scs -xp = np +xp: Any = np -def apply_conditions(conditions): +def apply_conditions(conditions: dict[str, Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ A decorator to apply conditions to inputs of a function. @@ -29,9 +31,9 @@ def apply_conditions(conditions): """ from functools import wraps - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: Any, **kwargs: Any) -> Any: if "jax" in xp.__name__: return func(*args, **kwargs) for key, condition in conditions.items(): @@ -61,7 +63,7 @@ def wrapped_function(*args, **kwargs): @apply_conditions(dict(alpha=(gt, 0), beta=(gt, 0), scale=(gt, 0))) -def beta_dist(xx, alpha, beta, scale=1): +def beta_dist(xx: Any, alpha: float, beta: float, scale: float | Any = 1) -> Any: r""" Beta distribution probability @@ -95,7 +97,7 @@ def beta_dist(xx, alpha, beta, scale=1): @apply_conditions(dict(low=(ge, 0), alpha=(ne, 1))) -def powerlaw(xx, alpha, high, low): +def powerlaw(xx: Any, alpha: float | Any, high: float | Any, low: float | Any) -> Any: r""" Power-law probability @@ -131,7 +133,7 @@ def powerlaw(xx, alpha, high, low): @apply_conditions(dict(sigma=(gt, 0))) -def truncnorm(xx, mu, sigma, high, low): +def truncnorm(xx: Any, mu: float | Any, sigma: float, high: float | Any, low: float | Any) -> Any: r""" Truncated normal probability @@ -161,7 +163,7 @@ def truncnorm(xx, mu, sigma, high, low): """ - def logsubexp(log_p, log_q): + def logsubexp(log_p: Any, log_q: Any) -> Any: return log_p + xp.log(1 - xp.exp(log_q - log_p)) zz = xp.array(xx - mu) / sigma @@ -183,7 +185,7 @@ def logsubexp(log_p, log_q): return xp.nan_to_num(xp.exp(log_pdf)) * (xx >= low) * (xx <= high) -def unnormalized_2d_gaussian(xx, yy, mu_x, mu_y, sigma_x, sigma_y, covariance): +def unnormalized_2d_gaussian(xx: Any, yy: Any, mu_x: float, mu_y: float, sigma_x: float, sigma_y: float, covariance: float) -> Any: r""" Compute the probability distribution for a correlated 2-dimensional Gaussian neglecting normalization terms. @@ -228,7 +230,7 @@ def unnormalized_2d_gaussian(xx, yy, mu_x, mu_y, sigma_x, sigma_y, covariance): return prob -def von_mises(xx, mu, kappa): +def von_mises(xx: Any, mu: float, kappa: float) -> Any: r""" PDF of the von Mises distribution defined on the standard interval. @@ -258,7 +260,7 @@ def von_mises(xx, mu, kappa): return xp.exp(kappa * (xp.cos(xx - mu) - 1)) / (2 * xp.pi * scs.i0e(kappa)) -def get_name(input): +def get_name(input: Any) -> str: """ Attempt to find the name of the the input. This either returns :code:`input.__name__` or :code:`input.__class__.__name__` @@ -278,7 +280,7 @@ def get_name(input): return input.__class__.__name__ -def to_number(value, func): +def to_number(value: Any, func: type[int] | type[float] | type[complex]) -> int | float | complex: """ Convert a zero-dimensional array to a number. @@ -296,7 +298,7 @@ def to_number(value, func): return func(value) -def to_numpy(array): +def to_numpy(array: Any) -> Any: """ Convert an array to a numpy array. Numeric types and pandas objects are returned unchanged. @@ -318,7 +320,7 @@ def to_numpy(array): raise TypeError(f"Cannot convert {type(array)} to numpy array") -def trapezoid(y, x=None, dx=1.0, axis=-1): +def trapezoid(y: Any, x: Any | None = None, dx: float = 1.0, axis: int = -1) -> Any: """ A wrapper of `trapz` or `trapezoid` that can handle different names in different backends. From fc294259f4e7d63ad4ef346ce7644039b2c7006e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:01:03 +0000 Subject: [PATCH 03/11] Add Python type hints to gwpopulation/vt.py - Added type hints to all functions, methods, and classes - Used Any for array-like types (numpy/cupy/jax arrays) - Used modern Python 3.10+ syntax (dict, list, tuple instead of Dict, List, Tuple) - Used collections.abc.Callable for callable types - Added -> None return type for __init__ methods - Kept existing docstrings unchanged - All tests pass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/vt.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/gwpopulation/vt.py b/gwpopulation/vt.py index 1e17d80b..ad3a8020 100644 --- a/gwpopulation/vt.py +++ b/gwpopulation/vt.py @@ -43,6 +43,9 @@ Note that the computational cost of this approach scales exponentially with the number of parameters. """ +from collections.abc import Callable +from typing import Any + import numpy as np from bilby.hyper.model import Model @@ -59,7 +62,9 @@ class _BaseVT: - def __init__(self, model, data): + def __init__( + self, model: Callable | list[Callable] | Model, data: dict[str, Any] + ) -> None: self.data = {key: xp.asarray(value) for key, value in data.items()} if isinstance(model, list): model = Model(model) @@ -67,7 +72,7 @@ def __init__(self, model, data): model = Model([model]) self.model = model - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError @@ -84,7 +89,9 @@ class GridVT(_BaseVT): parameter to be marginalized over. """ - def __init__(self, model, data): + def __init__( + self, model: Callable | list[Callable] | Model, data: dict[str, Any] + ) -> None: self.vts = xp.asarray(data.pop("vt")) super(GridVT, self).__init__(model=model, data=data) self.values = {key: xp.unique(self.data[key]) for key in self.data} @@ -93,7 +100,7 @@ def __init__(self, model, data): self.axes = {int(np.where(shape == lens[key])[0][0]): key for key in self.data} self.ndim = len(self.axes) - def __call__(self, parameters): + def __call__(self, parameters: dict[str, Any]) -> Any: vt_fac = self.model.prob(self.data, **parameters) * self.vts for ii in range(self.ndim): vt_fac = trapezoid( @@ -143,12 +150,12 @@ class ResamplingVT(_BaseVT): def __init__( self, - model, - data, - n_events=np.inf, - marginalize_uncertainty=False, - enforce_convergence=True, - ): + model: Callable | list[Callable] | Model, + data: dict[str, Any], + n_events: float = np.inf, + marginalize_uncertainty: bool = False, + enforce_convergence: bool = True, + ) -> None: super(ResamplingVT, self).__init__(model=model, data=data) self.n_events = n_events self.total_injections = data.get("total_generated", len(data["prior"])) @@ -164,7 +171,7 @@ def __init__( lamb=0, analysis_time=self.analysis_time ) - def __call__(self, parameters): + def __call__(self, parameters: dict[str, Any]) -> tuple[float, float] | float: r""" Compute the expected fraction of detected sources given a set of injections for the specified population model. @@ -194,7 +201,7 @@ def __call__(self, parameters): vt_factor = self.vt_factor(parameters) return vt_factor - def check_convergence(self, mu, var): + def check_convergence(self, mu: float, var: float) -> tuple[Any, Any]: r""" Check if the estimate of the detection efficiency has converged beyond the threshold of :math:`\frac{\mu^2}{\sigma^2} > 4 n_{\rm events}`. @@ -205,7 +212,7 @@ def check_convergence(self, mu, var): xp.nan_to_num(xp.inf * (1 - converged), nan=0, posinf=xp.inf), ) - def vt_factor(self, parameters): + def vt_factor(self, parameters: dict[str, Any]) -> float: r""" Compute the expected number of detections given a set of injections. @@ -232,7 +239,7 @@ def vt_factor(self, parameters): vt_factor += correction return vt_factor - def detection_efficiency(self, parameters): + def detection_efficiency(self, parameters: dict[str, Any]) -> tuple[float, float]: r""" Compute the expected fraction of detections given a set of injections and the variance in the Monte Carlo estimate. @@ -258,7 +265,7 @@ def detection_efficiency(self, parameters): ) return mu, var - def surveyed_hypervolume(self, parameters): + def surveyed_hypervolume(self, parameters: dict[str, Any]) -> float: r""" The total surveyed 4-volume with units of :math:`{\rm Gpc}^3{\rm yr}`. From e4f6894bde2b139fbd2804a900c43d0ee5cbb1c1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:05:37 +0000 Subject: [PATCH 04/11] Add Python type hints to hyperpe.py - Added type hints to all functions, methods, and classes - Used Any for array-like types (numpy, cupy, jax arrays) - Used modern Python 3.10+ type hints (dict[str, Any] instead of Dict) - Used Callable from collections.abc - All existing tests pass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/hyperpe.py | 54 +++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/gwpopulation/hyperpe.py b/gwpopulation/hyperpe.py index 7451ef83..d2a8650c 100644 --- a/gwpopulation/hyperpe.py +++ b/gwpopulation/hyperpe.py @@ -43,6 +43,8 @@ """ import types +from collections.abc import Callable +from typing import Any import numpy as np from bilby.core.likelihood import Likelihood @@ -71,14 +73,14 @@ class HyperparameterLikelihood(Likelihood): def __init__( self, - posteriors, - hyper_prior, - ln_evidences=None, - max_samples=1e100, - selection_function=lambda args: 1, - conversion_function=lambda args: (args, None), - maximum_uncertainty=xp.inf, - ): + posteriors: list[Any], + hyper_prior: Model | Callable, + ln_evidences: list[float] | None = None, + max_samples: float = 1e100, + selection_function: Callable[[dict[str, Any]], float | tuple[float, float]] = lambda args: 1.0, + conversion_function: Callable[[dict[str, Any]], tuple[dict[str, Any], Any]] = lambda args: (args, None), + maximum_uncertainty: float = xp.inf, + ) -> None: """ Parameters ---------- @@ -144,7 +146,7 @@ def __init__( __doc__ += __init__.__doc__ @property - def maximum_uncertainty(self): + def maximum_uncertainty(self) -> float: """ The maximum allowed uncertainty in the estimate of the log-likelihood. If the uncertainty is larger than this value a log likelihood of -inf @@ -153,14 +155,14 @@ def maximum_uncertainty(self): return self._maximum_uncertainty @maximum_uncertainty.setter - def maximum_uncertainty(self, value): + def maximum_uncertainty(self, value: float) -> None: self._maximum_uncertainty = value if value in [xp.inf, np.inf]: self._max_variance = value else: self._max_variance = value**2 - def ln_likelihood_and_variance(self, parameters): + def ln_likelihood_and_variance(self, parameters: dict[str, Any]) -> tuple[Any, float]: """ Compute the ln likelihood estimator and its variance. """ @@ -177,23 +179,23 @@ def ln_likelihood_and_variance(self, parameters): ln_l += selection return ln_l, to_number(variance, float) - def log_likelihood_ratio(self, parameters): + def log_likelihood_ratio(self, parameters: dict[str, Any]) -> float: ln_l, variance = self.ln_likelihood_and_variance(parameters=parameters) ln_l = xp.nan_to_num(ln_l, nan=-xp.inf) ln_l -= xp.nan_to_num(xp.inf * (self.maximum_uncertainty < variance), nan=0) return to_number(xp.nan_to_num(ln_l), float) - def noise_log_likelihood(self): + def noise_log_likelihood(self) -> float: return self.total_noise_evidence - def log_likelihood(self, parameters): + def log_likelihood(self, parameters: dict[str, Any]) -> float: return self.noise_log_likelihood() + self.log_likelihood_ratio( parameters=parameters ) def _compute_per_event_ln_bayes_factors( - self, parameters, *, return_uncertainty=True - ): + self, parameters: dict[str, Any], *, return_uncertainty: bool = True + ) -> Any | tuple[Any, Any]: weights = self.hyper_prior.prob(self.data, **parameters) / self.sampling_prior expectation = xp.mean(weights, axis=-1) if return_uncertainty: @@ -205,7 +207,7 @@ def _compute_per_event_ln_bayes_factors( else: return xp.log(expectation) - def _get_selection_factor(self, parameters, *, return_uncertainty=True): + def _get_selection_factor(self, parameters: dict[str, Any], *, return_uncertainty: bool = True) -> Any | tuple[Any, Any]: selection, variance = self._selection_function_with_uncertainty( parameters=parameters ) @@ -216,7 +218,7 @@ def _get_selection_factor(self, parameters, *, return_uncertainty=True): else: return total_selection - def _selection_function_with_uncertainty(self, parameters): + def _selection_function_with_uncertainty(self, parameters: dict[str, Any]) -> tuple[Any, Any]: result = self.selection_function(parameters) if isinstance(result, tuple): selection, variance = result @@ -225,7 +227,7 @@ def _selection_function_with_uncertainty(self, parameters): variance = 0.0 return selection, variance - def generate_extra_statistics(self, sample): + def generate_extra_statistics(self, sample: dict[str, Any]) -> dict[str, Any]: r""" Given an input sample, add extra statistics @@ -274,7 +276,7 @@ def generate_extra_statistics(self, sample): sample["variance"] = to_number(total_variance, float) return sample - def generate_rate_posterior_sample(self, parameters): + def generate_rate_posterior_sample(self, parameters: dict[str, Any]) -> float: r""" Generate a sample from the posterior distribution for rate assuming a :math:`1 / R` prior. @@ -311,7 +313,7 @@ def generate_rate_posterior_sample(self, parameters): rate = gamma(a=self.n_posteriors).rvs() / vt return rate - def resample_posteriors(self, posteriors, max_samples=1e300): + def resample_posteriors(self, posteriors: list[Any], max_samples: float = 1e300) -> dict[str, Any]: """ Convert list of pandas DataFrame object to dict of arrays. @@ -342,7 +344,7 @@ def resample_posteriors(self, posteriors, max_samples=1e300): data[key] = xp.array(data[key]) return data - def posterior_predictive_resample(self, samples, return_weights=False): + def posterior_predictive_resample(self, samples: Any, return_weights: bool = False) -> dict[str, Any] | tuple[dict[str, Any], Any]: """ Resample the original single event posteriors to use the PPD from each of the other events as the prior. @@ -420,7 +422,7 @@ def posterior_predictive_resample(self, samples, return_weights=False): return new_samples @property - def meta_data(self): + def meta_data(self) -> dict[str, Any]: return dict( model=[get_name(model) for model in self.hyper_prior.models], data={key: to_numpy(self.data[key]) for key in self.data}, @@ -442,7 +444,7 @@ class RateLikelihood(HyperparameterLikelihood): __doc__ += HyperparameterLikelihood.__init__.__doc__ - def _get_selection_factor(self, parameters, *, return_uncertainty=True): + def _get_selection_factor(self, parameters: dict[str, Any], *, return_uncertainty: bool = True) -> Any | tuple[Any, Any]: r""" The selection factor for the rate likelihood is @@ -472,7 +474,7 @@ def _get_selection_factor(self, parameters, *, return_uncertainty=True): else: return total_selection - def generate_rate_posterior_sample(self, parameters): + def generate_rate_posterior_sample(self, parameters: dict[str, Any]) -> Any: """ Since the rate is a sampled parameter, this simply returns the current value of the rate parameter. @@ -490,7 +492,7 @@ class NullHyperparameterLikelihood(HyperparameterLikelihood): `Farr `_. """ - def ln_likelihood_and_variance(self, parameters=None): + def ln_likelihood_and_variance(self, parameters: dict[str, Any] | None = None) -> tuple[float, float]: """ Compute the ln likelihood estimator and its variance. """ From 8e7185532f0c67be39214615f2a5ad9e8b7f305f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:11:06 +0000 Subject: [PATCH 05/11] Add Python type hints to all functions, methods, and classes in mass.py - Added typing.Any import for array-like types - Added type hints to all function parameters and return types - Added type hints to all class methods including __init__ (-> None) - Used modern Python 3.10+ syntax (dict[str, Any] instead of Dict[str, Any]) - Used Any for array-like types (numpy, cupy, or jax arrays) - All existing docstrings and functionality remain unchanged Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/models/mass.py | 125 ++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 62 deletions(-) diff --git a/gwpopulation/models/mass.py b/gwpopulation/models/mass.py index 826d8277..19d9e982 100644 --- a/gwpopulation/models/mass.py +++ b/gwpopulation/models/mass.py @@ -3,6 +3,7 @@ """ import inspect +from typing import Any import numpy as np import scipy.special as scs @@ -35,7 +36,7 @@ ] -def double_power_law_primary_mass(mass, alpha_1, alpha_2, mmin, mmax, break_fraction): +def double_power_law_primary_mass(mass: Any, alpha_1: float, alpha_2: float, mmin: float, mmax: float, break_fraction: float) -> Any: r""" Broken power-law mass distribution @@ -72,17 +73,17 @@ def double_power_law_primary_mass(mass, alpha_1, alpha_2, mmin, mmax, break_frac def double_power_law_peak_primary_mass( - mass, - alpha_1, - alpha_2, - mmin, - mmax, - break_fraction, - lam, - mpp, - sigpp, - gaussian_mass_maximum=100, -): + mass: Any, + alpha_1: float, + alpha_2: float, + mmin: float, + mmax: float, + break_fraction: float, + lam: float, + mpp: float, + sigpp: float, + gaussian_mass_maximum: float = 100, +) -> Any: r""" Broken power-law with a Gaussian component. @@ -138,8 +139,8 @@ def double_power_law_peak_primary_mass( def double_power_law_primary_power_law_mass_ratio( - dataset, alpha_1, alpha_2, beta, mmin, mmax, break_fraction -): + dataset: dict[str, Any], alpha_1: float, alpha_2: float, beta: float, mmin: float, mmax: float, break_fraction: float +) -> Any: r""" Power law model for two-dimensional mass distribution, modelling primary mass and conditional mass ratio distribution. @@ -184,7 +185,7 @@ def double_power_law_primary_power_law_mass_ratio( return prob -def power_law_primary_mass_ratio(dataset, alpha, beta, mmin, mmax): +def power_law_primary_mass_ratio(dataset: dict[str, Any], alpha: float, beta: float, mmin: float, mmax: float) -> Any: r""" Power law model for two-dimensional mass distribution, modelling primary mass and conditional mass ratio distribution. @@ -214,11 +215,11 @@ def power_law_primary_mass_ratio(dataset, alpha, beta, mmin, mmax): ) -def _primary_secondary_general(dataset, p_m1, p_m2): +def _primary_secondary_general(dataset: dict[str, Any], p_m1: Any, p_m2: Any) -> Any: return p_m1 * p_m2 * (dataset["mass_1"] >= dataset["mass_2"]) * 2 -def power_law_primary_secondary_independent(dataset, alpha, beta, mmin, mmax): +def power_law_primary_secondary_independent(dataset: dict[str, Any], alpha: float, beta: float, mmin: float, mmax: float) -> Any: r""" Power law model for two-dimensional mass distribution, modelling the primary and secondary masses as following independent distributions. @@ -247,7 +248,7 @@ def power_law_primary_secondary_independent(dataset, alpha, beta, mmin, mmax): return prob -def power_law_primary_secondary_identical(dataset, alpha, mmin, mmax): +def power_law_primary_secondary_identical(dataset: dict[str, Any], alpha: float, mmin: float, mmax: float) -> Any: r""" Power law model for two-dimensional mass distribution, modelling the primary and secondary masses as following independent distributions. @@ -273,7 +274,7 @@ def power_law_primary_secondary_identical(dataset, alpha, mmin, mmax): ) -def power_law_mass(mass, alpha, mmin, mmax): +def power_law_mass(mass: Any, alpha: float, mmin: float, mmax: float) -> Any: r""" Power law model for one-dimensional mass distribution. @@ -295,8 +296,8 @@ def power_law_mass(mass, alpha, mmin, mmax): def two_component_single( - mass, alpha, mmin, mmax, lam, mpp, sigpp, gaussian_mass_maximum=100 -): + mass: Any, alpha: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 +) -> Any: r""" Power law model for one-dimensional mass distribution with a Gaussian component. @@ -333,18 +334,18 @@ def two_component_single( def three_component_single( - mass, - alpha, - mmin, - mmax, - lam, - lam_1, - mpp_1, - sigpp_1, - mpp_2, - sigpp_2, - gaussian_mass_maximum=100, -): + mass: Any, + alpha: float, + mmin: float, + mmax: float, + lam: float, + lam_1: float, + mpp_1: float, + sigpp_1: float, + mpp_2: float, + sigpp_2: float, + gaussian_mass_maximum: float = 100, +) -> Any: r""" Power law model for one-dimensional mass distribution with two Gaussian components. @@ -395,8 +396,8 @@ def three_component_single( def two_component_primary_mass_ratio( - dataset, alpha, beta, mmin, mmax, lam, mpp, sigpp, gaussian_mass_maximum=100 -): + dataset: dict[str, Any], alpha: float, beta: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 +) -> Any: r""" Power law model for two-dimensional mass distribution, modelling primary mass and conditional mass ratio distribution. @@ -440,8 +441,8 @@ def two_component_primary_mass_ratio( def two_component_primary_secondary_independent( - dataset, alpha, beta, mmin, mmax, lam, mpp, sigpp, gaussian_mass_maximum=100 -): + dataset: dict[str, Any], alpha: float, beta: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 +) -> Any: r""" Power law model for two-dimensional mass distribution, modelling the primary and secondary masses as following independent distributions. @@ -486,8 +487,8 @@ def two_component_primary_secondary_independent( def two_component_primary_secondary_identical( - dataset, alpha, mmin, mmax, lam, mpp, sigpp, gaussian_mass_maximum=100 -): + dataset: dict[str, Any], alpha: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 +) -> Any: r""" Power law model for two-dimensional mass distribution, modelling the primary and secondary masses as following independent distributions. @@ -545,7 +546,7 @@ class BaseSmoothedMassDistribution: primary_model = None @property - def variable_names(self): + def variable_names(self) -> set[str]: vars = getattr( self.primary_model, "variable_names", @@ -556,10 +557,10 @@ def variable_names(self): return vars @property - def kwargs(self): + def kwargs(self) -> dict[str, Any]: return dict() - def __init__(self, mmin=2, mmax=100, normalization_shape=(1000, 500), cache=True): + def __init__(self, mmin: float = 2, mmax: float = 100, normalization_shape: tuple[int, int] = (1000, 500), cache: bool = True) -> None: self.mmin = mmin self.mmax = mmax self.m1s = xp.linspace(mmin, mmax, normalization_shape[0]) @@ -569,7 +570,7 @@ def __init__(self, mmin=2, mmax=100, normalization_shape=(1000, 500), cache=True self.m1s_grid, self.qs_grid = xp.meshgrid(self.m1s, self.qs) self.cache = cache - def __call__(self, dataset, *args, **kwargs): + def __call__(self, dataset: dict[str, Any], *args: Any, **kwargs: Any) -> Any: beta = kwargs.pop("beta") mmin = kwargs.get("mmin", self.mmin) mmax = kwargs.get("mmax", self.mmax) @@ -588,7 +589,7 @@ def __call__(self, dataset, *args, **kwargs): prob = p_m1 * p_q return prob - def p_m1(self, dataset, **kwargs): + def p_m1(self, dataset: dict[str, Any], **kwargs: Any) -> Any: mmin = kwargs.get("mmin", self.mmin) delta_m = kwargs.pop("delta_m", 0) p_m = self.__class__.primary_model(dataset["mass_1"], **kwargs) @@ -598,7 +599,7 @@ def p_m1(self, dataset, **kwargs): norm = self.norm_p_m1(delta_m=delta_m, **kwargs) return p_m / norm - def norm_p_m1(self, delta_m, **kwargs): + def norm_p_m1(self, delta_m: float, **kwargs: Any) -> Any: """Calculate the normalisation factor for the primary mass""" mmin = kwargs.get("mmin", self.mmin) if "jax" not in xp.__name__ and delta_m == 0: @@ -611,7 +612,7 @@ def norm_p_m1(self, delta_m, **kwargs): ) return norm - def p_q(self, dataset, beta, mmin, delta_m): + def p_q(self, dataset: dict[str, Any], beta: float, mmin: float, delta_m: float) -> Any: p_q = powerlaw(dataset["mass_ratio"], beta, 1, mmin / dataset["mass_1"]) p_q *= self.smoothing( dataset["mass_1"] * dataset["mass_ratio"], @@ -632,7 +633,7 @@ def p_q(self, dataset, beta, mmin, delta_m): return xp.nan_to_num(p_q) - def norm_p_q(self, beta, mmin, delta_m): + def norm_p_q(self, beta: float, mmin: float, delta_m: float) -> Any: """Calculate the mass ratio normalisation by linear interpolation""" p_q = powerlaw(self.qs_grid, beta, 1, mmin / self.m1s_grid) p_q *= self.smoothing( @@ -645,7 +646,7 @@ def norm_p_q(self, beta, mmin, delta_m): return self._q_interpolant(norms) - def _cache_q_norms(self, masses): + def _cache_q_norms(self, masses: Any) -> None: """ Cache the information necessary for linear interpolation of the mass ratio normalisation @@ -657,7 +658,7 @@ def _cache_q_norms(self, masses): ) @staticmethod - def smoothing(masses, mmin, mmax, delta_m): + def smoothing(masses: Any, mmin: float, mmax: float, delta_m: float) -> Any: """ Apply a one sided window between mmin and mmin + delta_m to the mass pdf. @@ -721,7 +722,7 @@ class SinglePeakSmoothedMassDistribution(BaseSmoothedMassDistribution): primary_model = two_component_single @property - def kwargs(self): + def kwargs(self) -> dict[str, Any]: return dict(gaussian_mass_maximum=self.mmax) @@ -766,7 +767,7 @@ class MultiPeakSmoothedMassDistribution(BaseSmoothedMassDistribution): primary_model = three_component_single @property - def kwargs(self): + def kwargs(self) -> dict[str, Any]: return dict(gaussian_mass_maximum=self.mmax) @@ -837,7 +838,7 @@ class BrokenPowerLawPeakSmoothedMassDistribution(BaseSmoothedMassDistribution): primary_model = double_power_law_peak_primary_mass @property - def kwargs(self): + def kwargs(self) -> dict[str, Any]: return dict(gaussian_mass_maximum=self.mmax) @@ -873,13 +874,13 @@ class InterpolatedPowerlaw( def __init__( self, - nodes=10, - kind="cubic", - mmin=2, - mmax=100, - normalization_shape=(1000, 500), - regularize=False, - ): + nodes: int = 10, + kind: str = "cubic", + mmin: float = 2, + mmax: float = 100, + normalization_shape: tuple[int, int] = (1000, 500), + regularize: bool = False, + ) -> None: """ Parameters ========== @@ -916,13 +917,13 @@ def __init__( self._xs = self.m1s @property - def variable_names(self): + def variable_names(self) -> set[str]: variable_names = super().variable_names.union( InterpolatedNoBaseModelIdentical.variable_names.fget(self) ) return variable_names - def p_m1(self, dataset, **kwargs): + def p_m1(self, dataset: dict[str, Any], **kwargs: Any) -> Any: f_splines, m_splines = self.extract_spline_points(kwargs) @@ -939,7 +940,7 @@ def p_m1(self, dataset, **kwargs): norm = self.norm_p_m1(delta_m=delta_m, f_splines=f_splines, **kwargs) return p_m / norm - def norm_p_m1(self, delta_m, f_splines=None, **kwargs): + def norm_p_m1(self, delta_m: float, f_splines: Any = None, **kwargs: Any) -> Any: mmin = kwargs.get("mmin", self.mmin) p_m = self.__class__.primary_model( self.m1s, **{key: kwargs[key] for key in ["alpha", "mmin", "mmax"]} From f75c88ce13233830da11acf1f51275b290d94f5b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:12:20 +0000 Subject: [PATCH 06/11] Fix line length issues in type hints Break long parameter lists into multiple lines for better readability Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/models/mass.py | 46 +++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/gwpopulation/models/mass.py b/gwpopulation/models/mass.py index 19d9e982..38150195 100644 --- a/gwpopulation/models/mass.py +++ b/gwpopulation/models/mass.py @@ -139,7 +139,13 @@ def double_power_law_peak_primary_mass( def double_power_law_primary_power_law_mass_ratio( - dataset: dict[str, Any], alpha_1: float, alpha_2: float, beta: float, mmin: float, mmax: float, break_fraction: float + dataset: dict[str, Any], + alpha_1: float, + alpha_2: float, + beta: float, + mmin: float, + mmax: float, + break_fraction: float, ) -> Any: r""" Power law model for two-dimensional mass distribution, modelling primary @@ -296,7 +302,14 @@ def power_law_mass(mass: Any, alpha: float, mmin: float, mmax: float) -> Any: def two_component_single( - mass: Any, alpha: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 + mass: Any, + alpha: float, + mmin: float, + mmax: float, + lam: float, + mpp: float, + sigpp: float, + gaussian_mass_maximum: float = 100, ) -> Any: r""" Power law model for one-dimensional mass distribution with a Gaussian component. @@ -396,7 +409,15 @@ def three_component_single( def two_component_primary_mass_ratio( - dataset: dict[str, Any], alpha: float, beta: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 + dataset: dict[str, Any], + alpha: float, + beta: float, + mmin: float, + mmax: float, + lam: float, + mpp: float, + sigpp: float, + gaussian_mass_maximum: float = 100, ) -> Any: r""" Power law model for two-dimensional mass distribution, modelling primary @@ -441,7 +462,15 @@ def two_component_primary_mass_ratio( def two_component_primary_secondary_independent( - dataset: dict[str, Any], alpha: float, beta: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 + dataset: dict[str, Any], + alpha: float, + beta: float, + mmin: float, + mmax: float, + lam: float, + mpp: float, + sigpp: float, + gaussian_mass_maximum: float = 100, ) -> Any: r""" Power law model for two-dimensional mass distribution, modelling the @@ -487,7 +516,14 @@ def two_component_primary_secondary_independent( def two_component_primary_secondary_identical( - dataset: dict[str, Any], alpha: float, mmin: float, mmax: float, lam: float, mpp: float, sigpp: float, gaussian_mass_maximum: float = 100 + dataset: dict[str, Any], + alpha: float, + mmin: float, + mmax: float, + lam: float, + mpp: float, + sigpp: float, + gaussian_mass_maximum: float = 100, ) -> Any: r""" Power law model for two-dimensional mass distribution, modelling the From eb08584ce6f10bf1c63d4dbbc495f2b34d4ba26a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:17:01 +0000 Subject: [PATCH 07/11] Add Python type hints to spin.py - Add type hints to all functions, methods, and classes - Use 'Any' for array-like types (numpy/cupy/jax arrays) - Use modern Python 3.10+ type syntax (dict, list instead of Dict, List) - Add '-> None' return type to __init__ methods - All existing tests pass (30 tests) - Docstrings remain unchanged Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/models/spin.py | 83 ++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 16 deletions(-) diff --git a/gwpopulation/models/spin.py b/gwpopulation/models/spin.py index 373a87e2..ae574f1b 100644 --- a/gwpopulation/models/spin.py +++ b/gwpopulation/models/spin.py @@ -2,6 +2,8 @@ Implemented spin models """ +from typing import Any + import numpy as xp from ..utils import beta_dist, trapezoid, truncnorm, unnormalized_2d_gaussian @@ -21,7 +23,14 @@ ] -def iid_spin(dataset, xi_spin, sigma_spin, amax, alpha_chi, beta_chi): +def iid_spin( + dataset: dict[str, Any], + xi_spin: float, + sigma_spin: float, + amax: float, + alpha_chi: float, + beta_chi: float, +) -> Any: r""" Independently and identically distributed spins. The magnitudes are assumed to follow a Beta distribution and the @@ -47,7 +56,9 @@ def iid_spin(dataset, xi_spin, sigma_spin, amax, alpha_chi, beta_chi): return prior -def iid_spin_magnitude_beta(dataset, amax=1, alpha_chi=1, beta_chi=1): +def iid_spin_magnitude_beta( + dataset: dict[str, Any], amax: float = 1, alpha_chi: float = 1, beta_chi: float = 1 +) -> Any: """ Independent and identically distributed beta distributions for both spin magnitudes. @@ -68,8 +79,14 @@ def iid_spin_magnitude_beta(dataset, amax=1, alpha_chi=1, beta_chi=1): def independent_spin_magnitude_beta( - dataset, alpha_chi_1, alpha_chi_2, beta_chi_1, beta_chi_2, amax_1, amax_2 -): + dataset: dict[str, Any], + alpha_chi_1: float, + alpha_chi_2: float, + beta_chi_1: float, + beta_chi_2: float, + amax_1: float, + amax_2: float, +) -> Any: """ Independent beta distributions for both spin magnitudes. @@ -92,7 +109,9 @@ def independent_spin_magnitude_beta( return prior -def iid_spin_orientation_gaussian_isotropic(dataset, xi_spin, sigma_spin, mu_spin=1): +def iid_spin_orientation_gaussian_isotropic( + dataset: dict[str, Any], xi_spin: float, sigma_spin: float, mu_spin: float = 1 +) -> Any: r""" A mixture model of spin orientations with isotropic and normally distributed components. The distribution of primary and secondary spin @@ -125,8 +144,13 @@ def iid_spin_orientation_gaussian_isotropic(dataset, xi_spin, sigma_spin, mu_spi def independent_spin_orientation_gaussian_isotropic( - dataset, xi_spin, sigma_1, sigma_2, mu_1=1, mu_2=1 -): + dataset: dict[str, Any], + xi_spin: float, + sigma_1: float, + sigma_2: float, + mu_1: float = 1, + mu_2: float = 1, +) -> Any: r""" A mixture model of spin orientations with isotropic and normally distributed components. @@ -166,7 +190,9 @@ def independent_spin_orientation_gaussian_isotropic( return prior -def gaussian_chi_eff(dataset, mu_chi_eff, sigma_chi_eff): +def gaussian_chi_eff( + dataset: dict[str, Any], mu_chi_eff: float, sigma_chi_eff: float +) -> Any: r""" A Gaussian in chi effective distribution @@ -197,7 +223,7 @@ def gaussian_chi_eff(dataset, mu_chi_eff, sigma_chi_eff): ) -def gaussian_chi_p(dataset, mu_chi_p, sigma_chi_p): +def gaussian_chi_p(dataset: dict[str, Any], mu_chi_p: float, sigma_chi_p: float) -> Any: r""" A Gaussian distribution in precessing effective spin (chi p) @@ -261,14 +287,20 @@ class GaussianChiEffChiP(object): Covariance between the two parameters (:math:`\rho`) """ - def __init__(self): + def __init__(self) -> None: self.chi_eff = xp.linspace(-1, 1, 500) self.chi_p = xp.linspace(0, 1, 250) self.chi_eff_grid, self.chi_p_grid = xp.meshgrid(self.chi_eff, self.chi_p) def __call__( - self, dataset, mu_chi_eff, sigma_chi_eff, mu_chi_p, sigma_chi_p, spin_covariance - ): + self, + dataset: dict[str, Any], + mu_chi_eff: float, + sigma_chi_eff: float, + mu_chi_p: float, + sigma_chi_p: float, + spin_covariance: float, + ) -> Any: prob = unnormalized_2d_gaussian( dataset["chi_eff"], @@ -292,8 +324,13 @@ def __call__( return prob def _normalization( - self, mu_chi_eff, sigma_chi_eff, mu_chi_p, sigma_chi_p, spin_covariance - ): + self, + mu_chi_eff: float, + sigma_chi_eff: float, + mu_chi_p: float, + sigma_chi_p: float, + spin_covariance: float, + ) -> Any: r""" Numerically calculate the normalization over a two-dimensional grid with trapezoidal integration @@ -351,7 +388,14 @@ class SplineSpinMagnitudeIdentical(InterpolatedNoBaseModelIdentical): value :code:`rms{name}`, default=False. """ - def __init__(self, minimum=0, maximum=1, nodes=5, kind="cubic", regularize=False): + def __init__( + self, + minimum: float = 0, + maximum: float = 1, + nodes: int = 5, + kind: str = "cubic", + regularize: bool = False, + ) -> None: super(SplineSpinMagnitudeIdentical, self).__init__( parameters=["a_1", "a_2"], @@ -384,7 +428,14 @@ class SplineSpinTiltIdentical(InterpolatedNoBaseModelIdentical): value :code:`rms{name}`, default=False. """ - def __init__(self, minimum=-1, maximum=1, nodes=5, kind="cubic", regularize=False): + def __init__( + self, + minimum: float = -1, + maximum: float = 1, + nodes: int = 5, + kind: str = "cubic", + regularize: bool = False, + ) -> None: super(SplineSpinTiltIdentical, self).__init__( parameters=["cos_tilt_1", "cos_tilt_2"], From b5d792c1df45a61f21ba740f47e127fdbe18fb57 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:21:14 +0000 Subject: [PATCH 08/11] Add Python type hints to redshift.py - Add type hints to all functions, methods, and classes - Use Any for array-like types (numpy/cupy/jax arrays) - Use modern Python 3.10+ syntax (dict[str, Any] vs Dict[str, Any]) - Add return type -> None to __init__ methods - Import typing.Any for array types - Keep all docstrings unchanged Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/models/redshift.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/gwpopulation/models/redshift.py b/gwpopulation/models/redshift.py index 6444b7d4..15df2e27 100644 --- a/gwpopulation/models/redshift.py +++ b/gwpopulation/models/redshift.py @@ -2,6 +2,8 @@ Implemented redshift models """ +from typing import Any + import numpy as xp from ..experimental.cosmo_models import CosmoMixin @@ -41,10 +43,10 @@ class _Redshift(CosmoMixin): for the model. """ - base_variable_names = None + base_variable_names: list[str] | None = None @property - def variable_names(self): + def variable_names(self) -> list[str]: """ Variable names for the model @@ -60,18 +62,18 @@ def variable_names(self): vars += self.base_variable_names return vars - def __init__(self, z_max=2.3, cosmo_model="Planck15"): + def __init__(self, z_max: float = 2.3, cosmo_model: str = "Planck15") -> None: super().__init__(cosmo_model=cosmo_model) self.z_max = z_max self.zs = xp.linspace(1e-6, z_max, 2500) - def __call__(self, dataset, **kwargs): + def __call__(self, dataset: dict[str, Any], **kwargs: Any) -> Any: """ Wrapper to :func:`probability`. """ return self.probability(dataset=dataset, **kwargs) - def normalisation(self, parameters): + def normalisation(self, parameters: dict[str, Any]) -> Any: r""" Compute the normalization of the rate-weighted spacetime volume. @@ -97,7 +99,7 @@ def normalisation(self, parameters): norm = trapezoid(normalisation_data, self.zs) return norm - def probability(self, dataset, **parameters): + def probability(self, dataset: dict[str, Any], **parameters: Any) -> Any: """ Compute the normalized probability of a merger occurring at the specified redshift. @@ -121,7 +123,7 @@ def probability(self, dataset, **parameters): ) return differential_volume / normalisation - def psi_of_z(self, redshift, **parameters): + def psi_of_z(self, redshift: Any, **parameters: Any) -> Any: r""" Method encoding the redshift evolution of the merger rate. This should be overwritten in child classes. @@ -145,7 +147,7 @@ def psi_of_z(self, redshift, **parameters): """ raise NotImplementedError - def dvc_dz(self, redshift, **parameters): + def dvc_dz(self, redshift: Any, **parameters: Any) -> Any: r""" .. note:: @@ -166,7 +168,7 @@ def dvc_dz(self, redshift, **parameters): * self.cosmology(parameters).differential_comoving_volume(redshift) ) - def differential_spacetime_volume(self, dataset, bounds=False, **parameters): + def differential_spacetime_volume(self, dataset: dict[str, Any], bounds: bool = False, **parameters: Any) -> Any: r""" Compute the differential spacetime volume. @@ -196,9 +198,9 @@ def differential_spacetime_volume(self, dataset, bounds=False, **parameters): class PowerLawRedshift(_Redshift): - base_variable_names = ["lamb"] + base_variable_names: list[str] = ["lamb"] - def psi_of_z(self, redshift, **parameters): + def psi_of_z(self, redshift: Any, **parameters: Any) -> Any: r""" Redshift model from Fishbach+ https://arxiv.org/abs/1805.10270 (`arXiv:1805.10270 `_. @@ -217,9 +219,9 @@ def psi_of_z(self, redshift, **parameters): class MadauDickinsonRedshift(_Redshift): - base_variable_names = ["gamma", "kappa", "z_peak"] + base_variable_names: list[str] = ["gamma", "kappa", "z_peak"] - def psi_of_z(self, redshift, **parameters): + def psi_of_z(self, redshift: Any, **parameters: Any) -> Any: r""" Redshift model from Fishbach+ (`arXiv:1805.10270 `_ Eq. (33)) @@ -249,7 +251,7 @@ def psi_of_z(self, redshift, **parameters): return psi_of_z -def total_four_volume(lamb, analysis_time, max_redshift=2.3): +def total_four_volume(lamb: float, analysis_time: float, max_redshift: float = 2.3) -> float: r""" Calculate the rate-weighted four-volume for a given power-law redshift model. From f3eda2f14bd8ab21b2dad9159f12965a8cfe76ef Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:24:35 +0000 Subject: [PATCH 09/11] Add Python type hints to gwpopulation/models/interped.py - Add comprehensive type hints to all functions, methods, and classes - Use Any for array-like types (numpy/cupy/jax arrays) - Use modern Python 3.10+ type hints (dict, list, tuple) - Add return type hints including -> None for __init__ - All docstrings and code functionality remain unchanged - All tests pass successfully Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/models/interped.py | 35 +++++++++++++++++---------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/gwpopulation/models/interped.py b/gwpopulation/models/interped.py index 9a93929e..36d10b24 100644 --- a/gwpopulation/models/interped.py +++ b/gwpopulation/models/interped.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Any import numpy as np @@ -12,7 +13,7 @@ ] -def _setup_interpolant(nodes, values, kind="cubic", backend=None): +def _setup_interpolant(nodes: Any, values: Any, kind: str = "cubic", backend: Any = None) -> Any: """ Create a caching spline interpolant. @@ -68,14 +69,14 @@ class InterpolatedNoBaseModelIdentical: def __init__( self, - parameters, - minimum, - maximum, - nodes=10, - kind="cubic", - log_nodes=False, - regularize=False, - ): + parameters: list[str], + minimum: float, + maximum: float, + nodes: int = 10, + kind: str = "cubic", + log_nodes: bool = False, + regularize: bool = False, + ) -> None: self.nodes = nodes self._norm_spline = None self._data_spline = dict() @@ -91,14 +92,14 @@ def __init__( self.fkeys = [f"f{self.base}{ii}" for ii in range(self.nodes)] self.regularize = regularize - def __call__(self, dataset, **kwargs): + def __call__(self, dataset: dict[str, Any], **kwargs: Any) -> Any: """ A wrapper to :func:`p_x_identical` """ return self.p_x_identical(dataset, **kwargs) @property - def variable_names(self): + def variable_names(self) -> list[str]: """ The names of the hyperparameters of the model. @@ -113,7 +114,7 @@ def variable_names(self): keys += [f"rms{self.base}"] return keys - def setup_interpolant(self, nodes, values): + def setup_interpolant(self, nodes: Any, values: dict[str, Any]) -> None: if self.log_nodes: func = xp.log else: @@ -125,7 +126,7 @@ def setup_interpolant(self, nodes, values): for param in self.parameters } - def p_x_unnormed(self, dataset, parameter, x_splines, f_splines, **kwargs): + def p_x_unnormed(self, dataset: dict[str, Any], parameter: str, x_splines: Any, f_splines: Any, **kwargs: Any) -> Any: """ Calculate the unnormalized likelihood of the dataset given the model @@ -158,7 +159,7 @@ def p_x_unnormed(self, dataset, parameter, x_splines, f_splines, **kwargs): ) return p_x - def norm_p_x(self, f_splines=None, x_splines=None, **kwargs): + def norm_p_x(self, f_splines: Any = None, x_splines: Any = None, **kwargs: Any) -> Any: """ Calculate the normalization of the spline @@ -181,7 +182,7 @@ def norm_p_x(self, f_splines=None, x_splines=None, **kwargs): norm = trapezoid(p_x, self._xs) return norm - def extract_spline_points(self, kwargs): + def extract_spline_points(self, kwargs: dict[str, Any]) -> tuple[Any, Any]: """ Extract the node positions and values from the dictionary of parameters @@ -204,7 +205,7 @@ def extract_spline_points(self, kwargs): x_splines = xp.array([kwargs[key] for key in self.xkeys]) return f_splines, x_splines - def p_x_identical(self, dataset, **kwargs): + def p_x_identical(self, dataset: dict[str, Any], **kwargs: Any) -> Any: """ Calculate the likelihood of the dataset given the model assuming that all the parameters are identically distributed. @@ -236,7 +237,7 @@ def p_x_identical(self, dataset, **kwargs): p_x /= norm ** len(self.parameters) return p_x - def infer_n_nodes(self, **kwargs): + def infer_n_nodes(self, **kwargs: Any) -> None: """ Infer the number of nodes from the dictionary of parameters. This method looks for the first missing parameter matching From 3e13b23180fd9fb9452c65149cf5141961f0437a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:27:37 +0000 Subject: [PATCH 10/11] Add Python type hints to experimental modules - Added type hints to all functions, methods, and classes in cosmo_models.py, jax.py, and numpyro.py - Used modern Python 3.10+ syntax (dict, list, tuple instead of Dict, List, Tuple) - Used 'Any' for array-like types compatible with numpy/cupy/jax arrays - Added necessary imports: typing.Any and collections.abc.Callable - Added return type hints including '-> None' for __init__ methods - Preserved all existing docstrings and code logic unchanged Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- gwpopulation/experimental/cosmo_models.py | 14 ++++++++------ gwpopulation/experimental/jax.py | 12 +++++++----- gwpopulation/experimental/numpyro.py | 20 +++++++++++--------- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/gwpopulation/experimental/cosmo_models.py b/gwpopulation/experimental/cosmo_models.py index 14b346a7..fbafa253 100644 --- a/gwpopulation/experimental/cosmo_models.py +++ b/gwpopulation/experimental/cosmo_models.py @@ -6,6 +6,8 @@ can be used to add cosmological functionality to a population model. """ +from typing import Any + import numpy as xp from bilby.hyper.model import Model from wcosmo import z_at_value @@ -24,7 +26,7 @@ class CosmoMixin: Should be of :code:`wcosmo.available.keys()`. """ - def __init__(self, cosmo_model="Planck15"): + def __init__(self, cosmo_model: str = "Planck15") -> None: wcosmo_disable_units() self.cosmo_model = cosmo_model if self.cosmo_model == "FlatwCDM": @@ -35,7 +37,7 @@ def __init__(self, cosmo_model="Planck15"): self.cosmology_names = [] self._cosmo = available[cosmo_model] - def cosmology_variables(self, parameters): + def cosmology_variables(self, parameters: dict[str, Any]) -> dict[str, Any]: """ Extract the cosmological parameters from the provided parameters. @@ -51,7 +53,7 @@ def cosmology_variables(self, parameters): """ return {key: parameters[key] for key in self.cosmology_names} - def cosmology(self, parameters): + def cosmology(self, parameters: dict[str, Any]) -> WCosmoMixin: """ Return the cosmology model given the parameters. @@ -70,7 +72,7 @@ def cosmology(self, parameters): else: return self._cosmo(**self.cosmology_variables(parameters)) - def detector_frame_to_source_frame(self, data, **parameters): + def detector_frame_to_source_frame(self, data: dict[str, Any], **parameters: Any) -> tuple[dict[str, Any], Any]: r""" Convert detector frame samples to source frame samples given cosmological parameters. Calculate the corresponding @@ -133,11 +135,11 @@ class CosmoModel(Model, CosmoMixin): Should be of :code:`wcosmo.available.keys()`. """ - def __init__(self, model_functions=None, cosmo_model="Planck15"): + def __init__(self, model_functions: list[Any] | None = None, cosmo_model: str = "Planck15") -> None: Model.__init__(self, model_functions=model_functions, cache=False) CosmoMixin.__init__(self, cosmo_model=cosmo_model) - def prob(self, data, **kwargs): + def prob(self, data: dict[str, Any], **kwargs: Any) -> Any: """ Compute the total population probability for the provided data given the keyword arguments. diff --git a/gwpopulation/experimental/jax.py b/gwpopulation/experimental/jax.py index 0bf670d4..b544a82e 100644 --- a/gwpopulation/experimental/jax.py +++ b/gwpopulation/experimental/jax.py @@ -1,12 +1,14 @@ import warnings +from collections.abc import Callable from copy import deepcopy from functools import partial +from typing import Any import numpy as np from bilby.core.likelihood import Likelihood -def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): +def generic_bilby_likelihood_function(likelihood: Likelihood, parameters: dict[str, Any], use_ratio: bool = True) -> float: """ A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`. @@ -48,8 +50,8 @@ class JittedLikelihood(Likelihood): """ def __init__( - self, likelihood, likelihood_func=generic_bilby_likelihood_function, kwargs=None - ): + self, likelihood: Likelihood, likelihood_func: Callable = generic_bilby_likelihood_function, kwargs: dict[str, Any] | None = None + ) -> None: from jax import jit if kwargs is None: @@ -59,8 +61,8 @@ def __init__( self.likelihood_func = jit(partial(likelihood_func, likelihood)) super().__init__() - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self._likelihood, name) - def log_likelihood_ratio(self, parameters): + def log_likelihood_ratio(self, parameters: dict[str, Any]) -> float: return float(np.nan_to_num(self.likelihood_func(parameters, **self.kwargs))) diff --git a/gwpopulation/experimental/numpyro.py b/gwpopulation/experimental/numpyro.py index 9c55d93c..883949d8 100644 --- a/gwpopulation/experimental/numpyro.py +++ b/gwpopulation/experimental/numpyro.py @@ -1,3 +1,5 @@ +from typing import Any + import jax import jax.numpy as jnp import numpy as np @@ -11,13 +13,13 @@ def gwpopulation_likelihood_model( - likelihood, - hyper_params, - varmax=np.inf, - apply_selection=True, - predictive_resample=True, - fit_keys=None, -): + likelihood: Any, + hyper_params: dict[str, Any], + varmax: float = np.inf, + apply_selection: bool = True, + predictive_resample: bool = True, + fit_keys: list[str] | None = None, +) -> Any: """ A :code:`numpyro` implementation of :func:`gwpopulation.hyperpe.HyperparameterLikelihood`. @@ -106,8 +108,8 @@ def gwpopulation_likelihood_model( def posterior_predictive_resample( - data, weights, label, event=None, shape=(), fit_keys=None -): + data: dict[str, Any], weights: Any, label: str | int, event: int | None = None, shape: tuple = (), fit_keys: list[str] | None = None +) -> None: """ Draw a sample from the predictive distribution given the weights for each sample. From f157fa576c949ef19ea75314f5ea5c8f4e4f48df Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:43:17 +0000 Subject: [PATCH 11/11] Use ModuleType for xp and Array from array-api-typing for array parameters Co-authored-by: ColmTalbot <25602909+ColmTalbot@users.noreply.github.com> --- gwpopulation/utils.py | 22 ++++++++++++---------- requirements.txt | 1 + 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/gwpopulation/utils.py b/gwpopulation/utils.py index aca9003f..53284a80 100644 --- a/gwpopulation/utils.py +++ b/gwpopulation/utils.py @@ -5,12 +5,14 @@ from collections.abc import Callable from numbers import Number from operator import ge, gt, ne +from types import ModuleType from typing import Any +from array_api_typing import Array import numpy as np from scipy import special as scs -xp: Any = np +xp: ModuleType = np def apply_conditions(conditions: dict[str, Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -63,7 +65,7 @@ def wrapped_function(*args: Any, **kwargs: Any) -> Any: @apply_conditions(dict(alpha=(gt, 0), beta=(gt, 0), scale=(gt, 0))) -def beta_dist(xx: Any, alpha: float, beta: float, scale: float | Any = 1) -> Any: +def beta_dist(xx: Array, alpha: float, beta: float, scale: float | Array = 1) -> Array: r""" Beta distribution probability @@ -97,7 +99,7 @@ def beta_dist(xx: Any, alpha: float, beta: float, scale: float | Any = 1) -> Any @apply_conditions(dict(low=(ge, 0), alpha=(ne, 1))) -def powerlaw(xx: Any, alpha: float | Any, high: float | Any, low: float | Any) -> Any: +def powerlaw(xx: Array, alpha: float | Array, high: float | Array, low: float | Array) -> Array: r""" Power-law probability @@ -133,7 +135,7 @@ def powerlaw(xx: Any, alpha: float | Any, high: float | Any, low: float | Any) - @apply_conditions(dict(sigma=(gt, 0))) -def truncnorm(xx: Any, mu: float | Any, sigma: float, high: float | Any, low: float | Any) -> Any: +def truncnorm(xx: Array, mu: float | Array, sigma: float, high: float | Array, low: float | Array) -> Array: r""" Truncated normal probability @@ -163,7 +165,7 @@ def truncnorm(xx: Any, mu: float | Any, sigma: float, high: float | Any, low: fl """ - def logsubexp(log_p: Any, log_q: Any) -> Any: + def logsubexp(log_p: Array, log_q: Array) -> Array: return log_p + xp.log(1 - xp.exp(log_q - log_p)) zz = xp.array(xx - mu) / sigma @@ -185,7 +187,7 @@ def logsubexp(log_p: Any, log_q: Any) -> Any: return xp.nan_to_num(xp.exp(log_pdf)) * (xx >= low) * (xx <= high) -def unnormalized_2d_gaussian(xx: Any, yy: Any, mu_x: float, mu_y: float, sigma_x: float, sigma_y: float, covariance: float) -> Any: +def unnormalized_2d_gaussian(xx: Array, yy: Array, mu_x: float, mu_y: float, sigma_x: float, sigma_y: float, covariance: float) -> Array: r""" Compute the probability distribution for a correlated 2-dimensional Gaussian neglecting normalization terms. @@ -230,7 +232,7 @@ def unnormalized_2d_gaussian(xx: Any, yy: Any, mu_x: float, mu_y: float, sigma_x return prob -def von_mises(xx: Any, mu: float, kappa: float) -> Any: +def von_mises(xx: Array, mu: float, kappa: float) -> Array: r""" PDF of the von Mises distribution defined on the standard interval. @@ -280,7 +282,7 @@ def get_name(input: Any) -> str: return input.__class__.__name__ -def to_number(value: Any, func: type[int] | type[float] | type[complex]) -> int | float | complex: +def to_number(value: Array, func: type[int] | type[float] | type[complex]) -> int | float | complex: """ Convert a zero-dimensional array to a number. @@ -298,7 +300,7 @@ def to_number(value: Any, func: type[int] | type[float] | type[complex]) -> int return func(value) -def to_numpy(array: Any) -> Any: +def to_numpy(array: Array) -> Array: """ Convert an array to a numpy array. Numeric types and pandas objects are returned unchanged. @@ -320,7 +322,7 @@ def to_numpy(array: Any) -> Any: raise TypeError(f"Cannot convert {type(array)} to numpy array") -def trapezoid(y: Any, x: Any | None = None, dx: float = 1.0, axis: int = -1) -> Any: +def trapezoid(y: Array, x: Array | None = None, dx: float = 1.0, axis: int = -1) -> Array: """ A wrapper of `trapz` or `trapezoid` that can handle different names in different backends. diff --git a/requirements.txt b/requirements.txt index 425117f5..acdd70cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ tqdm bilby>=2.7.0 # we don't allow parameters as state cached_interpolate>=0.3.2 wcosmo>=0.3.0 +array-api-typing @ git+https://github.com/data-apis/array-api-typing.git