From 35ad3a31f866475f3d1ed99e586a1734e6c5b57c Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 19 Feb 2026 15:14:36 -0500 Subject: [PATCH 1/7] add day of week effect code and tests --- pyrenew/model/multisignal_model.py | 24 +++ pyrenew/observation/base.py | 30 +++ pyrenew/observation/count_observations.py | 78 +++++++- test/test_observation_counts.py | 226 ++++++++++++++++++++++ 4 files changed, 356 insertions(+), 2 deletions(-) diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index d9aad03d..0026309c 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -131,6 +131,30 @@ def pad_observations( padding = jnp.full(pad_shape, jnp.nan) return jnp.concatenate([padding, obs], axis=axis) + def compute_first_day_dow(self, obs_start_dow: int) -> int: + """ + Compute the day of the week for the start of the shared time axis. + + The shared time axis begins ``n_init`` days before the first + observation. This method converts the known day of the week of + the first observation into the day of the week of the shared + time axis start (element 0), accounting for the initialization + period offset. + + Parameters + ---------- + obs_start_dow : int + Day of the week of the first observation day + (0=Monday, 6=Sunday, ISO convention). + + Returns + ------- + int + Day of the week for element 0 of the shared time axis. + """ + n_init = self.latent.n_initialization_points + return (obs_start_dow - n_init % 7) % 7 + def shift_times(self, times: jnp.ndarray) -> jnp.ndarray: """ Shift time indices from natural coordinates to shared time axis. diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index b87df223..04e1c9fc 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -185,6 +185,36 @@ def _validate_pmf( if jnp.any(pmf < 0): raise ValueError(f"{param_name} must have non-negative values") + def _validate_dow_effect( + self, + dow_effect: ArrayLike, + param_name: str, + ) -> None: + """ + Validate a day-of-week effect vector. + + Checks that the vector has exactly 7 non-negative elements + (one per day, 0=Monday through 6=Sunday, ISO convention). + + Parameters + ---------- + dow_effect : ArrayLike + Day-of-week multiplicative effects to validate. + param_name : str + Name of the parameter (for error messages). + + Raises + ------ + ValueError + If shape is not (7,) or any values are negative. + """ + if dow_effect.shape != (7,): + raise ValueError( + f"{param_name} must return shape (7,), got {dow_effect.shape}" + ) + if jnp.any(dow_effect < 0): + raise ValueError(f"{param_name} must have non-negative values") + def _convolve_with_alignment( self, latent_incidence: ArrayLike, diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 72199f77..8f86d2a1 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -11,11 +11,13 @@ import jax.numpy as jnp from jax.typing import ArrayLike +from pyrenew.arrayutils import tile_until_n from pyrenew.convolve import compute_prop_already_reported from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise from pyrenew.observation.types import ObservationSample +from pyrenew.time import validate_dow class _CountBase(BaseObservationProcess): @@ -32,6 +34,7 @@ def __init__( delay_distribution_rv: RandomVariable, noise: CountNoise, right_truncation_rv: RandomVariable | None = None, + day_of_week_rv: RandomVariable | None = None, ) -> None: """ Initialize count observation base. @@ -52,11 +55,22 @@ def __init__( When provided (along with ``right_truncation_offset`` at sample time), predicted counts are scaled down for recent timepoints to account for incomplete reporting. + day_of_week_rv : RandomVariable | None + Optional day-of-week multiplicative effect. Must sample to + shape (7,) with non-negative values, where entry j is the + multiplier for day-of-week j (0=Monday, 6=Sunday, ISO + convention). An effect of 1.0 means no adjustment for that + day. Values summing to 7.0 preserve weekly totals and keep + the ascertainment rate interpretable; other sums rescale + overall predicted counts. When provided (along with + ``first_day_dow`` at sample time), predicted counts are + scaled by a periodic weekly pattern. """ super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv) self.ascertainment_rate_rv = ascertainment_rate_rv self.noise = noise self.right_truncation_rv = right_truncation_rv + self.day_of_week_rv = day_of_week_rv def validate(self) -> None: """ @@ -84,6 +98,10 @@ def validate(self) -> None: rt_pmf = self.right_truncation_rv() self._validate_pmf(rt_pmf, "right_truncation_rv") + if self.day_of_week_rv is not None: + dow_effect = self.day_of_week_rv() + self._validate_dow_effect(dow_effect, "day_of_week_rv") + def lookback_days(self) -> int: """ Return required lookback days for this observation. @@ -191,6 +209,42 @@ def _apply_right_truncation( prop = prop[:, None] return predicted * prop + def _apply_day_of_week( + self, + predicted: ArrayLike, + first_day_dow: int, + ) -> ArrayLike: + """ + Apply day-of-week multiplicative adjustment to predicted counts. + + Tiles a 7-element effect vector across the full time axis, + aligned to the calendar via ``first_day_dow``. NaN values + in the initialization period propagate unchanged (NaN * effect = NaN), + which is correct since masked days are excluded from the likelihood. + + Parameters + ---------- + predicted : ArrayLike + Predicted counts. Shape: (n_timepoints,) or + (n_timepoints, n_subpops). + first_day_dow : int + Day of the week for element 0 of the time axis + (0=Monday, 6=Sunday, ISO convention). + + Returns + ------- + ArrayLike + Adjusted predicted counts, same shape as input. + """ + validate_dow(first_day_dow, "first_day_dow") + dow_effect = self.day_of_week_rv() + n_timepoints = predicted.shape[0] + daily_effect = tile_until_n(dow_effect, n_timepoints, offset=first_day_dow) + self._deterministic("day_of_week_effect", daily_effect) + if predicted.ndim == 2: + daily_effect = daily_effect[:, None] + return predicted * daily_effect + class Counts(_CountBase): """ @@ -231,7 +285,8 @@ def __repr__(self) -> str: f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " f"delay_distribution_rv={self.temporal_pmf_rv!r}, " f"noise={self.noise!r}, " - f"right_truncation_rv={self.right_truncation_rv!r})" + f"right_truncation_rv={self.right_truncation_rv!r}, " + f"day_of_week_rv={self.day_of_week_rv!r})" ) def validate_data( @@ -268,6 +323,7 @@ def sample( infections: ArrayLike, obs: ArrayLike | None = None, right_truncation_offset: int | None = None, + first_day_dow: int | None = None, ) -> ObservationSample: """ Sample aggregated counts. @@ -288,6 +344,12 @@ def sample( right_truncation_offset : int | None If provided (and ``right_truncation_rv`` was set at construction), apply right-truncation adjustment to predicted counts. + first_day_dow : int | None + Day of the week for the first timepoint on the shared time + axis (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction. Use + ``model.compute_first_day_dow(obs_start_dow)`` to convert + from the day of the week of the first observation. Returns ------- @@ -296,6 +358,8 @@ def sample( `predicted` (predicted counts before noise, shape: n_total). """ predicted_counts = self._predicted_obs(infections) + if self.day_of_week_rv is not None and first_day_dow is not None: + predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow) if self.right_truncation_rv is not None and right_truncation_offset is not None: predicted_counts = self._apply_right_truncation( predicted_counts, right_truncation_offset @@ -358,7 +422,8 @@ def __repr__(self) -> str: f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " f"delay_distribution_rv={self.temporal_pmf_rv!r}, " f"noise={self.noise!r}, " - f"right_truncation_rv={self.right_truncation_rv!r})" + f"right_truncation_rv={self.right_truncation_rv!r}, " + f"day_of_week_rv={self.day_of_week_rv!r})" ) def infection_resolution(self) -> str: @@ -419,6 +484,7 @@ def sample( subpop_indices: ArrayLike, obs: ArrayLike | None = None, right_truncation_offset: int | None = None, + first_day_dow: int | None = None, ) -> ObservationSample: """ Sample subpopulation-level counts. @@ -443,6 +509,12 @@ def sample( right_truncation_offset : int | None If provided (and ``right_truncation_rv`` was set at construction), apply right-truncation adjustment to predicted counts. + first_day_dow : int | None + Day of the week for the first timepoint on the shared time + axis (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction. Use + ``model.compute_first_day_dow(obs_start_dow)`` to convert + from the day of the week of the first observation. Returns ------- @@ -451,6 +523,8 @@ def sample( `predicted` (predicted counts before noise, shape: n_total x n_subpops). """ predicted_counts = self._predicted_obs(infections) + if self.day_of_week_rv is not None and first_day_dow is not None: + predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow) if self.right_truncation_rv is not None and right_truncation_offset is not None: predicted_counts = self._apply_right_truncation( predicted_counts, right_truncation_offset diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 0190d429..717478e6 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -576,5 +576,231 @@ def test_counts_by_subpop_2d_broadcasting(self): assert jnp.allclose(result.predicted[:, 0], result.predicted[:, 1]) +class TestDayOfWeek: + """Test day-of-week multiplicative adjustment in count observations.""" + + def test_no_dow_rv_unchanged(self, simple_delay_pmf): + """Test that day_of_week_rv=None ignores first_day_dow.""" + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + infections = jnp.ones(20) * 1000 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample(infections=infections, obs=None, first_day_dow=3) + + assert jnp.allclose(result.predicted, 10.0) + + def test_dow_rv_without_offset_unchanged(self, simple_delay_pmf): + """Test that first_day_dow=None skips adjustment.""" + dow_effect = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(20) * 1000 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample(infections=infections, obs=None, first_day_dow=None) + + assert jnp.allclose(result.predicted, 10.0) + + def test_uniform_dow_effect_unchanged(self, simple_delay_pmf): + """Test that uniform effect [1,1,...,1] leaves predictions unchanged.""" + dow_effect = jnp.ones(7) + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(14) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample(infections=infections, obs=None, first_day_dow=0) + + assert jnp.allclose(result.predicted, 100.0) + + def test_dow_effect_scales_predictions(self, simple_delay_pmf): + """Test that known day-of-week effects produce correct per-day scaling. + + With constant infections of 100, ascertainment 1.0, no delay, + and first_day_dow=0 (Monday), element i of predicted should + equal 100 * dow_effect[i % 7]. + """ + dow_effect = jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5]) + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(14) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample(infections=infections, obs=None, first_day_dow=0) + + assert jnp.isclose(result.predicted[0], 200.0) + assert jnp.isclose(result.predicted[1], 150.0) + assert jnp.isclose(result.predicted[4], 50.0) + assert jnp.isclose(result.predicted[7], 200.0) + + def test_dow_offset_shifts_pattern(self, simple_delay_pmf): + """Test that first_day_dow offsets the weekly pattern correctly. + + Starting on Wednesday (dow=2) means element 0 gets + dow_effect[2], element 1 gets dow_effect[3], etc. + """ + dow_effect = jnp.array([2.0, 1.5, 1.0, 0.8, 0.7, 0.5, 0.5]) + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(7) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample(infections=infections, obs=None, first_day_dow=2) + + assert jnp.isclose(result.predicted[0], 100.0) + assert jnp.isclose(result.predicted[1], 80.0) + assert jnp.isclose(result.predicted[5], 200.0) + + def test_deterministic_site_recorded(self, simple_delay_pmf): + """Test that day_of_week_effect deterministic site is recorded.""" + dow_effect = jnp.ones(7) + process = Counts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(10) * 100 + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + process.sample(infections=infections, obs=None, first_day_dow=0) + + assert "ed_day_of_week_effect" in trace + effect = trace["ed_day_of_week_effect"]["value"] + assert effect.shape == (10,) + + def test_counts_by_subpop_2d_broadcasting(self): + """Test day-of-week with CountsBySubpop 2D infections.""" + dow_effect = jnp.array([2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + delay_pmf = jnp.array([1.0]) + process = CountsBySubpop( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + + n_days = 14 + n_subpops = 3 + infections = jnp.ones((n_days, n_subpops)) * 100 + times = jnp.array([0, 1, 7]) + subpop_indices = jnp.array([0, 1, 2]) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + times=times, + subpop_indices=subpop_indices, + obs=None, + first_day_dow=0, + ) + + assert result.predicted.shape == (n_days, n_subpops) + assert jnp.isclose(result.predicted[0, 0], 200.0) + assert jnp.isclose(result.predicted[1, 0], 100.0) + assert jnp.isclose(result.predicted[7, 0], 200.0) + assert jnp.allclose(result.predicted[:, 0], result.predicted[:, 1]) + + def test_dow_with_right_truncation(self, simple_delay_pmf): + """Test that day-of-week and right-truncation compose correctly. + + Day-of-week is applied first, then right-truncation scales + the adjusted predictions. + """ + dow_effect = jnp.array([2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + rt_pmf = jnp.array([0.2, 0.3, 0.5]) + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + right_truncation_rv=DeterministicPMF("rt_delay", rt_pmf), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(10) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + right_truncation_offset=0, + first_day_dow=0, + ) + + assert jnp.isclose(result.predicted[0], 200.0) + assert jnp.isclose(result.predicted[1], 100.0) + assert result.predicted[-1] < result.predicted[0] + + def test_validate_catches_wrong_shape(self, simple_delay_pmf): + """Test that validate() rejects non-length-7 effect vectors.""" + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", jnp.ones(5)), + ) + with pytest.raises(ValueError, match="must return shape \\(7,\\)"): + process.validate() + + def test_validate_catches_negative_values(self, simple_delay_pmf): + """Test that validate() rejects negative effect values.""" + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable( + "dow", jnp.array([1.0, 1.0, 1.0, -0.5, 1.0, 1.0, 1.0]) + ), + ) + with pytest.raises(ValueError, match="must have non-negative values"): + process.validate() + + def test_invalid_first_day_dow_raises(self, simple_delay_pmf): + """Test that out-of-range first_day_dow raises ValueError.""" + dow_effect = jnp.ones(7) + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(14) * 100 + + with numpyro.handlers.seed(rng_seed=42): + with pytest.raises(ValueError, match="Day-of-week"): + process.sample(infections=infections, obs=None, first_day_dow=7) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From e9e138d883e650ec4bca6e59c4c78cb45bb9ade9 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 19 Feb 2026 15:39:26 -0500 Subject: [PATCH 2/7] adding tutorial --- docs/tutorials/.pages | 1 + docs/tutorials/day_of_week_effects.qmd | 414 +++++++++++++++++++++++++ test/test_observation_counts.py | 36 +++ 3 files changed, 451 insertions(+) create mode 100644 docs/tutorials/day_of_week_effects.qmd diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 0d261a7a..e615fc3e 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -5,4 +5,5 @@ nav: - observation_processes_measurements.md - latent_hierarchical_infections.md - right_truncation.md + - day_of_week_effects.md - periodic_effects.md diff --git a/docs/tutorials/day_of_week_effects.qmd b/docs/tutorials/day_of_week_effects.qmd new file mode 100644 index 00000000..3a9aa9ff --- /dev/null +++ b/docs/tutorials/day_of_week_effects.qmd @@ -0,0 +1,414 @@ +--- +title: Day-of-week effects for count data +format: + gfm: + code-fold: true +engine: jupyter +jupyter: + jupytext: + text_representation: + extension: .qmd + format_name: quarto + format_version: '1.0' + jupytext_version: 1.18.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +```{python} +# | label: setup +# | output: false +import jax.numpy as jnp +import numpy as np +import numpyro +import pandas as pd + +from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +from pyrenew import datasets + +import plotnine as p9 +from plotnine.exceptions import PlotnineWarning +import warnings + +warnings.filterwarnings("ignore", category=PlotnineWarning) + +from _tutorial_theme import theme_tutorial +``` + +Many health surveillance signals exhibit strong day-of-week patterns. +Emergency department visits and hospital admissions tend to be higher on weekdays and lower on weekends, driven by staffing, patient behavior, and reporting practices. +Ignoring this weekly periodicity forces the noise model to absorb systematic variation, inflating dispersion estimates and obscuring the underlying epidemic trend. + +PyRenew models day-of-week effects as a **multiplicative adjustment** applied to predicted counts after the delay convolution and ascertainment scaling: + +$$\lambda(t) = d_{w(t)} \cdot \alpha \sum_{s} I(t-s)\,\pi(s)$$ + +where $d_{w(t)}$ is the day-of-week multiplier for the weekday of timepoint $t$, $\alpha$ is the ascertainment rate, and $\pi(s)$ is the delay PMF. +The effect vector $\mathbf{d} = (d_0, d_1, \ldots, d_6)$ has one entry per day (0=Monday through 6=Sunday, ISO convention). +An effect of 1.0 means no adjustment for that day. +When the effects sum to 7.0, the average daily multiplier is 1.0, preserving weekly totals and keeping the ascertainment rate directly interpretable as the fraction of infections observed. + +## Defining a day-of-week effect + +A typical pattern for ED visits might show weekday effects above 1.0 and weekend effects below 1.0: + +```{python} +# | label: define-dow-effect +dow_values = jnp.array([1.20, 1.15, 1.10, 1.05, 1.00, 0.75, 0.75]) +day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] + +print(f"Day-of-week effects: {np.round(np.array(dow_values), 2)}") +print(f"Sum: {float(jnp.sum(dow_values)):.2f}") +``` + +```{python} +# | label: plot-dow-effect +dow_df = pd.DataFrame({"day": day_names, "effect": np.array(dow_values)}) +dow_df["day"] = pd.Categorical( + dow_df["day"], categories=day_names, ordered=True +) + +( + p9.ggplot(dow_df, p9.aes(x="day", y="effect")) + + p9.geom_col(fill="steelblue", alpha=0.7, color="black") + + p9.geom_hline(yintercept=1.0, linetype="dashed", color="grey") + + p9.labs( + x="Day of Week", + y="Multiplicative Effect", + title="Day-of-Week Effect Vector", + ) + + theme_tutorial +) +``` + +Values above the dashed line (1.0) increase predicted counts for that day; values below decrease them. +Monday at 1.20 means 20% more counts than an average day; Saturday and Sunday at 0.75 mean 25% fewer. + +## Observation process with and without day-of-week effects + +We construct two `Counts` observation processes using the same delay distribution and ascertainment rate. +The only difference is whether `day_of_week_rv` is provided. + +```{python} +# | label: create-processes +hosp_delay_pmf = jnp.array( + datasets.load_infection_admission_interval()["probability_mass"].to_numpy() +) +delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) +ihr_rv = DeterministicVariable("ihr", 0.01) +concentration_rv = DeterministicVariable("concentration", 20.0) + +process_no_dow = Counts( + name="hosp_no_dow", + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), +) + +process_with_dow = Counts( + name="hosp_dow", + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), + day_of_week_rv=DeterministicVariable("dow_effect", dow_values), +) +``` + +We simulate a growing epidemic and generate predicted counts from both processes. +The `first_day_dow` parameter tells PyRenew which day of the week corresponds to element 0 of the time axis. +Here we set `first_day_dow=0` (Monday). + +```{python} +# | label: simulate-and-sample +day_one = process_no_dow.lookback_days() +n_total = 130 +infections = 5000.0 * jnp.exp(0.03 * jnp.arange(n_total)) + +with numpyro.handlers.seed(rng_seed=0): + result_no_dow = process_no_dow.sample(infections=infections, obs=None) +with numpyro.handlers.seed(rng_seed=0): + result_with_dow = process_with_dow.sample( + infections=infections, obs=None, first_day_dow=0 + ) +``` + +```{python} +# | label: plot-predicted-comparison +n_plot_days = n_total - day_one +pred_rows = [] +for i in range(n_plot_days): + day_idx = day_one + i + pred_rows.append( + { + "day": i, + "admissions": float(result_no_dow.predicted[day_idx]), + "type": "No day-of-week effect", + } + ) + pred_rows.append( + { + "day": i, + "admissions": float(result_with_dow.predicted[day_idx]), + "type": "With day-of-week effect", + } + ) +pred_df = pd.DataFrame(pred_rows) +pred_df["type"] = pd.Categorical( + pred_df["type"], + categories=["No day-of-week effect", "With day-of-week effect"], + ordered=True, +) + +( + p9.ggplot( + pred_df, p9.aes(x="day", y="admissions", color="type", linetype="type") + ) + + p9.geom_line(size=1) + + p9.scale_color_manual(values=["steelblue", "#e41a1c"]) + + p9.scale_linetype_manual(values=["solid", "dashed"]) + + p9.labs( + x="Day", + y="Predicted Admissions", + title="Predicted Admissions:\nWith vs. Without Day-of-Week Effect", + color="", + linetype="", + ) + + theme_tutorial +) +``` + +Without the day-of-week effect the predicted curve is smooth. +With it, the curve oscillates with a 7-day period — dipping on weekends and rising on weekdays — while following the same overall trend. + +## Effect of the offset + +The `first_day_dow` parameter aligns the weekly pattern to the calendar. +Changing it shifts which days receive which multiplier. +Here we compare starting on Monday vs. Wednesday: + +```{python} +# | label: offset-comparison +with numpyro.handlers.seed(rng_seed=0): + result_monday = process_with_dow.sample( + infections=infections, obs=None, first_day_dow=0 + ) +with numpyro.handlers.seed(rng_seed=0): + result_wednesday = process_with_dow.sample( + infections=infections, obs=None, first_day_dow=2 + ) +``` + +```{python} +# | label: plot-offset-comparison +offset_rows = [] +for i in range(21): + day_idx = day_one + i + offset_rows.append( + { + "day": i, + "admissions": float(result_monday.predicted[day_idx]), + "offset": "first_day_dow=0 (Monday)", + } + ) + offset_rows.append( + { + "day": i, + "admissions": float(result_wednesday.predicted[day_idx]), + "offset": "first_day_dow=2 (Wednesday)", + } + ) +offset_df = pd.DataFrame(offset_rows) +offset_df["offset"] = pd.Categorical( + offset_df["offset"], + categories=[ + "first_day_dow=0 (Monday)", + "first_day_dow=2 (Wednesday)", + ], + ordered=True, +) + +( + p9.ggplot( + offset_df, + p9.aes(x="day", y="admissions", color="offset"), + ) + + p9.geom_line(size=1) + + p9.geom_point(size=2) + + p9.scale_color_manual(values=["steelblue", "#e41a1c"]) + + p9.labs( + x="Day", + y="Predicted Admissions", + title="Effect of first_day_dow on Weekly Pattern Alignment", + color="", + ) + + theme_tutorial +) +``` + +The two curves have the same shape but are phase-shifted: their weekend dips fall on different days. +Getting `first_day_dow` right matters — a misaligned offset would attribute Monday's high to Sunday or vice versa. + +When using `MultiSignalModel`, the shared time axis starts `n_init` days before the first observation. +The convenience method `model.compute_first_day_dow(obs_start_dow)` converts the known day of the week of the first observation to the correct offset for element 0 of the time axis. + +## Sampled observations + +Day-of-week effects shape the noise draws, not just the predicted means. +The noise model samples from a distribution centered on the adjusted predictions, so sampled observations inherit the weekly pattern. + +```{python} +# | label: sample-noisy +n_samples = 30 +noisy_results = [] +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + result_no = process_no_dow.sample(infections=infections, obs=None) + with numpyro.handlers.seed(rng_seed=seed): + result_yes = process_with_dow.sample( + infections=infections, obs=None, first_day_dow=0 + ) + for i in range(n_plot_days): + day_idx = day_one + i + noisy_results.append( + { + "day": i, + "admissions": float(result_no.observed[day_idx]), + "type": "No day-of-week effect", + "sample": seed, + } + ) + noisy_results.append( + { + "day": i, + "admissions": float(result_yes.observed[day_idx]), + "type": "With day-of-week effect", + "sample": seed, + } + ) +``` + +```{python} +# | label: plot-noisy +noisy_df = pd.DataFrame(noisy_results) +mean_df = noisy_df.groupby(["day", "type"])["admissions"].mean().reset_index() + +( + p9.ggplot(noisy_df, p9.aes(x="day", y="admissions")) + + p9.geom_line( + p9.aes(group="sample"), alpha=0.15, size=0.4, color="steelblue" + ) + + p9.geom_line( + data=mean_df, + mapping=p9.aes(x="day", y="admissions"), + color="#e41a1c", + size=1.2, + ) + + p9.facet_wrap("~ type", ncol=1) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Sampled Observations:\nWith vs. Without Day-of-Week Effect", + ) + + theme_tutorial +) +``` + +The top panel shows smooth variation around the trend. +The bottom panel shows systematic weekly oscillation in both the mean (red) and individual samples (blue) — the weekend dips are visible even through the noise. + +## Composing with right-truncation + +Day-of-week effects and right-truncation are independent adjustments that compose naturally. +Day-of-week is applied first (adjusting the expected counts for reporting patterns), then right-truncation scales down recent counts for incomplete reporting: + +$$\lambda(t) = F(k_t) \cdot d_{w(t)} \cdot \alpha \sum_s I(t-s)\,\pi(s)$$ + +```{python} +# | label: compose-with-truncation +reporting_delay_pmf = jnp.array([0.4, 0.3, 0.15, 0.08, 0.04, 0.02, 0.01]) + +process_both = Counts( + name="hosp_both", + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), + day_of_week_rv=DeterministicVariable("dow_effect", dow_values), + right_truncation_rv=DeterministicPMF( + "reporting_delay", reporting_delay_pmf + ), +) + +with numpyro.handlers.seed(rng_seed=0): + result_both = process_both.sample( + infections=infections, + obs=None, + first_day_dow=0, + right_truncation_offset=0, + ) +``` + +```{python} +# | label: plot-composed +compose_rows = [] +for i in range(n_plot_days): + day_idx = day_one + i + compose_rows.append( + { + "day": i, + "admissions": float(result_with_dow.predicted[day_idx]), + "type": "Day-of-week only", + } + ) + compose_rows.append( + { + "day": i, + "admissions": float(result_both.predicted[day_idx]), + "type": "Day-of-week + right-truncation", + } + ) +compose_df = pd.DataFrame(compose_rows) +compose_df["type"] = pd.Categorical( + compose_df["type"], + categories=["Day-of-week only", "Day-of-week + right-truncation"], + ordered=True, +) + +( + p9.ggplot( + compose_df, + p9.aes(x="day", y="admissions", color="type", linetype="type"), + ) + + p9.geom_line(size=1) + + p9.scale_color_manual(values=["steelblue", "#e41a1c"]) + + p9.scale_linetype_manual(values=["solid", "dashed"]) + + p9.labs( + x="Day", + y="Predicted Admissions", + title="Day-of-Week Effect Composed with Right-Truncation", + color="", + linetype="", + ) + + theme_tutorial +) +``` + +The two curves agree in the early period. +Near the right edge, right-truncation pulls the curve downward on top of the weekly oscillation. +Each adjustment operates on its own concern — weekly reporting patterns vs. incomplete recent data — and they combine multiplicatively without interfering. + +## Summary + +Day-of-week adjustment is enabled by passing a `day_of_week_rv` at construction time and a `first_day_dow` at sample time. + +| Parameter | Where | Purpose | +|-----------|-------|---------| +| `day_of_week_rv` | Constructor | 7-element multiplicative effect vector (0=Mon, 6=Sun) | +| `first_day_dow` | `sample()` | Day of the week for element 0 of the time axis | + +When either is `None`, the adjustment is disabled and the process behaves identically to one without day-of-week effects. + +The effect vector can be supplied as a fixed `DeterministicVariable` from empirical data, or as a stochastic `RandomVariable` (e.g., a scaled Dirichlet prior) to infer the weekly pattern from data. +Effects summing to 7.0 preserve weekly totals and keep the ascertainment rate interpretable; other sums rescale overall predicted counts. diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 717478e6..c627aa67 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -653,6 +653,42 @@ def test_dow_effect_scales_predictions(self, simple_delay_pmf): assert jnp.isclose(result.predicted[4], 50.0) assert jnp.isclose(result.predicted[7], 200.0) + def test_dow_effect_with_multiday_delay(self, short_delay_pmf): + """Test that DOW ratios are correct with a multi-day delay PMF. + + With a 2-day delay, the first element is NaN (init period). + Post-init predicted values should satisfy: + predicted_with_dow[t] / predicted_no_dow[t] == dow_effect[t % 7]. + """ + dow_effect = jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5]) + process_no_dow = Counts( + name="base", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", short_delay_pmf), + noise=PoissonNoise(), + ) + process_with_dow = Counts( + name="dow", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", short_delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(21) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result_no = process_no_dow.sample(infections=infections, obs=None) + with numpyro.handlers.seed(rng_seed=42): + result_yes = process_with_dow.sample( + infections=infections, obs=None, first_day_dow=0 + ) + + day_one = 1 + for t in range(day_one, 14): + expected_ratio = float(dow_effect[t % 7]) + actual_ratio = float(result_yes.predicted[t] / result_no.predicted[t]) + assert jnp.isclose(actual_ratio, expected_ratio, atol=1e-5) + def test_dow_offset_shifts_pattern(self, simple_delay_pmf): """Test that first_day_dow offsets the weekly pattern correctly. From 243f61e730639fbdcc0b5a658382f22096024475 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 25 Feb 2026 21:57:15 -0500 Subject: [PATCH 3/7] upate counts tutorial; modify counts.py accordingly --- .../observation_processes_counts.qmd | 517 ++++++++++++++++-- pyrenew/datasets/__init__.py | 8 +- pyrenew/datasets/hospital_admissions.py | 55 +- pyrenew/observation/__init__.py | 4 +- pyrenew/observation/count_observations.py | 19 +- test/test_interface_coverage.py | 23 +- 6 files changed, 547 insertions(+), 79 deletions(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 3f1305a2..0b900c55 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -1,13 +1,22 @@ --- -title: "Observation processes for count data" +title: Observation processes for count data format: gfm: code-fold: true engine: jupyter +jupyter: + jupytext: + text_representation: + extension: .qmd + format_name: quarto + format_version: '1.0' + jupytext_version: 1.18.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 --- -This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. - ```{python} # | label: setup # | output: false @@ -22,7 +31,12 @@ from plotnine.exceptions import PlotnineWarning warnings.filterwarnings("ignore", category=PlotnineWarning) from _tutorial_theme import theme_tutorial -from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise +from pyrenew.observation import ( + CountBase, + Counts, + NegativeBinomialNoise, + PoissonNoise, +) from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew import datasets ``` @@ -30,9 +44,21 @@ from pyrenew import datasets ## Overview Count observation processes model the lag between infections and an observed outcome such as hospital admissions, emergency department visits, confirmed cases, or deaths. -Observed data can be aggregated or available as subpopulation-level counts, which are modeled by classes `Counts` and `CountsBySubpop`, respectively. +All count observation processes inherit from `CountBase`, which provides the core ascertainment x delay convolution pipeline. +PyRenew provides two built-in subclasses: `Counts` for aggregate daily counts and `CountsBySubpop` for subpopulation-level daily counts. +We demonstrates how to extend `CountBase` to create an observation process for weekly data. + +### The generative model + +In PyRenew, the observation process maps latent daily infections *forward* to observed data. +The generative direction is always:
+   latent infections $\to$ ascertainment $\times$ delay convolution $\to$ predicted counts $\to$ noise $\to$ observed data.
+This forward direction is fundamental: the observation process transforms predictions to match the scale and resolution of the data, never the reverse. +During inference, the likelihood compares model predictions to observed data at the data's own resolution. + +### The count observation equation -Count observation processes transform infections into predicted counts by applying an event probability and/or ascertainment rate and convolving with a delay distribution. +Count observation processes transform infections into predicted counts by applying an ascertainment rate and convolving with a delay distribution. The predicted observations on day $t$ are: @@ -44,12 +70,72 @@ where: - $\alpha$ is the rate of ascertained counts per infection (e.g., infection-to-hospital admission rate). This can model a mix of biological effects (e.g. some percentage of infections lead to hospital admissions, but not all) and reporting effects (e.g. some percentage of admissions that occur are reported, but not all). - $\pi(s)$ is the delay distribution from infection to observation, conditional on an infection leading to an observation -Discrete observations are generated by sampling from a noise distribution—e.g. Poisson or negative binomial—to model reporting variability. +Discrete observations are generated by sampling from a noise distribution - e.g. Poisson or negative binomial - to model reporting variability. Poisson assumes variance equals the mean; negative binomial accommodates the overdispersion common in surveillance data. -**Note on terminology:** In real-world inference, incident infections are typically a *latent* (unobserved) quantity and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. +## The CountBase Class + +`CountBase` provides the core ascertainment $\times$ delay convolution operations that all subclasses inherit: + +| Inherited method | What it does | +|---|---| +| `_predicted_obs(infections)` | Convolves infections with the delay PMF, scaled by the ascertainment rate. Returns daily predicted counts. | +| `validate()` | Validates the delay PMF, ascertainment rate, noise model, and optional day-of-week / right-truncation parameters. | +| `lookback_days()` | Returns `len(delay_pmf) - 1`: the number of initialization days needed. | +| `_apply_day_of_week(predicted, first_day_dow)` | Applies a multiplicative 7-day periodic pattern (optional). | +| `_apply_right_truncation(predicted, offset)` | Scales recent predictions for incomplete reporting (optional). | + +The last two methods are optional adjustments that subclasses can apply in their `sample()` method. Day-of-week effects model systematic within-week reporting patterns (see [Day-of-Week Effects](day_of_week_effects.md)). Right-truncation adjustment accounts for incomplete reporting of recent observations (see [Right Truncation](right_truncation.md)). + +Subclasses must implement three methods: -## Hospital admissions example +| Required method | Purpose | +|---|---| +| `infection_resolution()` | Return `"aggregate"` or `"subpop"` to tell the model which latent infections to route to this observation. | +| `sample(infections, ...)` | The main forward model: compute predictions, apply any transformations, then sample from the noise distribution. | +| `validate_data(n_total, n_subpops, ...)` | Validate observation data shapes before JAX tracing begins. + +### Built-in subclasses + +PyRenew provides two observation process classes: + +- **`Counts`**: Aggregate daily counts. `infection_resolution()` returns `"aggregate"`. Accepts 1D infections of shape `(n_total,)`. Observations are on the shared dense time axis with NaN masking for initialization and missing data. +- **`CountsBySubpop`**: Subpopulation-level daily counts. `infection_resolution()` returns `"subpop"`. Accepts 2D infections of shape `(n_total, n_subpops)`. Uses sparse indexing via `times` and `subpop_indices`. + + +### Subclassing CountBase + +The following code sketch outlines the functions that must be defined on a subclass of `CountBase`. + +```python +from pyrenew.observation import CountBase +from pyrenew.observation.types import ObservationSample + +class MyCustomCounts(CountBase): + def __init__(self, name, ascertainment_rate_rv, delay_distribution_rv, noise): + super().__init__( + name=name, + ascertainment_rate_rv=ascertainment_rate_rv, + delay_distribution_rv=delay_distribution_rv, + noise=noise, + ) + + def infection_resolution(self): + return "aggregate" # or "subpop" + + def validate_data(self, n_total, n_subpops, **kwargs): + # Check observation data shapes + ... + + def sample(self, infections, ...): + predicted = self._predicted_obs(infections) # inherited + # ... transform predictions (e.g., aggregate to weekly) ... + observed = self.noise.sample(name=..., predicted=..., obs=...) + return ObservationSample(observed=observed, predicted=predicted) +``` + + +## Using the `Counts` class for Hospital Admissions Data For hospital admissions data, we construct a `Counts` observation process. The delay is the key mechanism: infections from $s$ days ago ($I(t-s)$) contribute to today's expected hospital admissions ($\mu(t)$) weighted by the probability ($\pi(s)$) that an infection leads to hospitalization after exactly $s$ days. The convolution sums these contributions across all past days. @@ -60,14 +146,14 @@ $$Y_t \sim \text{NegativeBinomial}(\text{mean} = \mu(t), \text{concentration} = The concentration parameter $\phi$ (sometimes called $k$ or the dispersion parameter) controls overdispersion: as $\phi \to \infty$, the distribution approaches Poisson; smaller values allow greater overdispersion. -We use the negative binomial distribution because real-world hospital admission counts exhibit overdispersion—the variance exceeds the mean. +We use the negative binomial distribution because real-world hospital admission counts exhibit overdispersion - the variance exceeds the mean. The Poisson distribution assumes variance equals the mean, which is too restrictive. The negative binomial adds an overdispersion term: $$\text{Var}[Y_t] = \mu + \frac{\mu^2}{\phi}$$ In this example, we use fixed parameter values for illustration; in practice, these parameters would be estimated from data using weakly informative priors. -## Infection-to-hospitalization delay distribution +### Infection-to-hospitalization delay distribution The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. @@ -131,9 +217,9 @@ plot_delay = ( plot_delay ``` -## Creating a Counts observation process +### Defining a Counts observation process -A `Counts` object takes the following arguments: +A `Counts` object inherits the full convolution pipeline from `CountBase`. It takes the following arguments: - **`name`**: unique, meaningful identifier for this observation process (e.g., `"hospital"`, `"deaths"`) - **`ascertainment_rate_rv`**: the probability an infection results in an observation (e.g., IHR) @@ -181,10 +267,14 @@ def first_valid_observation_day(obs_process) -> int: return obs_process.lookback_days() ``` -## Simulating observed hospital admissions given a single day's worth of infections +## Simulations -To demonstrate how a `Counts` observation process works, we examine how infections occurring on a single day result in observed hospital admissions. +In real-world inference, incident infections are a *latent* (unobserved) quantity and must be estimated from observed data. +To simulate the observation process we specify infections directly to show how they produce observed counts through convolution and sampling. +### Single-day infection spike + +To demonstrate how a `Counts` observation process works, we examine how infections occurring on a single day result in observed hospital admissions. ```{python} # | label: simulate-spike @@ -198,6 +288,7 @@ infections = infections.at[infection_spike_day].set(2000) ``` We plot the infections starting from day_one (the first valid observation day, after the lookback period). + ```{python} # | label: plot-infections # Plot relative to first valid observation day @@ -239,7 +330,7 @@ plot_infections Because all infections occur on a single day, this allows us to see how one day's worth of infections result in hospital admissions spread over subsequent days according to the delay distribution. -## Predicted admissions without observation noise. +### Predicted admissions without observation noise First, we compute the predicted admissions from the convolution alone, without observation noise. This is the mean of the distribution from which samples are drawn. @@ -319,7 +410,7 @@ plot_predicted The predicted admissions mirror the delay distribution, shifted by the infection spike day and scaled by the IHR. -## Observation Noise (Negative Binomial) +### Observation noise (Negative Binomial) The negative binomial distribution adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: @@ -418,7 +509,7 @@ print( ) ``` -## Effect of the ascertainment rate +### Effect of the ascertainment rate The ascertainment rate (here, the infection-hospitalization rate or IHR) directly scales the number of predicted hospital admissions. We compare two contrasting IHR values: **0.5%** and **2.5%**. @@ -451,7 +542,6 @@ for ihr_val in ihr_values: ) ``` - ```{python} # | label: plot-ihr-comparisons results_df = pd.DataFrame(results_list) @@ -471,7 +561,7 @@ plot_ihr = ( plot_ihr ``` -## Negative binomial concentration parameter +### Negative binomial concentration parameter The concentration parameter $\phi$ controls overdispersion: @@ -480,9 +570,9 @@ The concentration parameter $\phi$ controls overdispersion: We compare three concentration values spanning two orders of magnitude: -- **φ = 1**: high overdispersion (noisy) -- **φ = 10**: moderate overdispersion -- **φ = 100**: nearly Poisson (minimal noise) +- $\phi = 1$: high overdispersion (noisy) +- $\phi = 10$: moderate overdispersion +- $\phi = 100$: nearly Poisson (minimal noise) ```{python} # | label: concentration-comparisons @@ -518,7 +608,7 @@ for conc_val in concentration_values: { "day": i, "admissions": float(admit), - "concentration": f"φ = {int(conc_val)}", + "concentration": f"$\\phi$ = {int(conc_val)}", "replicate": seed, } ) @@ -531,7 +621,7 @@ conc_df = pd.DataFrame(conc_results) # Convert to ordered categorical conc_df["concentration"] = pd.Categorical( conc_df["concentration"], - categories=["φ = 1", "φ = 10", "φ = 100"], + categories=["$\\phi$ = 1", "$\\phi$ = 10", "$\\phi$ = 100"], ordered=True, ) @@ -542,14 +632,14 @@ plot_concentration = ( + p9.labs( x="Day", y="Hospital Admissions", - title="Effect of Concentration Parameter on Variability", + title="Effect of Negative Binomial Concentration Parameter on Variability", ) + theme_tutorial ) plot_concentration ``` -## Swapping noise models +### Swapping noise models To use Poisson noise instead of negative binomial, change the noise model: @@ -573,44 +663,373 @@ print( ) ``` -We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. - -To see the reduction in noise, it is necessary to keep the y-axis on the same scale as in the previous plot. +To compare Poisson noise directly against negative binomial, we plot 10 replicates from three noise models side by side using the same constant infection input. The shared y-axis makes the difference in variability immediately visible: Poisson ($\text{Var} = \mu$) is the tightest, negative binomial with $\phi = 100$ is nearly identical, and $\phi = 10$ shows noticeably more spread. ```{python} # | label: poisson-realizations -# Sample multiple realizations with Poisson noise -n_replicates_poisson = 10 +noise_comparison = [] +noise_configs = [ + ("Poisson", PoissonNoise()), + ( + "NegBin $\\phi$=100", + NegativeBinomialNoise(DeterministicVariable("c100", 100.0)), + ), + ( + "NegBin $\\phi$=10", + NegativeBinomialNoise(DeterministicVariable("c10", 10.0)), + ), +] + +for label, noise_model in noise_configs: + process_tmp = Counts( + name="hospital", + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=noise_model, + ) + for seed in range(10): + with numpyro.handlers.seed(rng_seed=seed): + result_tmp = process_tmp.sample( + infections=infections_constant, obs=None + ) + for i, admit in enumerate(result_tmp.observed[day_one:]): + noise_comparison.append( + { + "day": i, + "admissions": float(admit), + "noise": label, + "replicate": seed, + } + ) + +noise_df = pd.DataFrame(noise_comparison) +noise_df["noise"] = pd.Categorical( + noise_df["noise"], + categories=["Poisson", "NegBin $\\phi$=100", "NegBin $\\phi$=10"], + ordered=True, +) + +( + p9.ggplot(noise_df, p9.aes(x="day", y="admissions", group="replicate")) + + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") + + p9.facet_wrap("~ noise", ncol=3) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Noise Model Comparison: Poisson vs. Negative Binomial", + ) + + theme_tutorial +) +``` + +## Weekly Observations with WeeklyCounts + +Some surveillance signals are reported at coarser temporal resolution than daily. For example, NHSN hospital admissions are now reported as weekly (MMWR epiweek) totals rather than daily counts. + +The correct approach is to aggregate *predictions* up to the observation's temporal resolution, not to disaggregate observations down to daily. +Disaggregating weekly counts to daily values would fabricate within-week timing information that does not exist in the data. +Instead, the latent model produces daily predictions via `_predicted_obs()` (inherited from `CountBase`), and the observation process sums them into weekly totals. +The likelihood then evaluates at the weekly resolution - comparing weekly predicted totals to weekly observed totals. +This preserves the generative model's causal direction: latent daily infections flow forward through the observation process to produce predictions at whatever resolution the data requires. + +The predicted weekly admissions for epiweek $w$ are: + +$$\mu_w = \sum_{d \in w} \mu(d)$$ + +where $\mu(d)$ is the daily predicted count. Observations are weekly totals with negative binomial noise: + +$$Y_w \sim \text{NegativeBinomial}(\text{mean} = \mu_w, \text{concentration} = \phi)$$ + +Weekly aggregation naturally reduces variance relative to daily counts, so weekly observations typically use a higher concentration parameter (less overdispersion) than daily observations. +The choice of $\phi$ at each resolution reflects prior knowledge about the noise structure of the data. +Daily counts are subject to day-to-day reporting irregularities: staffing variation, batch reporting, and weekday/weekend effects all introduce overdispersion beyond what the Poisson model predicts. +A moderate $\phi$ (e.g., 10) captures this extra daily noise. +Weekly totals average over these within-week fluctuations, so the remaining noise after aggregation is closer to Poisson. +A high $\phi$ (e.g., 100) is appropriate because most of the reporting-driven overdispersion has been smoothed out by summing over 7 days. +In practice, both concentration parameters would be given informative priors and estimated from data, but the prior for the weekly $\phi$ should be centered higher than the prior for the daily $\phi$. + +Day-of-week effects and right-truncation are not applicable to weekly data: weekly aggregation absorbs within-week patterns and mitigates reporting delays. + +### Implementing the WeeklyCounts class + +```{python} +# | label: weekly-counts-class +from jax.typing import ArrayLike +from pyrenew.observation import CountBase +from pyrenew.observation.noise import CountNoise +from pyrenew.observation.types import ObservationSample +from pyrenew.metaclass import RandomVariable +from pyrenew.time import daily_to_mmwr_epiweekly + + +class WeeklyCounts(CountBase): + """Weekly (MMWR epiweek) aggregate count observation process.""" + + def __init__( + self, + name: str, + ascertainment_rate_rv: RandomVariable, + delay_distribution_rv: RandomVariable, + noise: CountNoise, + ) -> None: + """ + Initialize weekly count observation process. + + Parameters + ---------- + name : str + Unique name for this observation process. + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model for weekly count observations. + """ + super().__init__( + name=name, + ascertainment_rate_rv=ascertainment_rate_rv, + delay_distribution_rv=delay_distribution_rv, + noise=noise, + ) + + def infection_resolution(self) -> str: + """Return 'aggregate' for jurisdiction-level observations.""" + return "aggregate" + + def validate_data( + self, + n_total: int, + n_subpops: int, + first_day_dow: int | None = None, + week_indices: ArrayLike | None = None, + obs: ArrayLike | None = None, + **kwargs, + ) -> None: + """ + Validate weekly observation data. + + Parameters + ---------- + n_total : int + Total time steps on the shared daily axis. + n_subpops : int + Number of subpopulations (unused). + first_day_dow : int | None + Day of the week for element 0 of the shared time axis. + week_indices : ArrayLike | None + Indices into the weekly-aggregated predictions array. + obs : ArrayLike | None + Weekly observed counts. + **kwargs + Additional keyword arguments (ignored). + """ + if obs is not None and week_indices is not None: + obs = jnp.asarray(obs) + week_indices = jnp.asarray(week_indices) + if obs.shape != week_indices.shape: + raise ValueError( + f"Observation '{self.name}': obs shape {obs.shape} " + f"must match week_indices shape {week_indices.shape}" + ) + + def sample( + self, + infections: ArrayLike, + first_day_dow: int, + week_indices: ArrayLike, + obs: ArrayLike | None = None, + ) -> ObservationSample: + """ + Sample weekly aggregated counts. + + Parameters + ---------- + infections : ArrayLike + Daily aggregate infections, shape (n_total,). + first_day_dow : int + ISO day-of-week for element 0 of the shared time axis + (0=Monday, 6=Sunday). + week_indices : ArrayLike + Indices into the weekly predictions array identifying + which weeks have observations. + obs : ArrayLike | None + Weekly observed counts, shape (n_obs_weeks,). + None for prior predictive sampling. + + Returns + ------- + ObservationSample + Named tuple with observed (weekly) and predicted (daily). + """ + daily_predicted = self._predicted_obs(infections) + self._deterministic("predicted_daily", daily_predicted) + + weekly_predicted = daily_to_mmwr_epiweekly( + daily_predicted, input_data_first_dow=first_day_dow + ) + self._deterministic("predicted_weekly", weekly_predicted) + + predicted_at_obs = weekly_predicted[week_indices] + + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=predicted_at_obs, + obs=obs, + ) + + return ObservationSample(observed=observed, predicted=daily_predicted) +``` + +Key design choices: + +- **No day-of-week or right-truncation**: The constructor passes neither `day_of_week_rv` nor `right_truncation_rv` to `CountBase`. Weekly aggregation absorbs within-week patterns and mitigates reporting delays. +- **`week_indices`**: Maps observed weeks to positions in the aggregated predictions. This handles partial weeks at the start/end of the time series and allows for missing weeks. +- **Two deterministic sites**: `predicted_daily` (full daily time series) and `predicted_weekly` (aggregated epiweek totals) are both recorded for posterior analysis. + +### Configuring a weekly hospital admissions process + +```{python} +# | label: create-weekly-process +weekly_ihr_rv = DeterministicVariable("weekly_ihr", 0.01) +weekly_concentration_rv = DeterministicVariable("weekly_concentration", 100.0) + +weekly_hosp_process = WeeklyCounts( + name="hospital_weekly", + ascertainment_rate_rv=weekly_ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(weekly_concentration_rv), +) + +print(f"Required lookback: {weekly_hosp_process.lookback_days()} days") +``` + +### Comparing daily and weekly observations from the same infections + +Using the exponentially decaying infection curve from earlier, we can see how the same underlying epidemic produces different observations at daily vs. weekly resolution. + +```{python} +# | label: weekly-simulate +import datetime as dt + +peak_value = 3000 +infections_decay = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + +# The shared time axis starts on a Sunday (2023-01-01 was a Sunday = ISO dow 6) +first_dow = 6 + +# Compute weekly predictions to determine valid week indices +with numpyro.handlers.seed(rng_seed=0): + daily_predicted = weekly_hosp_process._predicted_obs(infections_decay) + +weekly_predicted = daily_to_mmwr_epiweekly( + daily_predicted, input_data_first_dow=first_dow +) +n_valid_weeks = int(jnp.sum(~jnp.isnan(weekly_predicted))) +n_total_weeks = len(weekly_predicted) + +# Use all valid (non-NaN) weeks +all_week_indices = jnp.arange(n_total_weeks) +valid_mask = ~jnp.isnan(weekly_predicted) +week_indices = all_week_indices[valid_mask] + +print( + f"Total weeks: {n_total_weeks}, " + f"valid weeks (after lookback): {n_valid_weeks}" +) +``` -poisson_results = [] -for seed in range(n_replicates_poisson): +```{python} +# | label: weekly-daily-comparison +# Sample daily observations (using existing Counts process) +daily_process = Counts( + name="hospital_daily", + ascertainment_rate_rv=weekly_ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(DeterministicVariable("conc_daily", 10.0)), +) + +# Collect daily and weekly samples +comparison_list = [] + +for seed in range(50): with numpyro.handlers.seed(rng_seed=seed): - poisson_temp = hosp_process_poisson.sample( - infections=infections_constant, + daily_result = daily_process.sample( + infections=infections_decay, obs=None + ) + + for i, val in enumerate(daily_result.observed[day_one:]): + comparison_list.append( + { + "time": i, + "admissions": float(val), + "resolution": "Daily ($\\phi$=10)", + "replicate": seed, + } + ) + + with numpyro.handlers.seed(rng_seed=seed): + weekly_result = weekly_hosp_process.sample( + infections=infections_decay, + first_day_dow=first_dow, + week_indices=week_indices, obs=None, ) - # Slice from day_one to align with valid observation period - for i, admit in enumerate(poisson_temp.observed[day_one:]): - poisson_results.append( + for j, (wi, val) in enumerate(zip(week_indices, weekly_result.observed)): + comparison_list.append( { - "day": i, - "admissions": float(admit), + "time": int(wi) * 7 + 3 - day_one, + "admissions": float(val), + "resolution": "Weekly ($\\phi$=100)", "replicate": seed, } ) -poisson_df = pd.DataFrame(poisson_results) -plot_poisson = ( - p9.ggplot(poisson_df, p9.aes(x="day", y="admissions", group="replicate")) - + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") +comparison_df = pd.DataFrame(comparison_list) +``` + +```{python} +# | label: plot-weekly-daily-comparison +daily_comp = comparison_df[comparison_df["resolution"] == "Daily ($\\phi$=10)"] +weekly_comp = comparison_df[ + comparison_df["resolution"] == "Weekly ($\\phi$=100)" +] + +( + p9.ggplot() + + p9.geom_line( + p9.aes(x="time", y="admissions", group="replicate"), + data=daily_comp, + color="steelblue", + alpha=0.2, + size=0.5, + ) + + p9.geom_jitter( + p9.aes(x="time", y="admissions", group="replicate"), + data=weekly_comp, + color="orange", + alpha=0.6, + size=2, + width=1.2, + height=0, + ) + p9.labs( - x="Day", + x="Day (relative to first valid observation day)", y="Hospital Admissions", - title="Poisson Noise Model (Variance = Mean)", + title="Daily vs. Weekly: 50 Sample Observations from Same Infections ", + subtitle="Blue lines: daily ($\\phi$=10) | Orange points: weekly totals ($\\phi$=100)", ) + theme_tutorial - + p9.ylim(0, 105) ) -plot_poisson ``` + +Weekly aggregation collapses seven daily values into a single total per epiweek. +The weekly points may appear more dispersed than the daily lines, even though the weekly process uses a much higher concentration parameter ($\phi = 100$ vs. $\phi = 10$). +This is not a modeling error. +The negative binomial variance is $\text{Var}[Y] = \mu + \mu^2 / \phi$. +Weekly totals have means roughly 7 times larger than daily means ($\mu_w \approx 7 \mu_d$), so the quadratic term $\mu_w^2 / \phi$ grows with the square of the mean. +Even with $\phi = 100$, the absolute spread of the weekly distribution is wider than the daily distribution with $\phi = 10$, because the weekly mean is so much larger. +In relative terms (coefficient of variation), the weekly observations are tighter, which is why weekly data is often considered less noisy for inference. + +In a multi-signal model, pairing weekly hospital admissions with a daily signal (such as ED visits) allows the daily signal to resolve within-week dynamics that the weekly signal cannot capture. diff --git a/pyrenew/datasets/__init__.py b/pyrenew/datasets/__init__.py index 91927f71..71b0ebb8 100644 --- a/pyrenew/datasets/__init__.py +++ b/pyrenew/datasets/__init__.py @@ -1,7 +1,11 @@ # numpydoc ignore=GL08 +from pyrenew.datasets.ed_visits import load_ed_visits_data_for_state from pyrenew.datasets.generation_interval import load_generation_interval -from pyrenew.datasets.hospital_admissions import load_hospital_data_for_state +from pyrenew.datasets.hospital_admissions import ( + load_hospital_data_for_state, + load_weekly_hospital_data_for_state, +) from pyrenew.datasets.infection_admission_interval import ( load_infection_admission_interval, ) @@ -13,5 +17,7 @@ "load_infection_admission_interval", "load_generation_interval", "load_hospital_data_for_state", + "load_weekly_hospital_data_for_state", "load_wastewater_data_for_state", + "load_ed_visits_data_for_state", ] diff --git a/pyrenew/datasets/hospital_admissions.py b/pyrenew/datasets/hospital_admissions.py index 5cb01b58..126d27a3 100644 --- a/pyrenew/datasets/hospital_admissions.py +++ b/pyrenew/datasets/hospital_admissions.py @@ -3,7 +3,7 @@ Load hospital admissions data for use in tutorials and examples. This module provides functions to load COVID-19 hospital admissions -data from the CDC's cfa-forecast-renewal-ww project. +data (daily and weekly) from the CDC's cfa-forecast-renewal-ww project. """ from importlib.resources import files @@ -63,3 +63,56 @@ def load_hospital_data_for_state( "dates": dates, "n_days": len(daily_admits), } + + +def load_weekly_hospital_data_for_state( + state_abbr: str = "CA", + filename: str = "2023-11-06_weekly.csv", +) -> dict: + """ + Load weekly (epiweek) hospital admissions data for a specific state. + + Parameters + ---------- + state_abbr : str + State abbreviation (e.g., "CA"). Default is "CA". + filename : str + CSV filename. Default is "2023-11-06_weekly.csv". + + Returns + ------- + dict + Dictionary containing: + + - weekly_admits: JAX array of weekly hospital admissions + - population: Population size (scalar) + - week_ends: List of datetime.date objects (epiweek ending dates) + - n_weeks: Number of weeks + + Notes + ----- + Data source: aggregated from daily CDC cfa-forecast-renewal-ww data. + License: Public Domain (CC0 1.0 Universal) - U.S. Government work. + """ + data_path = files("pyrenew.datasets.hospital_admissions_data") / filename + df = pl.read_csv(source=data_path) + + df = ( + df.with_columns(pl.col("week_end").str.to_date()) + .filter(pl.col("location") == state_abbr) + .sort("week_end") + ) + + if len(df) == 0: + raise ValueError(f"No data found for state {state_abbr} in {filename}") + + weekly_admits = jnp.array(df["weekly_hosp_admits"].to_numpy()) + population = int(df["pop"][0]) + week_ends = df["week_end"].to_list() + + return { + "weekly_admits": weekly_admits, + "population": population, + "week_ends": week_ends, + "n_weeks": len(weekly_admits), + } diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index 50616364..64e0ef4d 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -4,6 +4,7 @@ ``BaseObservationProcess`` is the abstract base. Concrete subclasses: +- ``CountBase``: Base class for count observations (ascertainment x delay convolution) - ``Counts``: Aggregate counts (admissions, deaths) - ``CountsBySubpop``: Subpopulation-level counts - ``Measurements``: Continuous subpopulation-level signals (e.g., wastewater) @@ -19,7 +20,7 @@ """ from pyrenew.observation.base import BaseObservationProcess -from pyrenew.observation.count_observations import Counts, CountsBySubpop +from pyrenew.observation.count_observations import CountBase, Counts, CountsBySubpop from pyrenew.observation.measurements import Measurements from pyrenew.observation.negativebinomial import NegativeBinomialObservation from pyrenew.observation.noise import ( @@ -46,6 +47,7 @@ "HierarchicalNormalNoise", "VectorizedRV", # Observation processes + "CountBase", "Counts", "CountsBySubpop", "Measurements", diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 8f86d2a1..6788f9d7 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -20,9 +20,9 @@ from pyrenew.time import validate_dow -class _CountBase(BaseObservationProcess): +class CountBase(BaseObservationProcess): """ - Internal base for count observation processes. + Base class for count observation processes. Implements ascertainment x delay convolution with pluggable noise model. """ @@ -116,17 +116,6 @@ def lookback_days(self) -> int: """ return len(self.temporal_pmf_rv()) - 1 - def infection_resolution(self) -> str: - """ - Return required infection resolution. - - Returns - ------- - str - "aggregate" or "subpop". - """ - raise NotImplementedError("Subclasses must implement infection_resolution()") - def _predicted_obs( self, infections: ArrayLike, @@ -246,7 +235,7 @@ def _apply_day_of_week( return predicted * daily_effect -class Counts(_CountBase): +class Counts(CountBase): """ Aggregated count observation. @@ -395,7 +384,7 @@ def sample( return ObservationSample(observed=observed, predicted=predicted_counts) -class CountsBySubpop(_CountBase): +class CountsBySubpop(CountBase): """ Subpopulation-level count observation. diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py index 63c72fd7..b3ccaa26 100644 --- a/test/test_interface_coverage.py +++ b/test/test_interface_coverage.py @@ -253,11 +253,11 @@ def test_measurements_infection_resolution(): def test_base_count_observation_infection_resolution_raises(): - """Base _CountBase.infection_resolution() raises NotImplementedError.""" - from pyrenew.observation.count_observations import _CountBase + """Subclass of CountBase without infection_resolution cannot be instantiated.""" + from pyrenew.observation.count_observations import CountBase - class _MinimalCounts(_CountBase): - """Minimal subclass that inherits infection_resolution unchanged.""" + class _MinimalCounts(CountBase): + """Minimal subclass missing infection_resolution.""" def sample(self, *args, **kwargs): # numpydoc ignore=GL08 pass @@ -265,14 +265,13 @@ def sample(self, *args, **kwargs): # numpydoc ignore=GL08 def validate_data(self, n_total, n_subpops, **obs_data): # numpydoc ignore=GL08 pass - obs = _MinimalCounts( - name="test_base", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), - noise=PoissonNoise(), - ) - with pytest.raises(NotImplementedError): - obs.infection_resolution() + with pytest.raises(TypeError, match="infection_resolution"): + _MinimalCounts( + name="test_base", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), + noise=PoissonNoise(), + ) # ============================================================================= From 0b3a0fae2b104dbe67381e21399cc75381dad2c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 03:07:22 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyrenew/datasets/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrenew/datasets/__init__.py b/pyrenew/datasets/__init__.py index 71b0ebb8..71440409 100644 --- a/pyrenew/datasets/__init__.py +++ b/pyrenew/datasets/__init__.py @@ -1,6 +1,7 @@ # numpydoc ignore=GL08 from pyrenew.datasets.ed_visits import load_ed_visits_data_for_state + from pyrenew.datasets.generation_interval import load_generation_interval from pyrenew.datasets.hospital_admissions import ( load_hospital_data_for_state, From 9795402099ce10f94bc2a42261b17c89cf9040e3 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 26 Feb 2026 13:33:11 -0500 Subject: [PATCH 5/7] remove ed_visit import --- pyrenew/datasets/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyrenew/datasets/__init__.py b/pyrenew/datasets/__init__.py index 71440409..17ba02f8 100644 --- a/pyrenew/datasets/__init__.py +++ b/pyrenew/datasets/__init__.py @@ -20,5 +20,4 @@ "load_hospital_data_for_state", "load_weekly_hospital_data_for_state", "load_wastewater_data_for_state", - "load_ed_visits_data_for_state", ] From 7c3887bb86e11984a64e87aeb128cbd2454d7e72 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 26 Feb 2026 13:34:01 -0500 Subject: [PATCH 6/7] remove ed_visit import --- pyrenew/datasets/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyrenew/datasets/__init__.py b/pyrenew/datasets/__init__.py index 17ba02f8..7a2998bb 100644 --- a/pyrenew/datasets/__init__.py +++ b/pyrenew/datasets/__init__.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from pyrenew.datasets.ed_visits import load_ed_visits_data_for_state - from pyrenew.datasets.generation_interval import load_generation_interval from pyrenew.datasets.hospital_admissions import ( load_hospital_data_for_state, From 8b884af3057e3979545e556bd406541ea44da585 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 26 Feb 2026 14:29:16 -0500 Subject: [PATCH 7/7] remove unused code --- pyrenew/datasets/__init__.py | 6 +-- pyrenew/datasets/hospital_admissions.py | 53 ------------------------- 2 files changed, 1 insertion(+), 58 deletions(-) diff --git a/pyrenew/datasets/__init__.py b/pyrenew/datasets/__init__.py index 7a2998bb..91927f71 100644 --- a/pyrenew/datasets/__init__.py +++ b/pyrenew/datasets/__init__.py @@ -1,10 +1,7 @@ # numpydoc ignore=GL08 from pyrenew.datasets.generation_interval import load_generation_interval -from pyrenew.datasets.hospital_admissions import ( - load_hospital_data_for_state, - load_weekly_hospital_data_for_state, -) +from pyrenew.datasets.hospital_admissions import load_hospital_data_for_state from pyrenew.datasets.infection_admission_interval import ( load_infection_admission_interval, ) @@ -16,6 +13,5 @@ "load_infection_admission_interval", "load_generation_interval", "load_hospital_data_for_state", - "load_weekly_hospital_data_for_state", "load_wastewater_data_for_state", ] diff --git a/pyrenew/datasets/hospital_admissions.py b/pyrenew/datasets/hospital_admissions.py index 365db3e4..cfc52aad 100644 --- a/pyrenew/datasets/hospital_admissions.py +++ b/pyrenew/datasets/hospital_admissions.py @@ -63,56 +63,3 @@ def load_hospital_data_for_state( "dates": dates, "n_days": len(daily_admits), } - - -def load_weekly_hospital_data_for_state( - state_abbr: str = "CA", - filename: str = "2023-11-06_weekly.csv", -) -> dict: - """ - Load weekly (epiweek) hospital admissions data for a specific state. - - Parameters - ---------- - state_abbr : str - State abbreviation (e.g., "CA"). Default is "CA". - filename : str - CSV filename. Default is "2023-11-06_weekly.csv". - - Returns - ------- - dict - Dictionary containing: - - - weekly_admits: JAX array of weekly hospital admissions - - population: Population size (scalar) - - week_ends: List of datetime.date objects (epiweek ending dates) - - n_weeks: Number of weeks - - Notes - ----- - Data source: aggregated from daily CDC cfa-forecast-renewal-ww data. - License: Public Domain (CC0 1.0 Universal) - U.S. Government work. - """ - data_path = files("pyrenew.datasets.hospital_admissions_data") / filename - df = pl.read_csv(source=data_path) - - df = ( - df.with_columns(pl.col("week_end").str.to_date()) - .filter(pl.col("location") == state_abbr) - .sort("week_end") - ) - - if len(df) == 0: - raise ValueError(f"No data found for state {state_abbr} in {filename}") - - weekly_admits = jnp.array(df["weekly_hosp_admits"].to_numpy()) - population = int(df["pop"][0]) - week_ends = df["week_end"].to_list() - - return { - "weekly_admits": weekly_admits, - "population": population, - "week_ends": week_ends, - "n_weeks": len(weekly_admits), - }