Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"jaxlib==0.4.38; sys_platform == 'darwin' and platform_machine == 'x86_64'",
"tables~=3.10",
"scipy==1.15",
"pip>=26.0.1",
]

[project.scripts]
Expand All @@ -42,8 +43,8 @@ ringdown_scan = "ringdown.cli.ringdown_scan:main"
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.uv]
dev-dependencies = [
[dependency-groups]
dev = [
"ringdown",
]

Expand Down
38 changes: 32 additions & 6 deletions ringdown/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import xarray as xr
import lal
import logging
import scipy.linalg # Added for AR coefficient calculation
from .data import Data, AutoCovariance, PowerSpectrum
import jax
from . import utils
from .utils import try_parse
from .target import Target, TargetCollection
Expand Down Expand Up @@ -347,7 +349,7 @@ def update_prior(self, *args, **kwargs):
@property
def run_input(self) -> list:
"""Arguments to be passed to model function at runtime:
[times, strains, ls, fp, fc].
[times, strains, ar_coeffs, sigmas, fp, fc].
"""
if not self.has_data:
raise ValueError("no data loaded")
Expand Down Expand Up @@ -383,16 +385,40 @@ def run_input(self) -> list:
np.array(d.time) - self.start_times[i] for i, d in data_dict.items()
]

# MODIFIED: Convert ACFs to AR Coefficients + Sigma for Gohberg-Semencul
ar_coeffs = []
sigmas = []

for acf_obj in self.acfs.values():
# Extract relevant ACF slice and normalize by strain scale^2
# (since strain is scaled by 'scale', Variance is scaled by 'scale^2')
acf_vals = acf_obj.iloc[: self.n_analyze].values / scale**2

# Use Levinson-Durbin (via Toeplitz solver) to get AR coeffs
# R * a = -r where R is Toeplitz(acf[:-1]), r is acf[1:]
R = acf_vals[:-1]
r = acf_vals[1:]

# This solves for [a1, a2, ... ap]
a_coeffs = scipy.linalg.solve_toeplitz((R, R), -r)

# Prepend 1.0 to get full AR filter [1, a1, ... ap]
full_ar_coeffs = np.concatenate(([1.0], a_coeffs))
ar_coeffs.append(full_ar_coeffs)

# Calculate innovation variance (sigma^2)
# sigma^2 = gamma_0 + sum(a_k * gamma_k)
sigma_sq = acf_vals[0] + np.dot(a_coeffs, r)
sigmas.append(np.sqrt(sigma_sq))

# arguments to be passed to function returned by model_function
# make sure this agrees with that function call!
# [times, strains, ls, fp, fc]
# [times, strains, ar_coeffs, sigmas, fp, fc]
input = [
times,
[s.values / scale for s in data_dict.values()],
[
(a.iloc[: self.n_analyze] / scale**2).cholesky
for a in self.acfs.values()
],
ar_coeffs, # New argument for optimized model
sigmas, # New argument for optimized model
fp,
fc,
]
Expand Down
149 changes: 103 additions & 46 deletions ringdown/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = ["make_model", "get_arviz", "rd_design_matrix"]

import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy as jsp

Expand All @@ -14,6 +15,7 @@
from .indexing import ModeIndexList
from .result import Result
from .utils.swsh import construct_sYlm, calc_YpYc
from .utils.matrix import apply_matrix_fft_precomputed, apply_cinv_gs_fast, next_fast_len

import arviz as az
from arviz.data.base import dict_to_dataset
Expand Down Expand Up @@ -186,10 +188,10 @@ def rd_design_matrix(
dm = jnp.concatenate(
[
# Yp * Fp * cos + Yc * Fc * sin
Yp_mat * dm[:, :, :nmode] + Yc_mat * dm[:, :, 3 * nmode :],
Yp_mat * dm[:, :, :nmode] + Yc_mat * dm[:, :, 3 * nmode:],
# Yp * Fp * sin - Yc * Fc * cos
Yp_mat * dm[:, :, nmode : 2 * nmode]
- Yc_mat * dm[:, :, 2 * nmode : 3 * nmode],
Yp_mat * dm[:, :, nmode: 2 * nmode]
- Yc_mat * dm[:, :, 2 * nmode: 3 * nmode],
],
axis=2,
)
Expand Down Expand Up @@ -272,9 +274,9 @@ def get_quad_derived_quantities(
# ellip = 0 and theta = 0,pi/2 for the single polarization model
else:
apx_unit = quads[:nmodes]
apy_unit = quads[nmodes : 2 * nmodes]
acx_unit = quads[2 * nmodes : 3 * nmodes]
acy_unit = quads[3 * nmodes :]
apy_unit = quads[nmodes: 2 * nmodes]
acx_unit = quads[2 * nmodes: 3 * nmodes]
acy_unit = quads[3 * nmodes:]

numpyro.deterministic("apx", apx_unit * a_scale)
numpyro.deterministic("apy", apy_unit * a_scale)
Expand Down Expand Up @@ -525,7 +527,8 @@ def make_model(
def model(
times,
strains,
ls,
ar_coeffs,
sigmas,
fps,
fcs,
predictive: bool = predictive,
Expand All @@ -543,19 +546,43 @@ def model(
strains : array_like
The strain data; list of 1D arrays for each IFO, or a 2D array with
shape (n_det, n_times).
ls : array_like
The noise covariance matrices; list of 2D arrays for each IFO, or a
3D array with shape (n_det, n_times, n_times).
ar_coeffs : array_like
The AR coefficients of the noise covariance (used for FFT inversion).
sigmas : array_like
The innovation standard deviation of the noise.
fps : array_like
The "plus" polarization coefficients for each IFO; length `n_det`.
fcs : array_like
The "cross" polarization coefficients for each IFO; length `n_det`.
"""
times, strains, ls, fps, fcs = map(
jnp.array, (times, strains, ls, fps, fcs)
times, strains, fps, fcs = map(
jnp.array, (times, strains, fps, fcs)
)
ar_coeffs = [jnp.array(a) for a in ar_coeffs]
sigmas = jnp.array(sigmas)

n_det = times.shape[0]
n_time = times.shape[1]

# ---------------------------------------------------------------------
# Precompute FFTs for GS Inversion (Optimization Trick)
# ---------------------------------------------------------------------
# We assume AR coeffs are constant per detector.
# Find optimal FFT size P + N - 1
P = ar_coeffs[0].shape[0] # Order of AR model
n_fft = next_fast_len(n_time + P - 1)

fft_as = []
fft_bs = []

for i in range(n_det):
ac = ar_coeffs[i]
# A column: [1, a1, ... ap] padded
fft_as.append(jnp.fft.rfft(ac, n=n_fft))

# B column: [0, ap, ... a1] padded (reversed coeffs)
rev_coeffs = jnp.pad(ac[1:][::-1], (1, 0))
fft_bs.append(jnp.fft.rfft(rev_coeffs, n=n_fft))

# Here is where the particular model choice is made:
#
Expand Down Expand Up @@ -583,7 +610,8 @@ def model(
# which, happily, is provided by the composed transformation
f_latent = numpyro.sample(
"f_latent",
dist.ImproperUniform(dist.constraints.real, (), (n_modes,)),
dist.ImproperUniform(
dist.constraints.real, (), (n_modes,)),
)
f_transform = dist.transforms.ComposeTransform(
[
Expand All @@ -594,7 +622,8 @@ def model(
)
f = numpyro.deterministic("f", f_transform(f_latent))
numpyro.factor(
"f_transform", f_transform.log_abs_det_jacobian(f_latent, f)
"f_transform", f_transform.log_abs_det_jacobian(
f_latent, f)
)

g = numpyro.sample(
Expand All @@ -607,7 +636,8 @@ def model(

g_latent = numpyro.sample(
"g_latent",
dist.ImproperUniform(dist.constraints.real, (), (n_modes,)),
dist.ImproperUniform(
dist.constraints.real, (), (n_modes,)),
)
g_transform = dist.transforms.ComposeTransform(
[
Expand All @@ -618,7 +648,8 @@ def model(
)
g = numpyro.deterministic("g", g_transform(g_latent))
numpyro.factor(
"g_transform", g_transform.log_abs_det_jacobian(g_latent, g)
"g_transform", g_transform.log_abs_det_jacobian(
g_latent, g)
)
else:
f = numpyro.sample("f", dist.Uniform(f_min, f_max))
Expand Down Expand Up @@ -726,8 +757,10 @@ def model(
# and the strain (y) for the current detector
# (ndet, ntime, nquads*nmode) => (i, ntime, nquads*nmode)
M = dms[i, :, :]
L = ls[i, :, :]
y = strains[i, :]
sigma = sigmas[i]
fft_a = fft_as[i]
fft_b = fft_bs[i]

# M acts as a coordinate transformation matrix, taking us
# from the space of quadratures to the space of the data ,
Expand All @@ -744,21 +777,27 @@ def model(
# likelihood precision (M^T C^-1 M):
# A_inv = Lambda_inv + M^T C^-1 M
# so that A and A_inv are (nquads*nmode, nquads*nmode)
A_inv = Lambda_inv + jnp.dot(
M.T, jsp.linalg.cho_solve((L, True), M)
)

# 1. Compute C^{-1} terms using FFT-GS
# C^{-1} y
Cinv_y = apply_cinv_gs_fast(y, fft_a, fft_b, n_fft, sigma)

# C^{-1} M (Vectorized over columns)
# Helper closure for vmap
def fast_gs_col(col): return apply_cinv_gs_fast(
col, fft_a, fft_b, n_fft, sigma)
Cinv_M = jax.vmap(fast_gs_col, in_axes=1, out_axes=1)(M)

A_inv = Lambda_inv + jnp.dot(M.T, Cinv_M)
A_inv_chol = jsp.linalg.cholesky(A_inv, lower=True)

# we can also compute the marginal-posterior mean (a),
# which is the precision-weighted sum of the prior mean
# (mu) and the likelihood mean (M^T C^-1 y):
# a = A_inv (Lambda_inv mu + M^T C^-1 y)
# so that a is (nquads*nmode,)
a = jsp.linalg.cho_solve(
(A_inv_chol, True),
jnp.dot(Lambda_inv, mu)
+ jnp.dot(M.T, jsp.linalg.cho_solve((L, True), y)),
)
rhs = jnp.dot(Lambda_inv, mu) + jnp.dot(M.T, Cinv_y)
a = jsp.linalg.cho_solve((A_inv_chol, True), rhs)

# the mean (b) of the marginal likelihood p(y|b, B),
# i.e., the likelihood obtained after integrating out
Expand Down Expand Up @@ -789,17 +828,14 @@ def model(
# ignore the 2pi factor since it introduces a term like
# - 0.5*ntime*log(2pi), which is constant
r = y - b
Cinv_r = jsp.linalg.cho_solve((L, True), r)

M_A_Mt_Cinv_r = jnp.dot(
M,
jsp.linalg.cho_solve(
(A_inv_chol, True), jnp.dot(M.T, Cinv_r)
),
)
# Need C^{-1} r
Cinv_r = apply_cinv_gs_fast(r, fft_a, fft_b, n_fft, sigma)

Cinv_M_A_Mt_Cinv_r = jsp.linalg.cho_solve(
(L, True), M_A_Mt_Cinv_r
Mt_Cinv_r = jnp.dot(M.T, Cinv_r)
woodbury_corr = jnp.dot(
Mt_Cinv_r.T,
jsp.linalg.cho_solve((A_inv_chol, True), Mt_Cinv_r)
)

# now all we have left to compute is the log determinant
Expand All @@ -814,16 +850,21 @@ def model(
# writing similarly for |A| and |Lambda|, we thus have
# that log_sqrt_det_B = 0.5 log|B| is
# (note that |A| = -|A_inv|)

# Log Determinants
# log|C| approx 2 * N * log(sigma) for AR process
log_det_C = 2.0 * len(y) * jnp.log(sigma)

log_sqrt_det_B = (
jnp.sum(jnp.log(jnp.diag(L)))
0.5 * log_det_C
- jnp.sum(jnp.log(jnp.diag(Lambda_inv_chol)))
+ jnp.sum(jnp.log(jnp.diag(A_inv_chol)))
)

# putting it all together we can get the contribution
# to the log likelihood from this detector
logl = (
-0.5 * jnp.dot(r, Cinv_r - Cinv_M_A_Mt_Cinv_r)
-0.5 * (jnp.dot(r, Cinv_r) - woodbury_corr)
- log_sqrt_det_B
)

Expand Down Expand Up @@ -952,13 +993,18 @@ def model(

if not prior:
for i, strain in enumerate(strains):
numpyro.sample(
f"logl_{i}",
dist.MultivariateNormal(
h_det[i, :], scale_tril=ls[i, :, :]
),
obs=strain,
)
# Compute log likelihood using GS: -0.5 * r^T C^{-1} r - log_det
model_strain = h_det[i, :]
r = strain - model_strain
sigma = sigmas[i]

Cinv_r = apply_cinv_gs_fast(
r, fft_as[i], fft_bs[i], n_fft, sigma)
log_det_C = 2.0 * len(r) * jnp.log(sigma)

logl = -0.5 * jnp.dot(r, Cinv_r) - 0.5 * log_det_C

numpyro.factor(f"logl_{i}", logl)

return model

Expand Down Expand Up @@ -1049,7 +1095,8 @@ def get_arviz(
if len(modes) != n_mode:
raise ValueError(f"expected {n_mode} modes, got {len(modes)}")
# get ifo from shape of Fc, assuming it's last argument provided to model
n_ifo = len(sampler._args[-1])
# NOTE: In new model signature, fcs is at index 5. strains is index 1.
n_ifo = len(sampler._args[1])
if ifos is None:
ifos = np.arange(n_ifo, dtype=int)
elif len(ifos) != n_ifo:
Expand All @@ -1073,20 +1120,30 @@ def get_arviz(
in_dims = {
"time": ["ifo", "time_index"],
"strain": ["ifo", "time_index"],
"cholesky_factor": ["ifo", "time_index", "time_index_1"],
"ar_coeffs": ["ifo", "ar_lag"],
"sigma": ["ifo"],
"fp": ["ifo"],
"fc": ["ifo"],
"epoch": ["ifo"],
}
# Updated mapping for new model signature
# args: (times, strains, ar_coeffs, sigmas, fps, fcs)
args = sampler._args
in_data = {
k: np.array(v) for k, v in zip(in_dims.keys(), sampler._args)
"time": np.array(args[0]),
"strain": np.array(args[1]),
# "ar_coeffs": np.array(args[2]), # Might be list of arrays, skip for simple netcdf
"sigma": np.array(args[3]),
"fp": np.array(args[4]),
"fc": np.array(args[5]),
}
in_data["epoch"] = np.array(epoch)
in_data["scale"] = scale or 1.0
# get injections, if provided
if injections is not None:
in_data["injection"] = np.array(injections)
in_dims["injection"] = ["ifo", "time_index"]

dims.update(in_dims)
obs_data = {"strain": in_data.pop("strain")}
else:
Expand Down
Loading