From 6b52cfebd2b0e689d9ac72330ee79c209561f55b Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 10 Apr 2026 23:45:13 -0400 Subject: [PATCH 1/7] updated files; unit tests passing --- .../tutorials/building_multisignal_models.qmd | 34 +-- .../latent_hierarchical_infections.qmd | 44 ++-- docs/tutorials/latent_infections.qmd | 33 +-- pyrenew/latent/__init__.py | 10 +- pyrenew/latent/base.py | 4 +- ...infections.py => population_infections.py} | 58 ++--- ...ections.py => subpopulation_infections.py} | 14 +- pyrenew/model/multisignal_model.py | 2 +- pyrenew/model/pyrenew_builder.py | 2 +- test/conftest.py | 41 +++- test/integration/conftest.py | 12 +- ...he.py => test_population_infections_he.py} | 14 +- test/test_interface_coverage.py | 20 +- ...tions.py => test_population_infections.py} | 210 ++++++++---------- test/test_pyrenew_builder.py | 28 +-- ...ns.py => test_subpopulation_infections.py} | 110 ++++----- 16 files changed, 322 insertions(+), 314 deletions(-) rename pyrenew/latent/{shared_infections.py => population_infections.py} (79%) rename pyrenew/latent/{hierarchical_infections.py => subpopulation_infections.py} (96%) rename test/integration/{test_shared_infections_he.py => test_population_infections_he.py} (95%) rename test/{test_shared_infections.py => test_population_infections.py} (55%) rename test/{test_hierarchical_infections.py => test_subpopulation_infections.py} (68%) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 7d67e262..41d73d2b 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -64,7 +64,7 @@ from pyrenew.metaclass import RandomVariable from pyrenew.randomvariable import DistributionalVariable from pyrenew.latent import ( - HierarchicalInfections, + SubpopulationInfections, AR1, RandomWalk, GammaGroupSdPrior, @@ -97,7 +97,7 @@ A **multi-signal model** combines multiple observation processes—each represen The `PyrenewBuilder` class handles this plumbing. You specify: -1. A single **latent process** (e.g., `HierarchicalInfections`) that defines how infections evolve. +1. A single **latent process** (e.g., `SubpopulationInfections`) that defines how infections evolve. 2. One or more **observation processes** (e.g., `Counts`, `Measurements`) that define how infections become data. The builder computes initialization requirements, wires components together, and produces a model ready for inference. @@ -116,7 +116,7 @@ This tutorial shows how to combine these components into a complete multi-signal This tutorial demonstrates building a multi-signal renewal model using: -- `HierarchicalInfections` — subpopulations share a jurisdiction-level baseline $\mathcal{R}(t)$ with subpopulation-specific deviations +- `SubpopulationInfections` — subpopulations share a jurisdiction-level baseline $\mathcal{R}(t)$ with subpopulation-specific deviations - `Counts` — hospital admissions (jurisdiction-level) - A custom `Wastewater` class — viral concentrations (subpopulation-level) @@ -133,7 +133,7 @@ The diagram below shows how data flows through the model. The latent process gen flowchart TB subgraph Latent["Latent Infection Process"] - L["Renewal equation
(HierarchicalInfections)"] + L["Renewal equation
(SubpopulationInfections)"] end subgraph Infections["Infection Trajectories"] @@ -214,7 +214,7 @@ We place a prior on the initial log(Rt), centered at 0.0 (Rt = 1.0) with moderat ```{python} # | label: initial-log-rt -initial_log_rt_rv = DistributionalVariable( +log_rt_time_0_rv = DistributionalVariable( "initial_log_rt", dist.Normal(0.0, 0.5) ) ``` @@ -530,10 +530,10 @@ print("Latent process configuration:") print(f" Generation interval length: {len(gen_int_rv())} days") builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int_rv, I0_rv=I0_rv, - initial_log_rt_rv=initial_log_rt_rv, + log_rt_time_0_rv=log_rt_time_0_rv, baseline_rt_process=baseline_rt_process, subpop_rt_deviation_process=subpop_rt_deviation_process, ) @@ -824,11 +824,11 @@ idata_90 = az.from_numpyro( model.mcmc, dims={ "latent_infections": ["time"], - "HierarchicalInfections::infections_aggregate": ["time"], - "HierarchicalInfections::log_rt_baseline": ["time", "dummy"], - "HierarchicalInfections::rt_baseline": ["time", "dummy"], - "HierarchicalInfections::rt_subpop": ["time", "subpop"], - "HierarchicalInfections::subpop_deviations": ["time", "subpop"], + "SubpopulationInfections::infections_aggregate": ["time"], + "SubpopulationInfections::log_rt_baseline": ["time", "dummy"], + "SubpopulationInfections::rt_baseline": ["time", "dummy"], + "SubpopulationInfections::rt_subpop": ["time", "subpop"], + "SubpopulationInfections::subpop_deviations": ["time", "subpop"], "latent_infections_by_subpop": ["time", "subpop"], "hospital_predicted": ["time"], "wastewater_predicted": ["time", "subpop"], @@ -973,11 +973,11 @@ idata_180 = az.from_numpyro( model.mcmc, dims={ "latent_infections": ["time"], - "HierarchicalInfections::infections_aggregate": ["time"], - "HierarchicalInfections::log_rt_baseline": ["time", "dummy"], - "HierarchicalInfections::rt_baseline": ["time", "dummy"], - "HierarchicalInfections::rt_subpop": ["time", "subpop"], - "HierarchicalInfections::subpop_deviations": ["time", "subpop"], + "SubpopulationInfections::infections_aggregate": ["time"], + "SubpopulationInfections::log_rt_baseline": ["time", "dummy"], + "SubpopulationInfections::rt_baseline": ["time", "dummy"], + "SubpopulationInfections::rt_subpop": ["time", "subpop"], + "SubpopulationInfections::subpop_deviations": ["time", "subpop"], "latent_infections_by_subpop": ["time", "subpop"], "hospital_predicted": ["time"], "wastewater_predicted": ["time", "subpop"], diff --git a/docs/tutorials/latent_hierarchical_infections.qmd b/docs/tutorials/latent_hierarchical_infections.qmd index 8ef3092d..3897d117 100644 --- a/docs/tutorials/latent_hierarchical_infections.qmd +++ b/docs/tutorials/latent_hierarchical_infections.qmd @@ -38,7 +38,7 @@ from _tutorial_theme import theme_tutorial ```{python} # | label: imports from pyrenew.latent import ( - HierarchicalInfections, + SubpopulationInfections, AR1, DifferencedAR1, RandomWalk, @@ -49,7 +49,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable ## Overview -`HierarchicalInfections` extends the renewal model to a population composed of $K$ subpopulations, each with its own latent infection trajectory. +`SubpopulationInfections` extends the renewal model to a population composed of $K$ subpopulations, each with its own latent infection trajectory. As in the single-population case, infections evolve according to the renewal equation. For each subpopulation $k = 1, \dots, K$, we define @@ -107,8 +107,8 @@ The aggregate trajectory $I_{\text{aggregate}}(t)$ is then passed to observation This model generalizes the single-population renewal model: -- `SharedInfections`: one trajectory $I(t)$ with one $\mathcal{R}(t)$ -- `HierarchicalInfections`: multiple trajectories $I_k(t)$ with a shared baseline $\mathcal{R}_{\text{baseline}}(t)$ and subpopulation deviations $\delta_k(t)$ +- `PopulationInfections`: one trajectory $I(t)$ with one $\mathcal{R}(t)$ +- `SubpopulationInfections`: multiple trajectories $I_k(t)$ with a shared baseline $\mathcal{R}_{\text{baseline}}(t)$ and subpopulation deviations $\delta_k(t)$ When all $\delta_k(t) = 0$, the model reduces to the shared (single-population) case. @@ -118,13 +118,13 @@ This tutorial assumes familiarity with the renewal equation, generation interval ## Model Structure -`HierarchicalInfections` uses the same core inputs as `SharedInfections`, with additional structure for subpopulation variation. +`SubpopulationInfections` uses the same core inputs as `PopulationInfections`, with additional structure for subpopulation variation. ### Core inputs (shared across subpopulations) - **`gen_int_rv`**: Generation interval distribution $w_\tau$ - **`I0_rv`**: Initial infection level (shared by default across subpopulations, but can be specified per subpopulation) -- **`initial_log_rt_rv`**: Initial value of $\log \mathcal{R}_{\text{baseline}}(t)$ at time $0$ +- **`log_rt_time_0_rv`**: Initial value of $\log \mathcal{R}_{\text{baseline}}(t)$ at time $0$ ### Temporal processes @@ -184,11 +184,11 @@ print(f"Initial Rt: {np.exp(initial_log_rt):.2f}") ```{python} # | label: instantiate -model = HierarchicalInfections( - name="HierarchicalInfections", +model = SubpopulationInfections( + name="SubpopulationInfections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", I0_val), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", initial_log_rt), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", initial_log_rt), baseline_rt_process=DifferencedAR1(autoreg=0.5, innovation_sd=0.01), subpop_rt_deviation_process=AR1(autoreg=0.8, innovation_sd=0.05), n_initialization_points=n_init, @@ -197,7 +197,7 @@ model = HierarchicalInfections( ## The Sum-to-Zero Constraint -Without a constraint, the baseline and deviations are not identifiable: shifting the baseline up by some amount $c$ and all deviations down by $c$ produces the same subpopulation $\mathcal{R}_k(t)$ values. `HierarchicalInfections` enforces $\sum_k \delta_k(t) = 0$ at every time step by centering the raw deviation trajectories. +Without a constraint, the baseline and deviations are not identifiable: shifting the baseline up by some amount $c$ and all deviations down by $c$ produces the same subpopulation $\mathcal{R}_k(t)$ values. `SubpopulationInfections` enforces $\sum_k \delta_k(t) = 0$ at every time step by centering the raw deviation trajectories. This ensures $\mathcal{R}_{\text{baseline}}(t)$ is the *unweighted* geometric mean of the subpopulation $\mathcal{R}_k(t)$ values, so the baseline represents the typical transmission level across subpopulations. Note that this is the unweighted mean across subpopulations, not population-weighted by $p_k$. As a result, $\mathcal{R}_{\text{baseline}}(t)$ does **not** in general equal the jurisdiction-level reproduction number implied by the aggregate infection trajectory $I_{\text{aggregate}}(t)$. When subpopulations differ in size, a small subpopulation with a large $\mathcal{R}_k(t)$ contributes equally to the baseline but only marginally to the aggregate. @@ -211,7 +211,7 @@ with numpyro.handlers.seed(rng_seed=42): subpop_fractions=subpop_fractions, ) -deviations = trace["HierarchicalInfections::subpop_deviations"]["value"] +deviations = trace["SubpopulationInfections::subpop_deviations"]["value"] deviation_sums = jnp.sum(deviations, axis=1) print(f"Deviation shape: {deviations.shape} (n_total_days, n_subpops)") print( @@ -222,9 +222,9 @@ print( ## Choosing the Baseline Temporal Process The baseline process governs the jurisdiction-level trend in $\mathcal{R}(t)$. -The same temporal process options apply as in `SharedInfections` (see [Temporal Process Choice](latent_infections.md#temporal-process-choice)): AR(1) for mean reversion, DifferencedAR1 for persistent trends with stabilizing rate of change, RandomWalk for unconstrained drift. +The same temporal process options apply as in `PopulationInfections` (see [Temporal Process Choice](latent_infections.md#temporal-process-choice)): AR(1) for mean reversion, DifferencedAR1 for persistent trends with stabilizing rate of change, RandomWalk for unconstrained drift. -We use DifferencedAR1 with small `innovation_sd` for the baseline throughout this tutorial. The prior predictive for baseline $\mathcal{R}(t)$ behaves much like the corresponding `SharedInfections` example, so we do not repeat that full comparison here. The focus of this tutorial is the deviation process and how it changes the spread of subpopulation trajectories around the baseline. +We use DifferencedAR1 with small `innovation_sd` for the baseline throughout this tutorial. The prior predictive for baseline $\mathcal{R}(t)$ behaves much like the corresponding `PopulationInfections` example, so we do not repeat that full comparison here. The focus of this tutorial is the deviation process and how it changes the spread of subpopulation trajectories around the baseline. The same high-level decision rules apply here: @@ -251,12 +251,12 @@ This determines whether subpopulations quickly return to the baseline or can div ```{python} # | label: helpers def sample_hierarchical(baseline_process, deviation_process, label): - """Draw prior predictive samples from a HierarchicalInfections model.""" - m = HierarchicalInfections( - name="HierarchicalInfections", + """Draw prior predictive samples from a SubpopulationInfections model.""" + m = SubpopulationInfections( + name="SubpopulationInfections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", I0_val), - initial_log_rt_rv=DeterministicVariable( + log_rt_time_0_rv=DeterministicVariable( "initial_log_rt", initial_log_rt ), baseline_rt_process=baseline_process, @@ -274,16 +274,16 @@ def sample_hierarchical(baseline_process, deviation_process, label): samples = Predictive(sampler, num_samples=n_samples)(random.PRNGKey(42)) return { "rt_baseline": np.array( - samples["HierarchicalInfections::rt_baseline"] + samples["SubpopulationInfections::rt_baseline"] )[:, n_init:, 0], - "rt_subpop": np.array(samples["HierarchicalInfections::rt_subpop"])[ + "rt_subpop": np.array(samples["SubpopulationInfections::rt_subpop"])[ :, n_init:, : ], "deviations": np.array( - samples["HierarchicalInfections::subpop_deviations"] + samples["SubpopulationInfections::subpop_deviations"] )[:, n_init:, :], "infections": np.array( - samples["HierarchicalInfections::infections_aggregate"] + samples["SubpopulationInfections::infections_aggregate"] )[:, n_init:], } @@ -461,7 +461,7 @@ for label, s in [("AR(1)", ar1_dev_samples), ("RandomWalk", rw_dev_samples)]: pd.DataFrame(deviation_summary_rows) ``` -AR(1) deviations stay close to zero because mean reversion continuously pulls them back. RandomWalk deviations accumulate over time. As in the SharedInfections tutorial, this is prior-modeling guidance rather than a uniquely determined epidemiologic rule. The choice depends on the epidemiological setting: +AR(1) deviations stay close to zero because mean reversion continuously pulls them back. RandomWalk deviations accumulate over time. As in the PopulationInfections tutorial, this is prior-modeling guidance rather than a uniquely determined epidemiologic rule. The choice depends on the epidemiological setting: - **AR(1) deviations** when subpopulations are expected to track the jurisdiction average. Local outbreaks or lulls are temporary. This is typical for geographically close subpopulations (e.g., counties within a metropolitan area) where mobility mixes transmission. - **RandomWalk deviations** when local differences can persist. This may be appropriate for subpopulations with distinct contact patterns, demographics, or intervention histories (e.g., urban vs. rural areas). diff --git a/docs/tutorials/latent_infections.qmd b/docs/tutorials/latent_infections.qmd index b06e62d3..cf4d865b 100644 --- a/docs/tutorials/latent_infections.qmd +++ b/docs/tutorials/latent_infections.qmd @@ -44,7 +44,12 @@ from _tutorial_theme import theme_tutorial ```{python} # | label: imports -from pyrenew.latent import SharedInfections, AR1, DifferencedAR1, RandomWalk +from pyrenew.latent import ( + PopulationInfections, + AR1, + DifferencedAR1, + RandomWalk, +) from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.randomvariable import DistributionalVariable ``` @@ -72,17 +77,17 @@ Here, $\tau$ indexes lags in the generation interval. PyRenew provides two latent infection classes: -- **`SharedInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. Appropriate when modeling one jurisdiction as a single population with one or more observation streams. This tutorial covers `SharedInfections`. -- **`HierarchicalInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Hierarchical Latent Infections](latent_hierarchical_infections.md). +- **`PopulationInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. Appropriate when modeling one jurisdiction as a single population with one or more observation streams. This tutorial covers `PopulationInfections`. +- **`SubpopulationInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Hierarchical Latent Infections](latent_hierarchical_infections.md). ## Model Inputs -`SharedInfections` requires four inputs: +`PopulationInfections` requires four inputs: 1. Generation interval distribution $w_\tau$ (`gen_int_rv`) 2. Infection prevalence at time $0$ as a proportion of the population (`I0_rv`) -3. Value for $log(\mathcal{R}(t))$ at time $0$ (`initial_log_rt_rv`) -4. A temporal process for $\mathcal{R}(t)$ dynamics (`shared_rt_process`) +3. Value for $log(\mathcal{R}(t))$ at time $0$ (`log_rt_time_0_rv`) +4. A temporal process for $\mathcal{R}(t)$ dynamics (`single_rt_process`) All inputs are **RandomVariables**, a quantity that is either known (observed, conditioned on) or unknown (to be inferred). See [PyRenew's RandomVariable abstract base class](random_variables.md). In this tutorial, we use `DeterministicVariable` and `DeterministicPMF` (fixed values) for illustration. In real inference, you would use `DistributionalVariable` with priors for quantities you want to estimate: @@ -249,14 +254,14 @@ rt_cap = 3.0 def sample_process(rt_process, label): """Draw prior predictive samples for a given temporal process.""" - model = SharedInfections( - name="SharedInfections", + model = PopulationInfections( + name="PopulationInfections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", I0_val), - initial_log_rt_rv=DeterministicVariable( + log_rt_time_0_rv=DeterministicVariable( "initial_log_rt", initial_log_rt ), - shared_rt_process=rt_process, + single_rt_process=rt_process, n_initialization_points=n_init, ) @@ -266,9 +271,11 @@ def sample_process(rt_process, label): samples = Predictive(sampler, num_samples=n_samples)(random.PRNGKey(42)) return { - "rt": np.array(samples["SharedInfections::rt_shared"])[:, n_init:, 0], + "rt": np.array(samples["PopulationInfections::rt_single"])[ + :, n_init:, 0 + ], "infections": np.array( - samples["SharedInfections::infections_aggregate"] + samples["PopulationInfections::infections_aggregate"] )[:, n_init:], } @@ -686,7 +693,7 @@ The latent infection trajectory is not observed directly. To connect it to data, The `PyrenewBuilder` handles the wiring: -1. **`configure_latent()`** sets the shared infection process (called once) +1. **`configure_latent()`** sets the single infection process (called once) 2. **`add_observation()`** adds an observation process (called once per data stream) 3. **`build()`** computes `n_initialization_points` from all delay distributions and produces a model ready for inference diff --git a/pyrenew/latent/__init__.py b/pyrenew/latent/__init__.py index 8c428579..45f6f06b 100644 --- a/pyrenew/latent/__init__.py +++ b/pyrenew/latent/__init__.py @@ -5,7 +5,6 @@ LatentSample, PopulationStructure, ) -from pyrenew.latent.hierarchical_infections import HierarchicalInfections from pyrenew.latent.hierarchical_priors import ( GammaGroupSdPrior, HierarchicalNormalPrior, @@ -27,7 +26,8 @@ ) from pyrenew.latent.infections import Infections from pyrenew.latent.infectionswithfeedback import InfectionsWithFeedback -from pyrenew.latent.shared_infections import SharedInfections +from pyrenew.latent.population_infections import PopulationInfections +from pyrenew.latent.subpopulation_infections import SubpopulationInfections from pyrenew.latent.temporal_processes import ( AR1, DifferencedAR1, @@ -50,10 +50,8 @@ "BaseLatentInfectionProcess", "LatentSample", "PopulationStructure", - # Hierarchical infection processes - "HierarchicalInfections", - # Shared infection processes - "SharedInfections", + "SubpopulationInfections", + "PopulationInfections", # Hierarchical priors "HierarchicalNormalPrior", "GammaGroupSdPrior", diff --git a/pyrenew/latent/base.py b/pyrenew/latent/base.py index 4d42308a..c84d393a 100644 --- a/pyrenew/latent/base.py +++ b/pyrenew/latent/base.py @@ -300,9 +300,9 @@ def _validate_and_prepare_I0( Validate and prepare I0 for use in the renewal equation. Subclasses override this to enforce shape constraints (e.g., scalar - for SharedInfections) or broadcast I0 to match the population + for PopulationInfections) or broadcast I0 to match the population structure (e.g., scalar to per-subpop array for - HierarchicalInfections). + SubpopulationInfections). Parameters ---------- diff --git a/pyrenew/latent/shared_infections.py b/pyrenew/latent/population_infections.py similarity index 79% rename from pyrenew/latent/shared_infections.py rename to pyrenew/latent/population_infections.py index ac7ec1df..2671372a 100644 --- a/pyrenew/latent/shared_infections.py +++ b/pyrenew/latent/population_infections.py @@ -1,5 +1,5 @@ """ -Shared latent infection process renewal model. +Populaton-level single-trajectory latent infection process renewal model. """ from __future__ import annotations @@ -22,7 +22,7 @@ from pyrenew.metaclass import RandomVariable -class SharedInfections(BaseLatentInfectionProcess): +class PopulationInfections(BaseLatentInfectionProcess): """ A single $\ mathcal{R}(t)$ trajectory drives one renewal equation. @@ -36,27 +36,27 @@ def __init__( gen_int_rv: RandomVariable, n_initialization_points: int, I0_rv: RandomVariable, - shared_rt_process: TemporalProcess, - initial_log_rt_rv: RandomVariable, + single_rt_process: TemporalProcess, + log_rt_time_0_rv: RandomVariable, ) -> None: """ - Initialize shared infections process. + Initialize population-level infections process. Parameters ---------- name Name prefix for numpyro sample sites. All deterministic quantities are recorded under this scope (e.g., - ``"{name}::rt_shared"``). + ``"{name}::rt_single"``). gen_int_rv Generation interval PMF n_initialization_points Number of initialization days before day 0. I0_rv Initial infection prevalence (proportion of population) - shared_rt_process - Temporal process for shared Rt dynamics - initial_log_rt_rv + single_rt_process + Temporal process for single Rt dynamics + log_rt_time_0_rv Initial value for log(Rt) at time 0. Raises @@ -77,13 +77,13 @@ def __init__( if isinstance(I0_rv, DeterministicVariable): self._validate_I0(I0_rv.value) - if initial_log_rt_rv is None: - raise ValueError("initial_log_rt_rv is required") - self.initial_log_rt_rv = initial_log_rt_rv + if log_rt_time_0_rv is None: + raise ValueError("log_rt_time_0_rv is required") + self.log_rt_time_0_rv = log_rt_time_0_rv - if shared_rt_process is None: - raise ValueError("shared_rt_process is required") - self.shared_rt_process = shared_rt_process + if single_rt_process is None: + raise ValueError("single_rt_process is required") + self.single_rt_process = single_rt_process def default_subpop_fractions(self) -> ArrayLike: """ @@ -104,7 +104,7 @@ def _validate_and_prepare_I0( """ Validate that I0 is a scalar prevalence value. - SharedInfections operates on a single population, so I0 must be + PopulationInfections operates on a single population, so I0 must be a scalar (0-dimensional array). Parameters @@ -126,13 +126,13 @@ def _validate_and_prepare_I0( """ if I0.ndim != 0: raise ValueError( - "SharedInfections requires I0_rv to return a scalar prevalence" + "PopulationInfections requires I0_rv to return a scalar prevalence" ) return super()._validate_and_prepare_I0(I0, pop) def validate(self) -> None: """ - Validate shared infections parameters. + Validate population infections parameters. Checks that the generation interval is a valid PMF. @@ -150,9 +150,9 @@ def sample( **kwargs: object, ) -> LatentSample: """ - Sample shared infections using a single renewal process. + Sample population infections using a single renewal process. - Generates a shared Rt trajectory, computes initial infections via + Generates a single Rt trajectory, computes initial infections via exponential backprojection, and runs one renewal equation. Parameters @@ -184,28 +184,28 @@ def sample( frac_check = jnp.isclose(pop.fractions[0], 1.0, atol=1e-6) if pop.n_subpops != 1 or (not_jax_tracer(frac_check) and not frac_check): raise ValueError( - "SharedInfections requires exactly one subpopulation " + "PopulationInfections requires exactly one subpopulation " "with fraction [1.0]" ) n_total_days = self.n_initialization_points + n_days_post_init - initial_log_rt = self.initial_log_rt_rv() + initial_log_rt = self.log_rt_time_0_rv() - log_rt_shared = self.shared_rt_process.sample( + log_rt_single = self.single_rt_process.sample( n_timepoints=n_total_days, initial_value=initial_log_rt, - name_prefix="log_rt_shared", + name_prefix="log_rt_single", ) - rt_shared = jnp.exp(log_rt_shared) + rt_single = jnp.exp(log_rt_single) gen_int = self.gen_int_rv() I0 = self._validate_and_prepare_I0(jnp.asarray(self.I0_rv()), pop) initial_r = r_approx_from_R( - R=rt_shared[0, 0], + R=rt_single[0, 0], g=gen_int, n_newton_steps=4, ) @@ -218,7 +218,7 @@ def sample( post_init_infections = compute_infections_from_rt( I0=recent_I0, - Rt=rt_shared[self.n_initialization_points :, 0], + Rt=rt_single[self.n_initialization_points :, 0], reversed_generation_interval_pmf=gen_int_reversed, ) @@ -235,8 +235,8 @@ def sample( with numpyro.handlers.scope(prefix=self.name, divider="::"): numpyro.deterministic("I0_init", I0_init) - numpyro.deterministic("log_rt_shared", log_rt_shared) - numpyro.deterministic("rt_shared", rt_shared) + numpyro.deterministic("log_rt_single", log_rt_single) + numpyro.deterministic("rt_single", rt_single) numpyro.deterministic("infections_aggregate", infections_aggregate) return LatentSample( diff --git a/pyrenew/latent/hierarchical_infections.py b/pyrenew/latent/subpopulation_infections.py similarity index 96% rename from pyrenew/latent/hierarchical_infections.py rename to pyrenew/latent/subpopulation_infections.py index 3e13d6e9..8eb43067 100644 --- a/pyrenew/latent/hierarchical_infections.py +++ b/pyrenew/latent/subpopulation_infections.py @@ -24,7 +24,7 @@ from pyrenew.metaclass import RandomVariable -class HierarchicalInfections(BaseLatentInfectionProcess): +class SubpopulationInfections(BaseLatentInfectionProcess): """ Multi-subpopulation renewal model with hierarchical Rt structure. @@ -58,7 +58,7 @@ def __init__( I0_rv: RandomVariable, baseline_rt_process: TemporalProcess, subpop_rt_deviation_process: TemporalProcess, - initial_log_rt_rv: RandomVariable, + log_rt_time_0_rv: RandomVariable, ) -> None: """ Initialize hierarchical infections process. @@ -79,7 +79,7 @@ def __init__( Temporal process for baseline Rt dynamics subpop_rt_deviation_process Temporal process for subpopulation deviations - initial_log_rt_rv + log_rt_time_0_rv Initial value for log(Rt) at time 0. Raises @@ -100,9 +100,9 @@ def __init__( if isinstance(I0_rv, DeterministicVariable): self._validate_I0(I0_rv.value) - if initial_log_rt_rv is None: - raise ValueError("initial_log_rt_rv is required") - self.initial_log_rt_rv = initial_log_rt_rv + if log_rt_time_0_rv is None: + raise ValueError("log_rt_time_0_rv is required") + self.log_rt_time_0_rv = log_rt_time_0_rv if baseline_rt_process is None: raise ValueError("baseline_rt_process is required") @@ -193,7 +193,7 @@ def sample( n_total_days = self.n_initialization_points + n_days_post_init - initial_log_rt = self.initial_log_rt_rv() + initial_log_rt = self.log_rt_time_0_rv() log_rt_baseline = self.baseline_rt_process.sample( n_timepoints=n_total_days, diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 8d7d78b1..b2ec8768 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -20,7 +20,7 @@ class MultiSignalModel(Model): """ Multi-signal renewal model. - Combines a latent infection process (e.g., HierarchicalInfections, + Combines a latent infection process (e.g., SubpopulationInfections, PartitionedInfections) with multiple observation processes (e.g., CountObservation, WastewaterObservation). diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index 98c0b797..a8d30fc0 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -61,7 +61,7 @@ def configure_latent( Parameters ---------- latent_class - Class to use for latent infections (e.g., HierarchicalInfections, + Class to use for latent infections (e.g., SubpopulationInfections, PartitionedInfections, or a custom implementation) **params Parameters for latent class constructor (model structure). diff --git a/test/conftest.py b/test/conftest.py index 980c5b25..e9a5e16c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,7 +10,12 @@ import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import AR1, HierarchicalInfections, RandomWalk +from pyrenew.latent import ( + AR1, + PopulationInfections, + RandomWalk, + SubpopulationInfections, +) from pyrenew.observation import ( Counts, HierarchicalNormalNoise, @@ -146,31 +151,51 @@ def hierarchical_normal_noise_tight(): # ============================================================================= -# Hierarchical Infections Fixture +# Latent Infections Fixtures # ============================================================================= @pytest.fixture -def hierarchical_infections(gen_int_rv): +def subpopulation_infections(gen_int_rv): """ - Pre-configured HierarchicalInfections instance. + Pre-configured SubpopulationInfections instance. Returns ------- - HierarchicalInfections + SubpopulationInfections Configured infection process with realistic parameters. """ - return HierarchicalInfections( - name="hierarchical", + return SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, ) +@pytest.fixture +def population_infections(gen_int_rv): + """ + Pre-configured PopulationInfections instance. + + Returns + ------- + PopulationInfections + Configured infection process with realistic parameters. + """ + return PopulationInfections( + name="population", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + n_initialization_points=7, + ) + + # ============================================================================= # Counts Process Fixtures # ============================================================================= diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 2592b083..8505b813 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -21,7 +21,7 @@ ) from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import AR1 -from pyrenew.latent.shared_infections import SharedInfections +from pyrenew.latent.population_infections import PopulationInfections from pyrenew.model import PyrenewBuilder from pyrenew.observation import Counts, NegativeBinomialNoise from pyrenew.randomvariable import DistributionalVariable @@ -137,7 +137,7 @@ def he_model( ed_day_of_week_effects: jnp.ndarray, ) -> PyrenewBuilder: """ - Build a SharedInfections model with hospital + ED observation processes. + Build a PopulationInfections model with hospital + ED observation processes. Parameters ---------- @@ -159,13 +159,11 @@ def he_model( builder = PyrenewBuilder() builder.configure_latent( - SharedInfections, + PopulationInfections, gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), - initial_log_rt_rv=DistributionalVariable( - "initial_log_rt", dist.Normal(0.0, 0.5) - ), - shared_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), ) hospital_obs = Counts( diff --git a/test/integration/test_shared_infections_he.py b/test/integration/test_population_infections_he.py similarity index 95% rename from test/integration/test_shared_infections_he.py rename to test/integration/test_population_infections_he.py index 197913a4..12ac6553 100644 --- a/test/integration/test_shared_infections_he.py +++ b/test/integration/test_population_infections_he.py @@ -1,7 +1,7 @@ """ -Integration test: SharedInfections H+E model with posterior recovery. +Integration test: PopulationInfections H+E model with posterior recovery. -Fits a SharedInfections model with hospital admissions and ED visit +Fits a PopulationInfections model with hospital admissions and ED visit observation processes to synthetic 120-day CA data, then checks that posterior estimates recover known true parameters. """ @@ -173,9 +173,9 @@ def posterior_dt( fitted_model.mcmc, dims={ "latent_infections": ["time"], - "SharedInfections::infections_aggregate": ["time"], - "SharedInfections::log_rt_shared": ["time", "dummy"], - "SharedInfections::rt_shared": ["time", "dummy"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], "hospital_predicted": ["time"], "ed_predicted": ["time"], }, @@ -216,7 +216,7 @@ def test_mcmc_convergence( """ summary = az.summary( posterior_dt, - var_names=["I0", "initial_log_rt", "ihr", "iedr"], + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], ) rhat = summary["r_hat"].astype(float) ess = summary["ess_bulk"].astype(float) @@ -239,7 +239,7 @@ def test_rt_posterior_covers_truth( daily_infections : pl.DataFrame True infections and R(t) trajectory. """ - rt_posterior = posterior_dt.posterior["SharedInfections::rt_shared"] + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py index 500b65dc..b919efed 100644 --- a/test/test_interface_coverage.py +++ b/test/test_interface_coverage.py @@ -23,12 +23,12 @@ AR1, DifferencedAR1, GammaGroupSdPrior, - HierarchicalInfections, HierarchicalNormalPrior, Infections, InfectionsWithFeedback, RandomWalk, StudentTGroupModePrior, + SubpopulationInfections, ) from pyrenew.metaclass import RandomVariable from pyrenew.observation import ( @@ -225,11 +225,11 @@ def validate_data(self, n_total, n_subpops, **obs_data): # numpydoc ignore=GL08 def test_get_required_lookback(gen_int_rv): """get_required_lookback returns generation interval PMF length.""" - infections = HierarchicalInfections( + infections = SubpopulationInfections( name="hierarchical", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, @@ -239,17 +239,17 @@ def test_get_required_lookback(gen_int_rv): # ============================================================================= -# HierarchicalInfections.validate() coverage +# SubpopulationInfections.validate() coverage # ============================================================================= def test_hierarchical_infections_validate(gen_int_rv): - """HierarchicalInfections.validate() runs without error on valid PMF.""" - infections = HierarchicalInfections( + """SubpopulationInfections.validate() runs without error on valid PMF.""" + infections = SubpopulationInfections( name="hierarchical", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, @@ -435,12 +435,12 @@ def test_name_attribute_matches_expected(instance, expected_name): def test_hierarchical_infections_name(gen_int_rv): - """HierarchicalInfections.name is correctly set during construction.""" - infections = HierarchicalInfections( + """SubpopulationInfections.name is correctly set during construction.""" + infections = SubpopulationInfections( name="test_hi", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, diff --git a/test/test_shared_infections.py b/test/test_population_infections.py similarity index 55% rename from test/test_shared_infections.py rename to test/test_population_infections.py index dfefc5c7..261ac8f7 100644 --- a/test/test_shared_infections.py +++ b/test/test_population_infections.py @@ -1,5 +1,5 @@ """ -Unit tests for SharedInfections. +Unit tests for PopulationInfections. """ import jax.numpy as jnp @@ -7,110 +7,90 @@ import pytest from pyrenew.deterministic import DeterministicVariable -from pyrenew.latent import AR1, RandomWalk -from pyrenew.latent.shared_infections import SharedInfections - - -@pytest.fixture -def shared_infections(gen_int_rv): - """ - Pre-configured SharedInfections instance. - - Returns - ------- - SharedInfections - Configured infection process with realistic parameters. - """ - return SharedInfections( - name="shared", - gen_int_rv=gen_int_rv, - I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), - n_initialization_points=7, - ) - - -class TestSharedInfectionsSample: - """Test sample method for SharedInfections.""" - - def test_output_shapes(self, shared_infections): +from pyrenew.latent import RandomWalk +from pyrenew.latent.population_infections import PopulationInfections + + +class TestPopulationInfectionsSample: + """Test sample method for PopulationInfections.""" + + def test_output_shapes(self, population_infections): """Test that output shapes are correct for a single-population model.""" n_days_post_init = 30 - n_total = shared_infections.n_initialization_points + n_days_post_init + n_total = population_infections.n_initialization_points + n_days_post_init with numpyro.handlers.seed(rng_seed=42): - inf_agg, inf_all = shared_infections.sample( + inf_agg, inf_all = population_infections.sample( n_days_post_init=n_days_post_init, ) assert inf_agg.shape == (n_total,) assert inf_all.shape == (n_total, 1) - def test_aggregate_equals_single_subpop(self, shared_infections): + def test_aggregate_equals_single_subpop(self, population_infections): """Test that aggregate infections equal the single subpopulation column.""" with numpyro.handlers.seed(rng_seed=42): - inf_agg, inf_all = shared_infections.sample( + inf_agg, inf_all = population_infections.sample( n_days_post_init=30, ) assert jnp.allclose(inf_agg, inf_all[:, 0], atol=1e-6) - def test_infections_are_positive(self, shared_infections): + def test_infections_are_positive(self, population_infections): """Test that all infections are positive.""" with numpyro.handlers.seed(rng_seed=42): - inf_agg, inf_all = shared_infections.sample( + inf_agg, inf_all = population_infections.sample( n_days_post_init=30, ) assert jnp.all(inf_agg > 0) assert jnp.all(inf_all > 0) - def test_deterministic_sites_recorded(self, shared_infections): + def test_deterministic_sites_recorded(self, population_infections): """Test that expected numpyro deterministic sites are recorded.""" with numpyro.handlers.seed(rng_seed=42): with numpyro.handlers.trace() as trace: - shared_infections.sample(n_days_post_init=30) + population_infections.sample(n_days_post_init=30) expected_sites = [ - "shared::I0_init", - "shared::log_rt_shared", - "shared::rt_shared", - "shared::infections_aggregate", + "population::I0_init", + "population::log_rt_single", + "population::rt_single", + "population::infections_aggregate", ] for site in expected_sites: assert site in trace, f"Missing deterministic site: {site}" - def test_rt_is_exp_of_log_rt(self, shared_infections): + def test_rt_is_exp_of_log_rt(self, population_infections): """Test that recorded Rt equals exp of recorded log Rt.""" with numpyro.handlers.seed(rng_seed=42): with numpyro.handlers.trace() as trace: - shared_infections.sample(n_days_post_init=30) + population_infections.sample(n_days_post_init=30) - log_rt = trace["shared::log_rt_shared"]["value"] - rt = trace["shared::rt_shared"]["value"] + log_rt = trace["population::log_rt_single"]["value"] + rt = trace["population::rt_single"]["value"] assert jnp.allclose(rt, jnp.exp(log_rt), atol=1e-6) - def test_default_fractions_used_when_none(self, shared_infections): + def test_default_fractions_used_when_none(self, population_infections): """Test that default fractions [1.0] are used when not provided.""" with numpyro.handlers.seed(rng_seed=42): - inf_agg, inf_all = shared_infections.sample( + inf_agg, inf_all = population_infections.sample( n_days_post_init=30, subpop_fractions=None, ) assert inf_all.shape[1] == 1 - def test_explicit_fractions_one(self, shared_infections): + def test_explicit_fractions_one(self, population_infections): """Test that explicit fractions [1.0] produce same results as default.""" with numpyro.handlers.seed(rng_seed=42): - inf_agg_default, inf_all_default = shared_infections.sample( + inf_agg_default, inf_all_default = population_infections.sample( n_days_post_init=30, ) with numpyro.handlers.seed(rng_seed=42): - inf_agg_explicit, inf_all_explicit = shared_infections.sample( + inf_agg_explicit, inf_all_explicit = population_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([1.0]), ) @@ -118,24 +98,24 @@ def test_explicit_fractions_one(self, shared_infections): assert jnp.allclose(inf_agg_default, inf_agg_explicit, atol=1e-6) assert jnp.allclose(inf_all_default, inf_all_explicit, atol=1e-6) - def test_different_seeds_give_different_results(self, shared_infections): + def test_different_seeds_give_different_results(self, population_infections): """Test that different RNG seeds produce different trajectories.""" with numpyro.handlers.seed(rng_seed=1): - inf_agg_1, _ = shared_infections.sample(n_days_post_init=30) + inf_agg_1, _ = population_infections.sample(n_days_post_init=30) with numpyro.handlers.seed(rng_seed=999): - inf_agg_2, _ = shared_infections.sample(n_days_post_init=30) + inf_agg_2, _ = population_infections.sample(n_days_post_init=30) assert not jnp.allclose(inf_agg_1, inf_agg_2) def test_custom_name_prefix(self, gen_int_rv): """Test that custom name prefix is used in deterministic sites.""" - process = SharedInfections( + process = PopulationInfections( name="my_infections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=RandomWalk(), n_initialization_points=7, ) @@ -143,69 +123,69 @@ def test_custom_name_prefix(self, gen_int_rv): with numpyro.handlers.trace() as trace: process.sample(n_days_post_init=30) - assert "my_infections::rt_shared" in trace + assert "my_infections::rt_single" in trace -class TestSharedInfectionsValidation: +class TestPopulationInfectionsValidation: """Test validation of inputs.""" def test_rejects_missing_I0_rv(self, gen_int_rv): """Test that None I0_rv is rejected.""" with pytest.raises(ValueError, match="I0_rv is required"): - SharedInfections( - name="shared", + PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=None, - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=RandomWalk(), n_initialization_points=7, ) - def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv): - """Test that None initial_log_rt_rv is rejected.""" - with pytest.raises(ValueError, match="initial_log_rt_rv is required"): - SharedInfections( - name="shared", + def test_rejects_missing_log_rt_time_0_rv(self, gen_int_rv): + """Test that None log_rt_time_0_rv is rejected.""" + with pytest.raises(ValueError, match="log_rt_time_0_rv is required"): + PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=None, - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=None, + single_rt_process=RandomWalk(), n_initialization_points=7, ) - def test_rejects_missing_shared_rt_process(self, gen_int_rv): - """Test that None shared_rt_process is rejected.""" - with pytest.raises(ValueError, match="shared_rt_process is required"): - SharedInfections( - name="shared", + def test_rejects_missing_single_rt_process(self, gen_int_rv): + """Test that None single_rt_process is rejected.""" + with pytest.raises(ValueError, match="single_rt_process is required"): + PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=None, + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=None, n_initialization_points=7, ) def test_rejects_invalid_I0(self, gen_int_rv): """Test that invalid I0 values are rejected at construction.""" with pytest.raises(ValueError, match="I0 must be positive"): - SharedInfections( - name="shared", + PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", -0.1), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=RandomWalk(), n_initialization_points=7, ) def test_rejects_I0_greater_than_one(self, gen_int_rv): """Test that I0 > 1 is rejected at construction.""" with pytest.raises(ValueError, match="I0 must be <= 1"): - SharedInfections( - name="shared", + PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 1.5), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=RandomWalk(), n_initialization_points=7, ) @@ -214,46 +194,46 @@ def test_rejects_insufficient_n_initialization_points(self, gen_int_rv): with pytest.raises( ValueError, match="n_initialization_points must be at least" ): - SharedInfections( - name="shared", + PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=RandomWalk(), n_initialization_points=2, ) - def test_rejects_fractions_not_summing_to_one(self, shared_infections): + def test_rejects_fractions_not_summing_to_one(self, population_infections): """Test that fractions not summing to 1 raises error at sample time.""" with pytest.raises(ValueError, match="must sum to 1.0"): with numpyro.handlers.seed(rng_seed=42): - shared_infections.sample( + population_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.5]), ) def test_rejects_multiple_subpop_fractions_even_if_sum_to_one( - self, shared_infections + self, population_infections ): - """Test that multi-element fractions are rejected for shared infections.""" + """Test that multi-element fractions are rejected for population infections.""" with pytest.raises( ValueError, match="requires exactly one subpopulation with fraction \\[1.0\\]", ): with numpyro.handlers.seed(rng_seed=42): - shared_infections.sample( + population_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.5, 0.5]), ) def test_rejects_non_scalar_I0(self, gen_int_rv): """Test that vector-valued I0 is rejected with a clear error.""" - process = SharedInfections( - name="shared", + process = PopulationInfections( + name="population", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", jnp.array([0.001, 0.002])), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - shared_rt_process=RandomWalk(), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=RandomWalk(), n_initialization_points=7, ) @@ -264,47 +244,47 @@ def test_rejects_non_scalar_I0(self, gen_int_rv): with numpyro.handlers.seed(rng_seed=42): process.sample(n_days_post_init=30) - def test_validate_passes(self, shared_infections): + def test_validate_passes(self, population_infections): """Test that validate() succeeds for a properly constructed instance.""" - shared_infections.validate() + population_infections.validate() - def test_default_subpop_fractions(self, shared_infections): + def test_default_subpop_fractions(self, population_infections): """Test that default_subpop_fractions returns [1.0].""" - fracs = shared_infections.default_subpop_fractions() + fracs = population_infections.default_subpop_fractions() assert jnp.allclose(fracs, jnp.array([1.0])) -class TestSharedValidateAndPrepareI0: - """Test _validate_and_prepare_I0 for SharedInfections.""" +class TestPopulationValidateAndPrepareI0: + """Test _validate_and_prepare_I0 for PopulationInfections.""" - def test_accepts_valid_scalar(self, shared_infections): + def test_accepts_valid_scalar(self, population_infections): """Test that a valid scalar I0 passes through unchanged.""" - pop = shared_infections._parse_and_validate_fractions() + pop = population_infections._parse_and_validate_fractions() I0 = jnp.array(0.01) - result = shared_infections._validate_and_prepare_I0(I0, pop) + result = population_infections._validate_and_prepare_I0(I0, pop) assert result.ndim == 0 assert jnp.isclose(result, 0.01) - def test_rejects_vector(self, shared_infections): + def test_rejects_vector(self, population_infections): """Test that a vector I0 is rejected.""" - pop = shared_infections._parse_and_validate_fractions() + pop = population_infections._parse_and_validate_fractions() I0 = jnp.array([0.01, 0.02]) with pytest.raises(ValueError, match="scalar prevalence"): - shared_infections._validate_and_prepare_I0(I0, pop) + population_infections._validate_and_prepare_I0(I0, pop) - def test_rejects_negative(self, shared_infections): + def test_rejects_negative(self, population_infections): """Test that negative I0 is rejected.""" - pop = shared_infections._parse_and_validate_fractions() + pop = population_infections._parse_and_validate_fractions() I0 = jnp.array(-0.01) with pytest.raises(ValueError, match="I0 must be positive"): - shared_infections._validate_and_prepare_I0(I0, pop) + population_infections._validate_and_prepare_I0(I0, pop) - def test_rejects_greater_than_one(self, shared_infections): + def test_rejects_greater_than_one(self, population_infections): """Test that I0 > 1 is rejected.""" - pop = shared_infections._parse_and_validate_fractions() + pop = population_infections._parse_and_validate_fractions() I0 = jnp.array(1.5) with pytest.raises(ValueError, match="I0 must be <= 1"): - shared_infections._validate_and_prepare_I0(I0, pop) + population_infections._validate_and_prepare_I0(I0, pop) if __name__ == "__main__": diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 74769411..211a3ace 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -6,7 +6,7 @@ import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import HierarchicalInfections, RandomWalk +from pyrenew.latent import RandomWalk, SubpopulationInfections from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import Counts, CountsBySubpop, NegativeBinomialNoise @@ -28,10 +28,10 @@ def simple_builder(): gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), ) @@ -66,10 +66,10 @@ def validation_builder(): gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), ) @@ -105,10 +105,10 @@ def test_rejects_population_structure_at_configure_time(self): with pytest.raises(ValueError, match="Do not specify"): builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), subpop_fractions=jnp.array([0.5, 0.5]), @@ -121,10 +121,10 @@ def test_rejects_n_initialization_points_at_configure_time(self): with pytest.raises(ValueError, match="Do not specify n_initialization_points"): builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), n_initialization_points=10, @@ -136,20 +136,20 @@ def test_rejects_reconfiguring_latent(self): gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), ) with pytest.raises(RuntimeError, match="already configured"): builder.configure_latent( - HierarchicalInfections, + SubpopulationInfections, gen_int_rv=gen_int, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), ) @@ -189,7 +189,7 @@ def test_compute_n_initialization_points_without_latent_raises(self): def test_compute_n_initialization_points_without_gen_int_raises(self): """Test that compute_n_initialization_points without gen_int_rv raises.""" builder = PyrenewBuilder() - builder.latent_class = HierarchicalInfections + builder.latent_class = SubpopulationInfections builder.latent_params = {} with pytest.raises(ValueError, match="gen_int_rv is required"): diff --git a/test/test_hierarchical_infections.py b/test/test_subpopulation_infections.py similarity index 68% rename from test/test_hierarchical_infections.py rename to test/test_subpopulation_infections.py index 60501e06..658546e3 100644 --- a/test/test_hierarchical_infections.py +++ b/test/test_subpopulation_infections.py @@ -1,5 +1,5 @@ """ -Unit tests for HierarchicalInfections. +Unit tests for SubpopulationInfections. """ import jax.numpy as jnp @@ -7,18 +7,18 @@ import pytest from pyrenew.deterministic import DeterministicVariable -from pyrenew.latent import HierarchicalInfections, RandomWalk +from pyrenew.latent import RandomWalk, SubpopulationInfections -class TestHierarchicalInfectionsSample: +class TestSubpopulationInfectionsSample: """Test sample method with population structure at sample time.""" - def test_jurisdiction_total_is_weighted_sum(self, hierarchical_infections): + def test_jurisdiction_total_is_weighted_sum(self, subpopulation_infections): """Test that jurisdiction total equals weighted sum of subpopulations.""" fractions = jnp.array([0.3, 0.25, 0.45]) with numpyro.handlers.seed(rng_seed=42): - inf_juris, inf_all = hierarchical_infections.sample( + inf_juris, inf_all = subpopulation_infections.sample( n_days_post_init=30, subpop_fractions=fractions, ) @@ -27,24 +27,24 @@ def test_jurisdiction_total_is_weighted_sum(self, hierarchical_infections): assert jnp.allclose(inf_juris, expected, atol=1e-6) - def test_deviations_sum_to_zero(self, hierarchical_infections): + def test_deviations_sum_to_zero(self, subpopulation_infections): """Test that subpopulation deviations sum to zero (identifiability).""" with numpyro.handlers.seed(rng_seed=42): with numpyro.handlers.trace() as trace: - hierarchical_infections.sample( + subpopulation_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.3, 0.25, 0.45]), ) - deviations = trace["hierarchical::subpop_deviations"]["value"] + deviations = trace["subpopulation::subpop_deviations"]["value"] deviation_sums = jnp.sum(deviations, axis=1) assert jnp.allclose(deviation_sums, 0.0, atol=1e-6) - def test_infections_are_positive(self, hierarchical_infections): + def test_infections_are_positive(self, subpopulation_infections): """Test that all infections are positive (epidemiological invariant).""" with numpyro.handlers.seed(rng_seed=42): - inf_juris, inf_all = hierarchical_infections.sample( + inf_juris, inf_all = subpopulation_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.3, 0.25, 0.45]), ) @@ -62,15 +62,15 @@ def test_infections_are_positive(self, hierarchical_infections): ids=["K=1", "K=3", "K=6"], ) def test_shape_and_positivity_across_subpop_counts( - self, hierarchical_infections, fractions + self, subpopulation_infections, fractions ): """Test correct shapes and positivity for varying numbers of subpops.""" n_days_post_init = 30 - n_total = hierarchical_infections.n_initialization_points + n_days_post_init + n_total = subpopulation_infections.n_initialization_points + n_days_post_init n_subpops = len(fractions) with numpyro.handlers.seed(rng_seed=42): - inf_juris, inf_all = hierarchical_infections.sample( + inf_juris, inf_all = subpopulation_infections.sample( n_days_post_init=n_days_post_init, subpop_fractions=fractions, ) @@ -85,30 +85,30 @@ def test_shape_and_positivity_across_subpop_counts( assert jnp.allclose(inf_juris, expected, atol=1e-6) -class TestHierarchicalInfectionsValidation: +class TestSubpopulationInfectionsValidation: """Test validation of inputs.""" def test_rejects_missing_I0_rv(self, gen_int_rv): """Test that None I0_rv is rejected.""" with pytest.raises(ValueError, match="I0_rv is required"): - HierarchicalInfections( - name="hierarchical", + SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=None, - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), n_initialization_points=7, ) - def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv): - """Test that None initial_log_rt_rv is rejected.""" - with pytest.raises(ValueError, match="initial_log_rt_rv is required"): - HierarchicalInfections( - name="hierarchical", + def test_rejects_missing_log_rt_time_0_rv(self, gen_int_rv): + """Test that None log_rt_time_0_rv is rejected.""" + with pytest.raises(ValueError, match="log_rt_time_0_rv is required"): + SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=None, + log_rt_time_0_rv=None, baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), n_initialization_points=7, @@ -117,11 +117,11 @@ def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv): def test_rejects_missing_baseline_rt_process(self, gen_int_rv): """Test that None baseline_rt_process is rejected.""" with pytest.raises(ValueError, match="baseline_rt_process is required"): - HierarchicalInfections( - name="hierarchical", + SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=None, subpop_rt_deviation_process=RandomWalk(), n_initialization_points=7, @@ -130,11 +130,11 @@ def test_rejects_missing_baseline_rt_process(self, gen_int_rv): def test_rejects_missing_subpop_rt_deviation_process(self, gen_int_rv): """Test that None subpop_rt_deviation_process is rejected.""" with pytest.raises(ValueError, match="subpop_rt_deviation_process is required"): - HierarchicalInfections( - name="hierarchical", + SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=None, n_initialization_points=7, @@ -143,11 +143,11 @@ def test_rejects_missing_subpop_rt_deviation_process(self, gen_int_rv): def test_rejects_invalid_I0(self, gen_int_rv): """Test that invalid I0 values are rejected.""" with pytest.raises(ValueError, match="I0 must be positive"): - HierarchicalInfections( - name="hierarchical", + SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", -0.1), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), n_initialization_points=7, @@ -158,36 +158,36 @@ def test_rejects_insufficient_n_initialization_points(self, gen_int_rv): with pytest.raises( ValueError, match="n_initialization_points must be at least" ): - HierarchicalInfections( - name="hierarchical", + SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), n_initialization_points=2, ) - def test_rejects_fractions_not_summing_to_one(self, hierarchical_infections): + def test_rejects_fractions_not_summing_to_one(self, subpopulation_infections): """Test that fractions not summing to 1 raises error at sample time.""" with pytest.raises(ValueError, match="must sum to 1.0"): with numpyro.handlers.seed(rng_seed=42): - hierarchical_infections.sample( + subpopulation_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.3, 0.25, 0.40]), ) -class TestHierarchicalInfectionsPerSubpopI0: +class TestSubpopulationInfectionsPerSubpopI0: """Test per-subpopulation I0 values.""" def test_per_subpop_I0_array(self, gen_int_rv): """Test with per-subpopulation I0 values and verify positivity.""" - process = HierarchicalInfections( - name="hierarchical", + process = SubpopulationInfections( + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", jnp.array([0.001, 0.002, 0.0015])), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), n_initialization_points=7, @@ -207,46 +207,46 @@ def test_per_subpop_I0_array(self, gen_int_rv): assert jnp.all(inf_all > 0) -class TestHierarchicalValidateAndPrepareI0: - """Test _validate_and_prepare_I0 for HierarchicalInfections.""" +class TestSubpopulationValidateAndPrepareI0: + """Test _validate_and_prepare_I0 for SubpopulationInfections.""" - def test_broadcasts_scalar_to_subpop_array(self, hierarchical_infections): + def test_broadcasts_scalar_to_subpop_array(self, subpopulation_infections): """Test that scalar I0 is broadcast to per-subpopulation array.""" - pop = hierarchical_infections._parse_and_validate_fractions( + pop = subpopulation_infections._parse_and_validate_fractions( subpop_fractions=jnp.array([0.3, 0.25, 0.45]) ) I0 = jnp.array(0.01) - result = hierarchical_infections._validate_and_prepare_I0(I0, pop) + result = subpopulation_infections._validate_and_prepare_I0(I0, pop) assert result.shape == (3,) assert jnp.allclose(result, 0.01) - def test_passes_through_matching_array(self, hierarchical_infections): + def test_passes_through_matching_array(self, subpopulation_infections): """Test that a per-subpopulation I0 array passes through unchanged.""" - pop = hierarchical_infections._parse_and_validate_fractions( + pop = subpopulation_infections._parse_and_validate_fractions( subpop_fractions=jnp.array([0.3, 0.25, 0.45]) ) I0 = jnp.array([0.001, 0.002, 0.0015]) - result = hierarchical_infections._validate_and_prepare_I0(I0, pop) + result = subpopulation_infections._validate_and_prepare_I0(I0, pop) assert result.shape == (3,) assert jnp.allclose(result, I0) - def test_rejects_negative(self, hierarchical_infections): + def test_rejects_negative(self, subpopulation_infections): """Test that negative I0 is rejected.""" - pop = hierarchical_infections._parse_and_validate_fractions( + pop = subpopulation_infections._parse_and_validate_fractions( subpop_fractions=jnp.array([0.5, 0.5]) ) I0 = jnp.array(-0.01) with pytest.raises(ValueError, match="I0 must be positive"): - hierarchical_infections._validate_and_prepare_I0(I0, pop) + subpopulation_infections._validate_and_prepare_I0(I0, pop) - def test_rejects_greater_than_one(self, hierarchical_infections): + def test_rejects_greater_than_one(self, subpopulation_infections): """Test that I0 > 1 is rejected.""" - pop = hierarchical_infections._parse_and_validate_fractions( + pop = subpopulation_infections._parse_and_validate_fractions( subpop_fractions=jnp.array([0.5, 0.5]) ) I0 = jnp.array(1.5) with pytest.raises(ValueError, match="I0 must be <= 1"): - hierarchical_infections._validate_and_prepare_I0(I0, pop) + subpopulation_infections._validate_and_prepare_I0(I0, pop) if __name__ == "__main__": From c9ef607b68435253d83730e42fb8e276976698ad Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sat, 11 Apr 2026 13:05:43 -0400 Subject: [PATCH 2/7] consistent names across docs --- docs/index.md | 5 +-- docs/tutorials/.pages | 2 +- .../tutorials/building_multisignal_models.qmd | 12 +++---- docs/tutorials/latent_infections.qmd | 36 ++++++++----------- ...md => latent_subpopulation_infections.qmd} | 14 ++++---- 5 files changed, 30 insertions(+), 39 deletions(-) rename docs/tutorials/{latent_hierarchical_infections.qmd => latent_subpopulation_infections.qmd} (98%) diff --git a/docs/index.md b/docs/index.md index 1d45c148..ed973294 100644 --- a/docs/index.md +++ b/docs/index.md @@ -99,9 +99,10 @@ pip install git+https://github.com/CDCgov/PyRenew@main - [The RandomVariable abstract base class](tutorials/random_variables.md) -- PyRenew's core abstraction and its concrete implementations. - [Building multi-signal models](tutorials/building_multisignal_models.md) -- composing a renewal model from PyRenew components using `PyrenewBuilder`. +- [Latent infections](tutorials/latent_infections.md) -- modeling latent infection trajectories over time. +- [Latent subpopulation infections](tutorials/latent_subpopulation_infections.md) -- modeling latent infections with subpopulation structure. - [Observation processes: count data](tutorials/observation_processes_counts.md) -- connecting latent infections to observed counts. - [Observation processes: measurements](tutorials/observation_processes_measurements.md) -- connecting latent infections to continuous measurements. -- [Latent hierarchical infections](tutorials/latent_hierarchical_infections.md) -- modeling infections with subpopulation structure. ## Resources @@ -115,4 +116,4 @@ pip install git+https://github.com/CDCgov/PyRenew@main ### Further reading - [Semi-mechanistic Bayesian modelling of COVID-19 with renewal processes](https://academic.oup.com/jrsssa/article-pdf/186/4/601/54770289/qnad030.pdf) (Bhatt et al., 2023) -- [Unifying incidence and prevalence under a time-varying general branching process](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf) +- [Unifying incidence and prevalence under a time-varying general branching process](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf) (Pakkanen et al., 2023) diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 29657ec4..83759923 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -4,6 +4,6 @@ nav: - observation_processes_counts.md - observation_processes_measurements.md - latent_infections.md - - latent_hierarchical_infections.md + - latent_subpopulation_infections.md - right_truncation.md - day_of_week_effects.md diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 41d73d2b..06fa2cbe 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -106,7 +106,7 @@ The builder computes initialization requirements, wires components together, and Before diving into multi-signal models, you may want to review these foundational tutorials: -- **[Hierarchical Latent Infections](latent_hierarchical_infections.md)**: Understanding temporal process choices for $\mathcal{R}(t)$ +- **[Latent Infections](latent_infections.md)** and **[Latent Subpopulation Infections](latent_subpopulation_infections.md)**: Understanding temporal process choices for $\mathcal{R}(t)$ - **[Observation Processes: Counts](observation_processes_counts.md)**: Modeling count data (admissions, deaths) - **[Observation Processes: Measurements](observation_processes_measurements.md)**: Modeling continuous data (wastewater) @@ -208,14 +208,14 @@ I0_rv = DistributionalVariable("I0", dist.Beta(1, 100)) -### Initial Log Rt +### Log Rt at time $0$ -We place a prior on the initial log(Rt), centered at 0.0 (Rt = 1.0) with moderate uncertainty: +We place a prior on the log(Rt) at time $0$, centered at 0.0 (Rt = 1.0) with moderate uncertainty: ```{python} -# | label: initial-log-rt +# | label: log-rt-time-0 log_rt_time_0_rv = DistributionalVariable( - "initial_log_rt", dist.Normal(0.0, 0.5) + "log_rt_time_0", dist.Normal(0.0, 0.5) ) ``` @@ -1139,6 +1139,6 @@ This tutorial demonstrated composing a multi-signal renewal model using `Pyrenew ### Next Steps -- Explore different temporal processes for $\mathcal{R}(t)$ in the [Hierarchical Latent Infections](latent_hierarchical_infections.md) tutorial +- Explore different temporal processes for $\mathcal{R}(t)$ in the [Latent Infections](latent_infections.md) and [Latent Subpopulation Infections](latent_subpopulation_infections.md) tutorials - Learn about count-based observation models in [Observation Processes: Counts](observation_processes_counts.md) - Learn about continuous measurement models in [Observation Processes: Measurements](observation_processes_measurements.md) diff --git a/docs/tutorials/latent_infections.qmd b/docs/tutorials/latent_infections.qmd index cf4d865b..a1ad5766 100644 --- a/docs/tutorials/latent_infections.qmd +++ b/docs/tutorials/latent_infections.qmd @@ -3,12 +3,6 @@ title: Latent Infection Processes format: gfm: code-fold: true - html: - toc: true - embed-resources: true - self-contained-math: true - code-fold: true - code-tools: true engine: jupyter jupyter: jupytext: @@ -78,7 +72,7 @@ Here, $\tau$ indexes lags in the generation interval. PyRenew provides two latent infection classes: - **`PopulationInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. Appropriate when modeling one jurisdiction as a single population with one or more observation streams. This tutorial covers `PopulationInfections`. -- **`SubpopulationInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Hierarchical Latent Infections](latent_hierarchical_infections.md). +- **`SubpopulationInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Latent Subpopulation Infections](latent_subpopulation_infections.md). ## Model Inputs @@ -132,7 +126,7 @@ gi_df = pd.DataFrame({"day": days, "probability": np.array(gen_int_pmf)}) The generation interval length determines the minimum initialization period: with a $G$-point generation interval distribution, the renewal equation at time $0$ needs an infection-history vector long enough to supply the previous $G$ infection values used in the convolution. -### Initial Conditions: `I0` and `initial_log_rt` +### Initial Conditions: `I0` and `log_rt_time_0` These two parameters jointly define the infection history before the observation period begins. Understanding their interaction requires knowing how the latent @@ -161,17 +155,17 @@ period, $n_{\text{init}} - 1$ time points before $t = 0$. It sets the scale of the entire initialization vector: $I_{\text{init}}(0) = I_0$, with subsequent entries growing or declining exponentially toward $t = 0$. -* **The shape is set by `initial_log_rt`**.
-`initial_log_rt` enters the model in two places: it is the starting point of +* **The shape is set by `log_rt_time_0`**.
+`log_rt_time_0` enters the model in two places: it is the starting point of the $\mathcal{R}(t)$ trajectory ($\mathcal{R}(t=0) = e^{\text{initial\_log\_rt}}$), and it determines the exponential growth rate -$r$ used to construct the initialization vector. When `initial_log_rt = 0`, +$r$ used to construct the initialization vector. When `log_rt_time_0 = 0`, $r = 0$ and the initialization vector is flat at level `I0`. When -`initial_log_rt > 0`, infections are growing exponentially at $t = 0$; when -`initial_log_rt < 0`, they are declining. +`log_rt_time_0 > 0`, infections are growing exponentially at $t = 0$; when +`log_rt_time_0 < 0`, they are declining. -The initialization vector is what the renewal equation "sees" as recent infection history at time $0$. We can compute it directly for three values of `initial_log_rt`: +The initialization vector is what the renewal equation "sees" as recent infection history at time $0$. We can compute it directly for three values of `log_rt_time_0`: ```{python} # | label: backprojection-compute @@ -196,11 +190,11 @@ for label, log_rt in init_rt_values.items(): { "day": int(t) - n_init, "infections": float(I0_init[t]), - "config": f"initial_log_rt = {log_rt} ({label})", + "config": f"log_rt_time_0 = {log_rt} ({label})", } ) print( - f"initial_log_rt = {log_rt:+.1f}: Rt(0) = {float(Rt0):.2f}, r = {float(r):.4f}, " + f"log_rt_time_0 = {log_rt:+.1f}: Rt(0) = {float(Rt0):.2f}, r = {float(r):.4f}, " f"I_init range = [{float(I0_init[0]):.6f}, {float(I0_init[-1]):.6f}]" ) @@ -209,7 +203,7 @@ init_df = pd.DataFrame(init_data) ```{python} # | label: fig-backprojection -# | fig-cap: Initialization vectors for three values of initial_log_rt. Days are numbered relative to day 0, which is when the temporal process and renewal equation take over. When initial_log_rt = 0 (stable), the vector is flat. Nonzero values produce exponential growth or decay in the pre-observation period. +# | fig-cap: Initialization vectors for three values of log_rt_time_0. Days are numbered relative to day 0, which is when the temporal process and renewal equation take over. When log_rt_time_0 = 0 (stable), the vector is flat. Nonzero values produce exponential growth or decay in the pre-observation period. ( p9.ggplot(init_df, p9.aes(x="day", y="infections", color="config")) + p9.geom_line(size=1) @@ -238,7 +232,7 @@ After day 0, the temporal process takes over. How quickly the trajectory departs The temporal process governs how $\log \mathcal{R}(t)$ evolves day to day. To evaluate what a given process implies, we use **prior predictive checks**: drawing many samples from the model *before seeing any data* and examining the distribution of trajectories. A single sample tells you little (the trajectory depends on the random seed), but the envelope of many samples reveals the structural constraints built into the process. -We fix the initial conditions to a growing epidemic (`initial_log_rt = 0.5`, so $\mathcal{R}(0) \approx 1.65$) with `I0 = 0.001`. Starting well above equilibrium rather than near it makes the behavioral differences between temporal processes visible: the median trajectory of a mean-reverting process drifts back toward $\mathcal{R} = 1$, while a non-reverting process does not. +We fix the initial conditions to a growing epidemic (`log_rt_time_0 = 0.5`, so $\mathcal{R}(0) \approx 1.65$) with `I0 = 0.001`. Starting well above equilibrium rather than near it makes the behavioral differences between temporal processes visible: the median trajectory of a mean-reverting process drifts back toward $\mathcal{R} = 1$, while a non-reverting process does not. This section is primarily **modeling guidance for prior specification in PyRenew**, not a claim that epidemiologic theory uniquely determines one temporal process choice. The appropriate process depends on the scientific setting, time horizon, and how strongly you want the prior to regularize latent transmission dynamics. @@ -247,7 +241,7 @@ This section is primarily **modeling guidance for prior specification in PyRenew n_days = 28 n_init = len(gen_int_pmf) n_samples = 200 -initial_log_rt = 0.5 +log_rt_time_0 = 0.5 I0_val = 0.001 rt_cap = 3.0 @@ -258,9 +252,7 @@ def sample_process(rt_process, label): name="PopulationInfections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", I0_val), - log_rt_time_0_rv=DeterministicVariable( - "initial_log_rt", initial_log_rt - ), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", log_rt_time_0), single_rt_process=rt_process, n_initialization_points=n_init, ) diff --git a/docs/tutorials/latent_hierarchical_infections.qmd b/docs/tutorials/latent_subpopulation_infections.qmd similarity index 98% rename from docs/tutorials/latent_hierarchical_infections.qmd rename to docs/tutorials/latent_subpopulation_infections.qmd index 3897d117..52397a9e 100644 --- a/docs/tutorials/latent_hierarchical_infections.qmd +++ b/docs/tutorials/latent_subpopulation_infections.qmd @@ -1,5 +1,5 @@ --- -title: Hierarchical Latent Infections +title: Latent Subpopulation Infections format: gfm: code-fold: true @@ -112,7 +112,7 @@ This model generalizes the single-population renewal model: When all $\delta_k(t) = 0$, the model reduces to the shared (single-population) case. -This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `initial_log_rt`), and temporal processes. See [Latent Infection Processes](latent_infections.md) for that background. +This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `initial_log_rt`), and temporal processes. See [Latent Infections](latent_infections.md) for that background. --- @@ -173,13 +173,13 @@ n_subpops = len(subpop_fractions) n_days = 28 n_init = len(gen_int_pmf) n_samples = 200 -initial_log_rt = 0.2 +log_rt_time_0 = 0.2 I0_val = 0.001 rt_cap = 3.0 print(f"Subpopulations: {n_subpops}") print(f"Population fractions: {np.array(subpop_fractions)}") -print(f"Initial Rt: {np.exp(initial_log_rt):.2f}") +print(f"Log Rt at time 0: {np.exp(log_rt_time_0):.2f}") ``` ```{python} @@ -188,7 +188,7 @@ model = SubpopulationInfections( name="SubpopulationInfections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", I0_val), - log_rt_time_0_rv=DeterministicVariable("initial_log_rt", initial_log_rt), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", log_rt_time_0), baseline_rt_process=DifferencedAR1(autoreg=0.5, innovation_sd=0.01), subpop_rt_deviation_process=AR1(autoreg=0.8, innovation_sd=0.05), n_initialization_points=n_init, @@ -256,9 +256,7 @@ def sample_hierarchical(baseline_process, deviation_process, label): name="SubpopulationInfections", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", I0_val), - log_rt_time_0_rv=DeterministicVariable( - "initial_log_rt", initial_log_rt - ), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", log_rt_time_0), baseline_rt_process=baseline_process, subpop_rt_deviation_process=deviation_process, n_initialization_points=n_init, From e7b6ec09b1e013138f152c6ea4e845827527d682 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sat, 11 Apr 2026 16:54:21 -0400 Subject: [PATCH 3/7] Update test/test_interface_coverage.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/test_interface_coverage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py index b919efed..e35a5d3c 100644 --- a/test/test_interface_coverage.py +++ b/test/test_interface_coverage.py @@ -243,7 +243,7 @@ def test_get_required_lookback(gen_int_rv): # ============================================================================= -def test_hierarchical_infections_validate(gen_int_rv): +def test_subpopulation_infections_validate(gen_int_rv): """SubpopulationInfections.validate() runs without error on valid PMF.""" infections = SubpopulationInfections( name="hierarchical", From c8732a06f25662c95af273111352c9a3a12fb4da Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sat, 11 Apr 2026 16:57:53 -0400 Subject: [PATCH 4/7] copilot fix --- test/test_interface_coverage.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py index b919efed..0d03e3a5 100644 --- a/test/test_interface_coverage.py +++ b/test/test_interface_coverage.py @@ -226,10 +226,10 @@ def validate_data(self, n_total, n_subpops, **obs_data): # numpydoc ignore=GL08 def test_get_required_lookback(gen_int_rv): """get_required_lookback returns generation interval PMF length.""" infections = SubpopulationInfections( - name="hierarchical", + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, @@ -243,13 +243,13 @@ def test_get_required_lookback(gen_int_rv): # ============================================================================= -def test_hierarchical_infections_validate(gen_int_rv): +def test_subpopulation_infections_validate(gen_int_rv): """SubpopulationInfections.validate() runs without error on valid PMF.""" infections = SubpopulationInfections( - name="hierarchical", + name="subpopulation", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, @@ -434,13 +434,13 @@ def test_name_attribute_matches_expected(instance, expected_name): assert instance.name == expected_name -def test_hierarchical_infections_name(gen_int_rv): +def test_subpopulation_infections_name(gen_int_rv): """SubpopulationInfections.name is correctly set during construction.""" infections = SubpopulationInfections( name="test_hi", gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", 0.001), - log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), n_initialization_points=7, From 6290cff86912895f60c7b0a4ebde62ec606e2942 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 13 Apr 2026 11:19:23 -0400 Subject: [PATCH 5/7] mathcal Rt fixes --- pyrenew/latent/population_infections.py | 8 ++++---- pyrenew/latent/subpopulation_infections.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pyrenew/latent/population_infections.py b/pyrenew/latent/population_infections.py index 2671372a..1d1ce1cf 100644 --- a/pyrenew/latent/population_infections.py +++ b/pyrenew/latent/population_infections.py @@ -24,7 +24,7 @@ class PopulationInfections(BaseLatentInfectionProcess): """ - A single $\ mathcal{R}(t)$ trajectory drives one renewal equation. + A single $\\mathcal{R}(t)$ trajectory drives one renewal equation. The constructor specifies model structure (priors, temporal processes). """ @@ -55,9 +55,9 @@ def __init__( I0_rv Initial infection prevalence (proportion of population) single_rt_process - Temporal process for single Rt dynamics + Temporal process for single $\\mathcal{R}(t)$ dynamics log_rt_time_0_rv - Initial value for log(Rt) at time 0. + Initial value for log($\\mathcal{R}(t)$) at time 0. Raises ------ @@ -152,7 +152,7 @@ def sample( """ Sample population infections using a single renewal process. - Generates a single Rt trajectory, computes initial infections via + Generates a single $\\mathcal{R}(t)$ trajectory, computes initial infections via exponential backprojection, and runs one renewal equation. Parameters diff --git a/pyrenew/latent/subpopulation_infections.py b/pyrenew/latent/subpopulation_infections.py index 8eb43067..613f75f8 100644 --- a/pyrenew/latent/subpopulation_infections.py +++ b/pyrenew/latent/subpopulation_infections.py @@ -26,15 +26,15 @@ class SubpopulationInfections(BaseLatentInfectionProcess): """ - Multi-subpopulation renewal model with hierarchical Rt structure. + Multi-subpopulation renewal model with hierarchical $\\mathcal{R}(t)$ structure. - Each subpopulation has its own renewal equation with Rt deviating from a + Each subpopulation has its own renewal equation with $\\mathcal{R}(t)$ deviating from a shared baseline. Suitable when transmission dynamics vary substantially across subpopulations. Mathematical form: - - Baseline Rt: log[R_baseline(t)] ~ TemporalProcess - - Subpopulation Rt: log R_k(t) = log[R_baseline(t)] + delta_k(t) + - Baseline $\\mathcal{R}(t)$: log[R_baseline(t)] ~ TemporalProcess + - Subpopulation $\\mathcal{R}(t)$: log R_k(t) = log[R_baseline(t)] + delta_k(t) - Deviations: delta_k(t) ~ TemporalProcess with sum-to-zero constraint - Renewal per subpop: I_k(t) = R_k(t) * sum_tau I_k(t-tau) * g(tau) - Aggregate total: I_aggregate(t) = sum_k p_k * I_k(t) @@ -46,7 +46,7 @@ class SubpopulationInfections(BaseLatentInfectionProcess): Notes ----- Sum-to-zero constraint on deviations ensures R_baseline(t) is the geometric - mean of subpopulation Rt values, providing identifiability. + mean of subpopulation $\\mathcal{R}(t)$ values, providing identifiability. """ def __init__( @@ -76,11 +76,11 @@ def __init__( I0_rv Initial infection prevalence (proportion of population) baseline_rt_process - Temporal process for baseline Rt dynamics + Temporal process for baseline $\\mathcal{R}(t)$ dynamics subpop_rt_deviation_process Temporal process for subpopulation deviations log_rt_time_0_rv - Initial value for log(Rt) at time 0. + Initial value for log($\\mathcal{R}(t)$) at time 0. Raises ------ @@ -167,7 +167,7 @@ def sample( """ Sample hierarchical infections for all subpopulations. - Generates baseline Rt, subpopulation deviations with sum-to-zero + Generates baseline $\\mathcal{R}(t)$, subpopulation deviations with sum-to-zero constraint, initial infections, and runs n_subpops independent renewal processes. Parameters From de9e80a3cb5b270f90093595cf6abd9e5e3bcb02 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 13 Apr 2026 12:05:30 -0400 Subject: [PATCH 6/7] docstring comment cleanup --- pyrenew/model/multisignal_model.py | 6 +++--- pyrenew/model/pyrenew_builder.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index b2ec8768..8a124348 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -20,9 +20,9 @@ class MultiSignalModel(Model): """ Multi-signal renewal model. - Combines a latent infection process (e.g., SubpopulationInfections, - PartitionedInfections) with multiple observation processes (e.g., - CountObservation, WastewaterObservation). + Combines a latent infection process (e.g., SubpopulationInfections) + with multiple observation processes (e.g., CountObservation, + WastewaterObservation). Built via PyrenewBuilder to ensure n_initialization_points is computed correctly from all components. Can also be constructed manually for diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index a8d30fc0..e05d146a 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -61,8 +61,8 @@ def configure_latent( Parameters ---------- latent_class - Class to use for latent infections (e.g., SubpopulationInfections, - PartitionedInfections, or a custom implementation) + Class to use for latent infections (e.g., PopulationInfections, + SubpopulationInfections, or a custom implementation) **params Parameters for latent class constructor (model structure). DO NOT include n_initialization_points - it will be computed From f0de42f3167e4e0ab6a893fd0aec51634b6fa47d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 13 Apr 2026 14:14:51 -0400 Subject: [PATCH 7/7] docs cleanup --- docs/tutorials/.pages | 4 ++-- docs/tutorials/latent_infections.qmd | 4 ++-- docs/tutorials/latent_subpopulation_infections.qmd | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 83759923..9780c954 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -1,9 +1,9 @@ nav: - random_variables.md - building_multisignal_models.md - - observation_processes_counts.md - - observation_processes_measurements.md - latent_infections.md - latent_subpopulation_infections.md + - observation_processes_counts.md + - observation_processes_measurements.md - right_truncation.md - day_of_week_effects.md diff --git a/docs/tutorials/latent_infections.qmd b/docs/tutorials/latent_infections.qmd index a1ad5766..b46a9a51 100644 --- a/docs/tutorials/latent_infections.qmd +++ b/docs/tutorials/latent_infections.qmd @@ -145,7 +145,7 @@ $$I_{\text{init}}(\tau) = I_0 \cdot e^{r \cdot \tau}, \quad \tau = 0, 1, where $r$ is the asymptotic growth rate implied by the reproduction number at the start of the observation period, $\mathcal{R}(t=0) = -e^{\text{initial\_log\_rt}}$, and the generation interval. The function +e^{\text{log\_rt\_time\_0}}$, and the generation interval. The function `r_approx_from_R` converts $\mathcal{R}(t=0)$ and the generation interval into $r$ using Newton's method. @@ -158,7 +158,7 @@ entries growing or declining exponentially toward $t = 0$. * **The shape is set by `log_rt_time_0`**.
`log_rt_time_0` enters the model in two places: it is the starting point of the $\mathcal{R}(t)$ trajectory ($\mathcal{R}(t=0) = -e^{\text{initial\_log\_rt}}$), and it determines the exponential growth rate +e^{\text{log\_rt\_time\_0}}$), and it determines the exponential growth rate $r$ used to construct the initialization vector. When `log_rt_time_0 = 0`, $r = 0$ and the initialization vector is flat at level `I0`. When `log_rt_time_0 > 0`, infections are growing exponentially at $t = 0$; when diff --git a/docs/tutorials/latent_subpopulation_infections.qmd b/docs/tutorials/latent_subpopulation_infections.qmd index 52397a9e..0839764b 100644 --- a/docs/tutorials/latent_subpopulation_infections.qmd +++ b/docs/tutorials/latent_subpopulation_infections.qmd @@ -112,7 +112,7 @@ This model generalizes the single-population renewal model: When all $\delta_k(t) = 0$, the model reduces to the shared (single-population) case. -This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `initial_log_rt`), and temporal processes. See [Latent Infections](latent_infections.md) for that background. +This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `log_rt_time_0`), and temporal processes. See [Latent Infections](latent_infections.md) for that background. ---