diff --git a/docs/index.md b/docs/index.md
index 1d45c148..ed973294 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -99,9 +99,10 @@ pip install git+https://github.com/CDCgov/PyRenew@main
- [The RandomVariable abstract base class](tutorials/random_variables.md) -- PyRenew's core abstraction and its concrete implementations.
- [Building multi-signal models](tutorials/building_multisignal_models.md) -- composing a renewal model from PyRenew components using `PyrenewBuilder`.
+- [Latent infections](tutorials/latent_infections.md) -- modeling latent infection trajectories over time.
+- [Latent subpopulation infections](tutorials/latent_subpopulation_infections.md) -- modeling latent infections with subpopulation structure.
- [Observation processes: count data](tutorials/observation_processes_counts.md) -- connecting latent infections to observed counts.
- [Observation processes: measurements](tutorials/observation_processes_measurements.md) -- connecting latent infections to continuous measurements.
-- [Latent hierarchical infections](tutorials/latent_hierarchical_infections.md) -- modeling infections with subpopulation structure.
## Resources
@@ -115,4 +116,4 @@ pip install git+https://github.com/CDCgov/PyRenew@main
### Further reading
- [Semi-mechanistic Bayesian modelling of COVID-19 with renewal processes](https://academic.oup.com/jrsssa/article-pdf/186/4/601/54770289/qnad030.pdf) (Bhatt et al., 2023)
-- [Unifying incidence and prevalence under a time-varying general branching process](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf)
+- [Unifying incidence and prevalence under a time-varying general branching process](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf) (Pakkanen et al., 2023)
diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages
index 29657ec4..9780c954 100644
--- a/docs/tutorials/.pages
+++ b/docs/tutorials/.pages
@@ -1,9 +1,9 @@
nav:
- random_variables.md
- building_multisignal_models.md
+ - latent_infections.md
+ - latent_subpopulation_infections.md
- observation_processes_counts.md
- observation_processes_measurements.md
- - latent_infections.md
- - latent_hierarchical_infections.md
- right_truncation.md
- day_of_week_effects.md
diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd
index 7d67e262..06fa2cbe 100644
--- a/docs/tutorials/building_multisignal_models.qmd
+++ b/docs/tutorials/building_multisignal_models.qmd
@@ -64,7 +64,7 @@ from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable
from pyrenew.latent import (
- HierarchicalInfections,
+ SubpopulationInfections,
AR1,
RandomWalk,
GammaGroupSdPrior,
@@ -97,7 +97,7 @@ A **multi-signal model** combines multiple observation processes—each represen
The `PyrenewBuilder` class handles this plumbing. You specify:
-1. A single **latent process** (e.g., `HierarchicalInfections`) that defines how infections evolve.
+1. A single **latent process** (e.g., `SubpopulationInfections`) that defines how infections evolve.
2. One or more **observation processes** (e.g., `Counts`, `Measurements`) that define how infections become data.
The builder computes initialization requirements, wires components together, and produces a model ready for inference.
@@ -106,7 +106,7 @@ The builder computes initialization requirements, wires components together, and
Before diving into multi-signal models, you may want to review these foundational tutorials:
-- **[Hierarchical Latent Infections](latent_hierarchical_infections.md)**: Understanding temporal process choices for $\mathcal{R}(t)$
+- **[Latent Infections](latent_infections.md)** and **[Latent Subpopulation Infections](latent_subpopulation_infections.md)**: Understanding temporal process choices for $\mathcal{R}(t)$
- **[Observation Processes: Counts](observation_processes_counts.md)**: Modeling count data (admissions, deaths)
- **[Observation Processes: Measurements](observation_processes_measurements.md)**: Modeling continuous data (wastewater)
@@ -116,7 +116,7 @@ This tutorial shows how to combine these components into a complete multi-signal
This tutorial demonstrates building a multi-signal renewal model using:
-- `HierarchicalInfections` — subpopulations share a jurisdiction-level baseline $\mathcal{R}(t)$ with subpopulation-specific deviations
+- `SubpopulationInfections` — subpopulations share a jurisdiction-level baseline $\mathcal{R}(t)$ with subpopulation-specific deviations
- `Counts` — hospital admissions (jurisdiction-level)
- A custom `Wastewater` class — viral concentrations (subpopulation-level)
@@ -133,7 +133,7 @@ The diagram below shows how data flows through the model. The latent process gen
flowchart TB
subgraph Latent["Latent Infection Process"]
- L["Renewal equation
(HierarchicalInfections)"]
+ L["Renewal equation
(SubpopulationInfections)"]
end
subgraph Infections["Infection Trajectories"]
@@ -208,14 +208,14 @@ I0_rv = DistributionalVariable("I0", dist.Beta(1, 100))
-### Initial Log Rt
+### Log Rt at time $0$
-We place a prior on the initial log(Rt), centered at 0.0 (Rt = 1.0) with moderate uncertainty:
+We place a prior on the log(Rt) at time $0$, centered at 0.0 (Rt = 1.0) with moderate uncertainty:
```{python}
-# | label: initial-log-rt
-initial_log_rt_rv = DistributionalVariable(
- "initial_log_rt", dist.Normal(0.0, 0.5)
+# | label: log-rt-time-0
+log_rt_time_0_rv = DistributionalVariable(
+ "log_rt_time_0", dist.Normal(0.0, 0.5)
)
```
@@ -530,10 +530,10 @@ print("Latent process configuration:")
print(f" Generation interval length: {len(gen_int_rv())} days")
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int_rv,
I0_rv=I0_rv,
- initial_log_rt_rv=initial_log_rt_rv,
+ log_rt_time_0_rv=log_rt_time_0_rv,
baseline_rt_process=baseline_rt_process,
subpop_rt_deviation_process=subpop_rt_deviation_process,
)
@@ -824,11 +824,11 @@ idata_90 = az.from_numpyro(
model.mcmc,
dims={
"latent_infections": ["time"],
- "HierarchicalInfections::infections_aggregate": ["time"],
- "HierarchicalInfections::log_rt_baseline": ["time", "dummy"],
- "HierarchicalInfections::rt_baseline": ["time", "dummy"],
- "HierarchicalInfections::rt_subpop": ["time", "subpop"],
- "HierarchicalInfections::subpop_deviations": ["time", "subpop"],
+ "SubpopulationInfections::infections_aggregate": ["time"],
+ "SubpopulationInfections::log_rt_baseline": ["time", "dummy"],
+ "SubpopulationInfections::rt_baseline": ["time", "dummy"],
+ "SubpopulationInfections::rt_subpop": ["time", "subpop"],
+ "SubpopulationInfections::subpop_deviations": ["time", "subpop"],
"latent_infections_by_subpop": ["time", "subpop"],
"hospital_predicted": ["time"],
"wastewater_predicted": ["time", "subpop"],
@@ -973,11 +973,11 @@ idata_180 = az.from_numpyro(
model.mcmc,
dims={
"latent_infections": ["time"],
- "HierarchicalInfections::infections_aggregate": ["time"],
- "HierarchicalInfections::log_rt_baseline": ["time", "dummy"],
- "HierarchicalInfections::rt_baseline": ["time", "dummy"],
- "HierarchicalInfections::rt_subpop": ["time", "subpop"],
- "HierarchicalInfections::subpop_deviations": ["time", "subpop"],
+ "SubpopulationInfections::infections_aggregate": ["time"],
+ "SubpopulationInfections::log_rt_baseline": ["time", "dummy"],
+ "SubpopulationInfections::rt_baseline": ["time", "dummy"],
+ "SubpopulationInfections::rt_subpop": ["time", "subpop"],
+ "SubpopulationInfections::subpop_deviations": ["time", "subpop"],
"latent_infections_by_subpop": ["time", "subpop"],
"hospital_predicted": ["time"],
"wastewater_predicted": ["time", "subpop"],
@@ -1139,6 +1139,6 @@ This tutorial demonstrated composing a multi-signal renewal model using `Pyrenew
### Next Steps
-- Explore different temporal processes for $\mathcal{R}(t)$ in the [Hierarchical Latent Infections](latent_hierarchical_infections.md) tutorial
+- Explore different temporal processes for $\mathcal{R}(t)$ in the [Latent Infections](latent_infections.md) and [Latent Subpopulation Infections](latent_subpopulation_infections.md) tutorials
- Learn about count-based observation models in [Observation Processes: Counts](observation_processes_counts.md)
- Learn about continuous measurement models in [Observation Processes: Measurements](observation_processes_measurements.md)
diff --git a/docs/tutorials/latent_infections.qmd b/docs/tutorials/latent_infections.qmd
index b06e62d3..b46a9a51 100644
--- a/docs/tutorials/latent_infections.qmd
+++ b/docs/tutorials/latent_infections.qmd
@@ -3,12 +3,6 @@ title: Latent Infection Processes
format:
gfm:
code-fold: true
- html:
- toc: true
- embed-resources: true
- self-contained-math: true
- code-fold: true
- code-tools: true
engine: jupyter
jupyter:
jupytext:
@@ -44,7 +38,12 @@ from _tutorial_theme import theme_tutorial
```{python}
# | label: imports
-from pyrenew.latent import SharedInfections, AR1, DifferencedAR1, RandomWalk
+from pyrenew.latent import (
+ PopulationInfections,
+ AR1,
+ DifferencedAR1,
+ RandomWalk,
+)
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.randomvariable import DistributionalVariable
```
@@ -72,17 +71,17 @@ Here, $\tau$ indexes lags in the generation interval.
PyRenew provides two latent infection classes:
-- **`SharedInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. Appropriate when modeling one jurisdiction as a single population with one or more observation streams. This tutorial covers `SharedInfections`.
-- **`HierarchicalInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Hierarchical Latent Infections](latent_hierarchical_infections.md).
+- **`PopulationInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. Appropriate when modeling one jurisdiction as a single population with one or more observation streams. This tutorial covers `PopulationInfections`.
+- **`SubpopulationInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Latent Subpopulation Infections](latent_subpopulation_infections.md).
## Model Inputs
-`SharedInfections` requires four inputs:
+`PopulationInfections` requires four inputs:
1. Generation interval distribution $w_\tau$ (`gen_int_rv`)
2. Infection prevalence at time $0$ as a proportion of the population (`I0_rv`)
-3. Value for $log(\mathcal{R}(t))$ at time $0$ (`initial_log_rt_rv`)
-4. A temporal process for $\mathcal{R}(t)$ dynamics (`shared_rt_process`)
+3. Value for $log(\mathcal{R}(t))$ at time $0$ (`log_rt_time_0_rv`)
+4. A temporal process for $\mathcal{R}(t)$ dynamics (`single_rt_process`)
All inputs are **RandomVariables**, a quantity that is either known (observed, conditioned on) or unknown (to be inferred). See [PyRenew's RandomVariable abstract base class](random_variables.md). In this tutorial, we use `DeterministicVariable` and `DeterministicPMF` (fixed values) for illustration. In real inference, you would use `DistributionalVariable` with priors for quantities you want to estimate:
@@ -127,7 +126,7 @@ gi_df = pd.DataFrame({"day": days, "probability": np.array(gen_int_pmf)})
The generation interval length determines the minimum initialization period: with a $G$-point generation interval distribution, the renewal equation at time $0$ needs an infection-history vector long enough to supply the previous $G$ infection values used in the convolution.
-### Initial Conditions: `I0` and `initial_log_rt`
+### Initial Conditions: `I0` and `log_rt_time_0`
These two parameters jointly define the infection history before the observation
period begins. Understanding their interaction requires knowing how the latent
@@ -146,7 +145,7 @@ $$I_{\text{init}}(\tau) = I_0 \cdot e^{r \cdot \tau}, \quad \tau = 0, 1,
where $r$ is the asymptotic growth rate implied by the reproduction number at
the start of the observation period, $\mathcal{R}(t=0) =
-e^{\text{initial\_log\_rt}}$, and the generation interval. The function
+e^{\text{log\_rt\_time\_0}}$, and the generation interval. The function
`r_approx_from_R` converts $\mathcal{R}(t=0)$ and the generation interval into
$r$ using Newton's method.
@@ -156,17 +155,17 @@ period, $n_{\text{init}} - 1$ time points before $t = 0$. It sets the scale of
the entire initialization vector: $I_{\text{init}}(0) = I_0$, with subsequent
entries growing or declining exponentially toward $t = 0$.
-* **The shape is set by `initial_log_rt`**.
-`initial_log_rt` enters the model in two places: it is the starting point of
+* **The shape is set by `log_rt_time_0`**.
+`log_rt_time_0` enters the model in two places: it is the starting point of
the $\mathcal{R}(t)$ trajectory ($\mathcal{R}(t=0) =
-e^{\text{initial\_log\_rt}}$), and it determines the exponential growth rate
-$r$ used to construct the initialization vector. When `initial_log_rt = 0`,
+e^{\text{log\_rt\_time\_0}}$), and it determines the exponential growth rate
+$r$ used to construct the initialization vector. When `log_rt_time_0 = 0`,
$r = 0$ and the initialization vector is flat at level `I0`. When
-`initial_log_rt > 0`, infections are growing exponentially at $t = 0$; when
-`initial_log_rt < 0`, they are declining.
+`log_rt_time_0 > 0`, infections are growing exponentially at $t = 0$; when
+`log_rt_time_0 < 0`, they are declining.
-The initialization vector is what the renewal equation "sees" as recent infection history at time $0$. We can compute it directly for three values of `initial_log_rt`:
+The initialization vector is what the renewal equation "sees" as recent infection history at time $0$. We can compute it directly for three values of `log_rt_time_0`:
```{python}
# | label: backprojection-compute
@@ -191,11 +190,11 @@ for label, log_rt in init_rt_values.items():
{
"day": int(t) - n_init,
"infections": float(I0_init[t]),
- "config": f"initial_log_rt = {log_rt} ({label})",
+ "config": f"log_rt_time_0 = {log_rt} ({label})",
}
)
print(
- f"initial_log_rt = {log_rt:+.1f}: Rt(0) = {float(Rt0):.2f}, r = {float(r):.4f}, "
+ f"log_rt_time_0 = {log_rt:+.1f}: Rt(0) = {float(Rt0):.2f}, r = {float(r):.4f}, "
f"I_init range = [{float(I0_init[0]):.6f}, {float(I0_init[-1]):.6f}]"
)
@@ -204,7 +203,7 @@ init_df = pd.DataFrame(init_data)
```{python}
# | label: fig-backprojection
-# | fig-cap: Initialization vectors for three values of initial_log_rt. Days are numbered relative to day 0, which is when the temporal process and renewal equation take over. When initial_log_rt = 0 (stable), the vector is flat. Nonzero values produce exponential growth or decay in the pre-observation period.
+# | fig-cap: Initialization vectors for three values of log_rt_time_0. Days are numbered relative to day 0, which is when the temporal process and renewal equation take over. When log_rt_time_0 = 0 (stable), the vector is flat. Nonzero values produce exponential growth or decay in the pre-observation period.
(
p9.ggplot(init_df, p9.aes(x="day", y="infections", color="config"))
+ p9.geom_line(size=1)
@@ -233,7 +232,7 @@ After day 0, the temporal process takes over. How quickly the trajectory departs
The temporal process governs how $\log \mathcal{R}(t)$ evolves day to day.
To evaluate what a given process implies, we use **prior predictive checks**: drawing many samples from the model *before seeing any data* and examining the distribution of trajectories. A single sample tells you little (the trajectory depends on the random seed), but the envelope of many samples reveals the structural constraints built into the process.
-We fix the initial conditions to a growing epidemic (`initial_log_rt = 0.5`, so $\mathcal{R}(0) \approx 1.65$) with `I0 = 0.001`. Starting well above equilibrium rather than near it makes the behavioral differences between temporal processes visible: the median trajectory of a mean-reverting process drifts back toward $\mathcal{R} = 1$, while a non-reverting process does not.
+We fix the initial conditions to a growing epidemic (`log_rt_time_0 = 0.5`, so $\mathcal{R}(0) \approx 1.65$) with `I0 = 0.001`. Starting well above equilibrium rather than near it makes the behavioral differences between temporal processes visible: the median trajectory of a mean-reverting process drifts back toward $\mathcal{R} = 1$, while a non-reverting process does not.
This section is primarily **modeling guidance for prior specification in PyRenew**, not a claim that epidemiologic theory uniquely determines one temporal process choice. The appropriate process depends on the scientific setting, time horizon, and how strongly you want the prior to regularize latent transmission dynamics.
@@ -242,21 +241,19 @@ This section is primarily **modeling guidance for prior specification in PyRenew
n_days = 28
n_init = len(gen_int_pmf)
n_samples = 200
-initial_log_rt = 0.5
+log_rt_time_0 = 0.5
I0_val = 0.001
rt_cap = 3.0
def sample_process(rt_process, label):
"""Draw prior predictive samples for a given temporal process."""
- model = SharedInfections(
- name="SharedInfections",
+ model = PopulationInfections(
+ name="PopulationInfections",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", I0_val),
- initial_log_rt_rv=DeterministicVariable(
- "initial_log_rt", initial_log_rt
- ),
- shared_rt_process=rt_process,
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", log_rt_time_0),
+ single_rt_process=rt_process,
n_initialization_points=n_init,
)
@@ -266,9 +263,11 @@ def sample_process(rt_process, label):
samples = Predictive(sampler, num_samples=n_samples)(random.PRNGKey(42))
return {
- "rt": np.array(samples["SharedInfections::rt_shared"])[:, n_init:, 0],
+ "rt": np.array(samples["PopulationInfections::rt_single"])[
+ :, n_init:, 0
+ ],
"infections": np.array(
- samples["SharedInfections::infections_aggregate"]
+ samples["PopulationInfections::infections_aggregate"]
)[:, n_init:],
}
@@ -686,7 +685,7 @@ The latent infection trajectory is not observed directly. To connect it to data,
The `PyrenewBuilder` handles the wiring:
-1. **`configure_latent()`** sets the shared infection process (called once)
+1. **`configure_latent()`** sets the single infection process (called once)
2. **`add_observation()`** adds an observation process (called once per data stream)
3. **`build()`** computes `n_initialization_points` from all delay distributions and produces a model ready for inference
diff --git a/docs/tutorials/latent_hierarchical_infections.qmd b/docs/tutorials/latent_subpopulation_infections.qmd
similarity index 90%
rename from docs/tutorials/latent_hierarchical_infections.qmd
rename to docs/tutorials/latent_subpopulation_infections.qmd
index 8ef3092d..0839764b 100644
--- a/docs/tutorials/latent_hierarchical_infections.qmd
+++ b/docs/tutorials/latent_subpopulation_infections.qmd
@@ -1,5 +1,5 @@
---
-title: Hierarchical Latent Infections
+title: Latent Subpopulation Infections
format:
gfm:
code-fold: true
@@ -38,7 +38,7 @@ from _tutorial_theme import theme_tutorial
```{python}
# | label: imports
from pyrenew.latent import (
- HierarchicalInfections,
+ SubpopulationInfections,
AR1,
DifferencedAR1,
RandomWalk,
@@ -49,7 +49,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
## Overview
-`HierarchicalInfections` extends the renewal model to a population composed of $K$ subpopulations, each with its own latent infection trajectory.
+`SubpopulationInfections` extends the renewal model to a population composed of $K$ subpopulations, each with its own latent infection trajectory.
As in the single-population case, infections evolve according to the renewal equation. For each subpopulation $k = 1, \dots, K$, we define
@@ -107,24 +107,24 @@ The aggregate trajectory $I_{\text{aggregate}}(t)$ is then passed to observation
This model generalizes the single-population renewal model:
-- `SharedInfections`: one trajectory $I(t)$ with one $\mathcal{R}(t)$
-- `HierarchicalInfections`: multiple trajectories $I_k(t)$ with a shared baseline $\mathcal{R}_{\text{baseline}}(t)$ and subpopulation deviations $\delta_k(t)$
+- `PopulationInfections`: one trajectory $I(t)$ with one $\mathcal{R}(t)$
+- `SubpopulationInfections`: multiple trajectories $I_k(t)$ with a shared baseline $\mathcal{R}_{\text{baseline}}(t)$ and subpopulation deviations $\delta_k(t)$
When all $\delta_k(t) = 0$, the model reduces to the shared (single-population) case.
-This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `initial_log_rt`), and temporal processes. See [Latent Infection Processes](latent_infections.md) for that background.
+This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `log_rt_time_0`), and temporal processes. See [Latent Infections](latent_infections.md) for that background.
---
## Model Structure
-`HierarchicalInfections` uses the same core inputs as `SharedInfections`, with additional structure for subpopulation variation.
+`SubpopulationInfections` uses the same core inputs as `PopulationInfections`, with additional structure for subpopulation variation.
### Core inputs (shared across subpopulations)
- **`gen_int_rv`**: Generation interval distribution $w_\tau$
- **`I0_rv`**: Initial infection level (shared by default across subpopulations, but can be specified per subpopulation)
-- **`initial_log_rt_rv`**: Initial value of $\log \mathcal{R}_{\text{baseline}}(t)$ at time $0$
+- **`log_rt_time_0_rv`**: Initial value of $\log \mathcal{R}_{\text{baseline}}(t)$ at time $0$
### Temporal processes
@@ -173,22 +173,22 @@ n_subpops = len(subpop_fractions)
n_days = 28
n_init = len(gen_int_pmf)
n_samples = 200
-initial_log_rt = 0.2
+log_rt_time_0 = 0.2
I0_val = 0.001
rt_cap = 3.0
print(f"Subpopulations: {n_subpops}")
print(f"Population fractions: {np.array(subpop_fractions)}")
-print(f"Initial Rt: {np.exp(initial_log_rt):.2f}")
+print(f"Log Rt at time 0: {np.exp(log_rt_time_0):.2f}")
```
```{python}
# | label: instantiate
-model = HierarchicalInfections(
- name="HierarchicalInfections",
+model = SubpopulationInfections(
+ name="SubpopulationInfections",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", I0_val),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", initial_log_rt),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", log_rt_time_0),
baseline_rt_process=DifferencedAR1(autoreg=0.5, innovation_sd=0.01),
subpop_rt_deviation_process=AR1(autoreg=0.8, innovation_sd=0.05),
n_initialization_points=n_init,
@@ -197,7 +197,7 @@ model = HierarchicalInfections(
## The Sum-to-Zero Constraint
-Without a constraint, the baseline and deviations are not identifiable: shifting the baseline up by some amount $c$ and all deviations down by $c$ produces the same subpopulation $\mathcal{R}_k(t)$ values. `HierarchicalInfections` enforces $\sum_k \delta_k(t) = 0$ at every time step by centering the raw deviation trajectories.
+Without a constraint, the baseline and deviations are not identifiable: shifting the baseline up by some amount $c$ and all deviations down by $c$ produces the same subpopulation $\mathcal{R}_k(t)$ values. `SubpopulationInfections` enforces $\sum_k \delta_k(t) = 0$ at every time step by centering the raw deviation trajectories.
This ensures $\mathcal{R}_{\text{baseline}}(t)$ is the *unweighted* geometric mean of the subpopulation $\mathcal{R}_k(t)$ values, so the baseline represents the typical transmission level across subpopulations.
Note that this is the unweighted mean across subpopulations, not population-weighted by $p_k$. As a result, $\mathcal{R}_{\text{baseline}}(t)$ does **not** in general equal the jurisdiction-level reproduction number implied by the aggregate infection trajectory $I_{\text{aggregate}}(t)$. When subpopulations differ in size, a small subpopulation with a large $\mathcal{R}_k(t)$ contributes equally to the baseline but only marginally to the aggregate.
@@ -211,7 +211,7 @@ with numpyro.handlers.seed(rng_seed=42):
subpop_fractions=subpop_fractions,
)
-deviations = trace["HierarchicalInfections::subpop_deviations"]["value"]
+deviations = trace["SubpopulationInfections::subpop_deviations"]["value"]
deviation_sums = jnp.sum(deviations, axis=1)
print(f"Deviation shape: {deviations.shape} (n_total_days, n_subpops)")
print(
@@ -222,9 +222,9 @@ print(
## Choosing the Baseline Temporal Process
The baseline process governs the jurisdiction-level trend in $\mathcal{R}(t)$.
-The same temporal process options apply as in `SharedInfections` (see [Temporal Process Choice](latent_infections.md#temporal-process-choice)): AR(1) for mean reversion, DifferencedAR1 for persistent trends with stabilizing rate of change, RandomWalk for unconstrained drift.
+The same temporal process options apply as in `PopulationInfections` (see [Temporal Process Choice](latent_infections.md#temporal-process-choice)): AR(1) for mean reversion, DifferencedAR1 for persistent trends with stabilizing rate of change, RandomWalk for unconstrained drift.
-We use DifferencedAR1 with small `innovation_sd` for the baseline throughout this tutorial. The prior predictive for baseline $\mathcal{R}(t)$ behaves much like the corresponding `SharedInfections` example, so we do not repeat that full comparison here. The focus of this tutorial is the deviation process and how it changes the spread of subpopulation trajectories around the baseline.
+We use DifferencedAR1 with small `innovation_sd` for the baseline throughout this tutorial. The prior predictive for baseline $\mathcal{R}(t)$ behaves much like the corresponding `PopulationInfections` example, so we do not repeat that full comparison here. The focus of this tutorial is the deviation process and how it changes the spread of subpopulation trajectories around the baseline.
The same high-level decision rules apply here:
@@ -251,14 +251,12 @@ This determines whether subpopulations quickly return to the baseline or can div
```{python}
# | label: helpers
def sample_hierarchical(baseline_process, deviation_process, label):
- """Draw prior predictive samples from a HierarchicalInfections model."""
- m = HierarchicalInfections(
- name="HierarchicalInfections",
+ """Draw prior predictive samples from a SubpopulationInfections model."""
+ m = SubpopulationInfections(
+ name="SubpopulationInfections",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", I0_val),
- initial_log_rt_rv=DeterministicVariable(
- "initial_log_rt", initial_log_rt
- ),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", log_rt_time_0),
baseline_rt_process=baseline_process,
subpop_rt_deviation_process=deviation_process,
n_initialization_points=n_init,
@@ -274,16 +272,16 @@ def sample_hierarchical(baseline_process, deviation_process, label):
samples = Predictive(sampler, num_samples=n_samples)(random.PRNGKey(42))
return {
"rt_baseline": np.array(
- samples["HierarchicalInfections::rt_baseline"]
+ samples["SubpopulationInfections::rt_baseline"]
)[:, n_init:, 0],
- "rt_subpop": np.array(samples["HierarchicalInfections::rt_subpop"])[
+ "rt_subpop": np.array(samples["SubpopulationInfections::rt_subpop"])[
:, n_init:, :
],
"deviations": np.array(
- samples["HierarchicalInfections::subpop_deviations"]
+ samples["SubpopulationInfections::subpop_deviations"]
)[:, n_init:, :],
"infections": np.array(
- samples["HierarchicalInfections::infections_aggregate"]
+ samples["SubpopulationInfections::infections_aggregate"]
)[:, n_init:],
}
@@ -461,7 +459,7 @@ for label, s in [("AR(1)", ar1_dev_samples), ("RandomWalk", rw_dev_samples)]:
pd.DataFrame(deviation_summary_rows)
```
-AR(1) deviations stay close to zero because mean reversion continuously pulls them back. RandomWalk deviations accumulate over time. As in the SharedInfections tutorial, this is prior-modeling guidance rather than a uniquely determined epidemiologic rule. The choice depends on the epidemiological setting:
+AR(1) deviations stay close to zero because mean reversion continuously pulls them back. RandomWalk deviations accumulate over time. As in the PopulationInfections tutorial, this is prior-modeling guidance rather than a uniquely determined epidemiologic rule. The choice depends on the epidemiological setting:
- **AR(1) deviations** when subpopulations are expected to track the jurisdiction average. Local outbreaks or lulls are temporary. This is typical for geographically close subpopulations (e.g., counties within a metropolitan area) where mobility mixes transmission.
- **RandomWalk deviations** when local differences can persist. This may be appropriate for subpopulations with distinct contact patterns, demographics, or intervention histories (e.g., urban vs. rural areas).
diff --git a/pyrenew/latent/__init__.py b/pyrenew/latent/__init__.py
index 8c428579..45f6f06b 100644
--- a/pyrenew/latent/__init__.py
+++ b/pyrenew/latent/__init__.py
@@ -5,7 +5,6 @@
LatentSample,
PopulationStructure,
)
-from pyrenew.latent.hierarchical_infections import HierarchicalInfections
from pyrenew.latent.hierarchical_priors import (
GammaGroupSdPrior,
HierarchicalNormalPrior,
@@ -27,7 +26,8 @@
)
from pyrenew.latent.infections import Infections
from pyrenew.latent.infectionswithfeedback import InfectionsWithFeedback
-from pyrenew.latent.shared_infections import SharedInfections
+from pyrenew.latent.population_infections import PopulationInfections
+from pyrenew.latent.subpopulation_infections import SubpopulationInfections
from pyrenew.latent.temporal_processes import (
AR1,
DifferencedAR1,
@@ -50,10 +50,8 @@
"BaseLatentInfectionProcess",
"LatentSample",
"PopulationStructure",
- # Hierarchical infection processes
- "HierarchicalInfections",
- # Shared infection processes
- "SharedInfections",
+ "SubpopulationInfections",
+ "PopulationInfections",
# Hierarchical priors
"HierarchicalNormalPrior",
"GammaGroupSdPrior",
diff --git a/pyrenew/latent/base.py b/pyrenew/latent/base.py
index 4d42308a..c84d393a 100644
--- a/pyrenew/latent/base.py
+++ b/pyrenew/latent/base.py
@@ -300,9 +300,9 @@ def _validate_and_prepare_I0(
Validate and prepare I0 for use in the renewal equation.
Subclasses override this to enforce shape constraints (e.g., scalar
- for SharedInfections) or broadcast I0 to match the population
+ for PopulationInfections) or broadcast I0 to match the population
structure (e.g., scalar to per-subpop array for
- HierarchicalInfections).
+ SubpopulationInfections).
Parameters
----------
diff --git a/pyrenew/latent/shared_infections.py b/pyrenew/latent/population_infections.py
similarity index 77%
rename from pyrenew/latent/shared_infections.py
rename to pyrenew/latent/population_infections.py
index ac7ec1df..1d1ce1cf 100644
--- a/pyrenew/latent/shared_infections.py
+++ b/pyrenew/latent/population_infections.py
@@ -1,5 +1,5 @@
"""
-Shared latent infection process renewal model.
+Populaton-level single-trajectory latent infection process renewal model.
"""
from __future__ import annotations
@@ -22,9 +22,9 @@
from pyrenew.metaclass import RandomVariable
-class SharedInfections(BaseLatentInfectionProcess):
+class PopulationInfections(BaseLatentInfectionProcess):
"""
- A single $\ mathcal{R}(t)$ trajectory drives one renewal equation.
+ A single $\\mathcal{R}(t)$ trajectory drives one renewal equation.
The constructor specifies model structure (priors, temporal processes).
"""
@@ -36,28 +36,28 @@ def __init__(
gen_int_rv: RandomVariable,
n_initialization_points: int,
I0_rv: RandomVariable,
- shared_rt_process: TemporalProcess,
- initial_log_rt_rv: RandomVariable,
+ single_rt_process: TemporalProcess,
+ log_rt_time_0_rv: RandomVariable,
) -> None:
"""
- Initialize shared infections process.
+ Initialize population-level infections process.
Parameters
----------
name
Name prefix for numpyro sample sites. All deterministic
quantities are recorded under this scope (e.g.,
- ``"{name}::rt_shared"``).
+ ``"{name}::rt_single"``).
gen_int_rv
Generation interval PMF
n_initialization_points
Number of initialization days before day 0.
I0_rv
Initial infection prevalence (proportion of population)
- shared_rt_process
- Temporal process for shared Rt dynamics
- initial_log_rt_rv
- Initial value for log(Rt) at time 0.
+ single_rt_process
+ Temporal process for single $\\mathcal{R}(t)$ dynamics
+ log_rt_time_0_rv
+ Initial value for log($\\mathcal{R}(t)$) at time 0.
Raises
------
@@ -77,13 +77,13 @@ def __init__(
if isinstance(I0_rv, DeterministicVariable):
self._validate_I0(I0_rv.value)
- if initial_log_rt_rv is None:
- raise ValueError("initial_log_rt_rv is required")
- self.initial_log_rt_rv = initial_log_rt_rv
+ if log_rt_time_0_rv is None:
+ raise ValueError("log_rt_time_0_rv is required")
+ self.log_rt_time_0_rv = log_rt_time_0_rv
- if shared_rt_process is None:
- raise ValueError("shared_rt_process is required")
- self.shared_rt_process = shared_rt_process
+ if single_rt_process is None:
+ raise ValueError("single_rt_process is required")
+ self.single_rt_process = single_rt_process
def default_subpop_fractions(self) -> ArrayLike:
"""
@@ -104,7 +104,7 @@ def _validate_and_prepare_I0(
"""
Validate that I0 is a scalar prevalence value.
- SharedInfections operates on a single population, so I0 must be
+ PopulationInfections operates on a single population, so I0 must be
a scalar (0-dimensional array).
Parameters
@@ -126,13 +126,13 @@ def _validate_and_prepare_I0(
"""
if I0.ndim != 0:
raise ValueError(
- "SharedInfections requires I0_rv to return a scalar prevalence"
+ "PopulationInfections requires I0_rv to return a scalar prevalence"
)
return super()._validate_and_prepare_I0(I0, pop)
def validate(self) -> None:
"""
- Validate shared infections parameters.
+ Validate population infections parameters.
Checks that the generation interval is a valid PMF.
@@ -150,9 +150,9 @@ def sample(
**kwargs: object,
) -> LatentSample:
"""
- Sample shared infections using a single renewal process.
+ Sample population infections using a single renewal process.
- Generates a shared Rt trajectory, computes initial infections via
+ Generates a single $\\mathcal{R}(t)$ trajectory, computes initial infections via
exponential backprojection, and runs one renewal equation.
Parameters
@@ -184,28 +184,28 @@ def sample(
frac_check = jnp.isclose(pop.fractions[0], 1.0, atol=1e-6)
if pop.n_subpops != 1 or (not_jax_tracer(frac_check) and not frac_check):
raise ValueError(
- "SharedInfections requires exactly one subpopulation "
+ "PopulationInfections requires exactly one subpopulation "
"with fraction [1.0]"
)
n_total_days = self.n_initialization_points + n_days_post_init
- initial_log_rt = self.initial_log_rt_rv()
+ initial_log_rt = self.log_rt_time_0_rv()
- log_rt_shared = self.shared_rt_process.sample(
+ log_rt_single = self.single_rt_process.sample(
n_timepoints=n_total_days,
initial_value=initial_log_rt,
- name_prefix="log_rt_shared",
+ name_prefix="log_rt_single",
)
- rt_shared = jnp.exp(log_rt_shared)
+ rt_single = jnp.exp(log_rt_single)
gen_int = self.gen_int_rv()
I0 = self._validate_and_prepare_I0(jnp.asarray(self.I0_rv()), pop)
initial_r = r_approx_from_R(
- R=rt_shared[0, 0],
+ R=rt_single[0, 0],
g=gen_int,
n_newton_steps=4,
)
@@ -218,7 +218,7 @@ def sample(
post_init_infections = compute_infections_from_rt(
I0=recent_I0,
- Rt=rt_shared[self.n_initialization_points :, 0],
+ Rt=rt_single[self.n_initialization_points :, 0],
reversed_generation_interval_pmf=gen_int_reversed,
)
@@ -235,8 +235,8 @@ def sample(
with numpyro.handlers.scope(prefix=self.name, divider="::"):
numpyro.deterministic("I0_init", I0_init)
- numpyro.deterministic("log_rt_shared", log_rt_shared)
- numpyro.deterministic("rt_shared", rt_shared)
+ numpyro.deterministic("log_rt_single", log_rt_single)
+ numpyro.deterministic("rt_single", rt_single)
numpyro.deterministic("infections_aggregate", infections_aggregate)
return LatentSample(
diff --git a/pyrenew/latent/hierarchical_infections.py b/pyrenew/latent/subpopulation_infections.py
similarity index 89%
rename from pyrenew/latent/hierarchical_infections.py
rename to pyrenew/latent/subpopulation_infections.py
index 3e13d6e9..613f75f8 100644
--- a/pyrenew/latent/hierarchical_infections.py
+++ b/pyrenew/latent/subpopulation_infections.py
@@ -24,17 +24,17 @@
from pyrenew.metaclass import RandomVariable
-class HierarchicalInfections(BaseLatentInfectionProcess):
+class SubpopulationInfections(BaseLatentInfectionProcess):
"""
- Multi-subpopulation renewal model with hierarchical Rt structure.
+ Multi-subpopulation renewal model with hierarchical $\\mathcal{R}(t)$ structure.
- Each subpopulation has its own renewal equation with Rt deviating from a
+ Each subpopulation has its own renewal equation with $\\mathcal{R}(t)$ deviating from a
shared baseline. Suitable when transmission dynamics vary substantially
across subpopulations.
Mathematical form:
- - Baseline Rt: log[R_baseline(t)] ~ TemporalProcess
- - Subpopulation Rt: log R_k(t) = log[R_baseline(t)] + delta_k(t)
+ - Baseline $\\mathcal{R}(t)$: log[R_baseline(t)] ~ TemporalProcess
+ - Subpopulation $\\mathcal{R}(t)$: log R_k(t) = log[R_baseline(t)] + delta_k(t)
- Deviations: delta_k(t) ~ TemporalProcess with sum-to-zero constraint
- Renewal per subpop: I_k(t) = R_k(t) * sum_tau I_k(t-tau) * g(tau)
- Aggregate total: I_aggregate(t) = sum_k p_k * I_k(t)
@@ -46,7 +46,7 @@ class HierarchicalInfections(BaseLatentInfectionProcess):
Notes
-----
Sum-to-zero constraint on deviations ensures R_baseline(t) is the geometric
- mean of subpopulation Rt values, providing identifiability.
+ mean of subpopulation $\\mathcal{R}(t)$ values, providing identifiability.
"""
def __init__(
@@ -58,7 +58,7 @@ def __init__(
I0_rv: RandomVariable,
baseline_rt_process: TemporalProcess,
subpop_rt_deviation_process: TemporalProcess,
- initial_log_rt_rv: RandomVariable,
+ log_rt_time_0_rv: RandomVariable,
) -> None:
"""
Initialize hierarchical infections process.
@@ -76,11 +76,11 @@ def __init__(
I0_rv
Initial infection prevalence (proportion of population)
baseline_rt_process
- Temporal process for baseline Rt dynamics
+ Temporal process for baseline $\\mathcal{R}(t)$ dynamics
subpop_rt_deviation_process
Temporal process for subpopulation deviations
- initial_log_rt_rv
- Initial value for log(Rt) at time 0.
+ log_rt_time_0_rv
+ Initial value for log($\\mathcal{R}(t)$) at time 0.
Raises
------
@@ -100,9 +100,9 @@ def __init__(
if isinstance(I0_rv, DeterministicVariable):
self._validate_I0(I0_rv.value)
- if initial_log_rt_rv is None:
- raise ValueError("initial_log_rt_rv is required")
- self.initial_log_rt_rv = initial_log_rt_rv
+ if log_rt_time_0_rv is None:
+ raise ValueError("log_rt_time_0_rv is required")
+ self.log_rt_time_0_rv = log_rt_time_0_rv
if baseline_rt_process is None:
raise ValueError("baseline_rt_process is required")
@@ -167,7 +167,7 @@ def sample(
"""
Sample hierarchical infections for all subpopulations.
- Generates baseline Rt, subpopulation deviations with sum-to-zero
+ Generates baseline $\\mathcal{R}(t)$, subpopulation deviations with sum-to-zero
constraint, initial infections, and runs n_subpops independent renewal processes.
Parameters
@@ -193,7 +193,7 @@ def sample(
n_total_days = self.n_initialization_points + n_days_post_init
- initial_log_rt = self.initial_log_rt_rv()
+ initial_log_rt = self.log_rt_time_0_rv()
log_rt_baseline = self.baseline_rt_process.sample(
n_timepoints=n_total_days,
diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py
index 8d7d78b1..8a124348 100644
--- a/pyrenew/model/multisignal_model.py
+++ b/pyrenew/model/multisignal_model.py
@@ -20,9 +20,9 @@ class MultiSignalModel(Model):
"""
Multi-signal renewal model.
- Combines a latent infection process (e.g., HierarchicalInfections,
- PartitionedInfections) with multiple observation processes (e.g.,
- CountObservation, WastewaterObservation).
+ Combines a latent infection process (e.g., SubpopulationInfections)
+ with multiple observation processes (e.g., CountObservation,
+ WastewaterObservation).
Built via PyrenewBuilder to ensure n_initialization_points is computed
correctly from all components. Can also be constructed manually for
diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py
index 98c0b797..e05d146a 100644
--- a/pyrenew/model/pyrenew_builder.py
+++ b/pyrenew/model/pyrenew_builder.py
@@ -61,8 +61,8 @@ def configure_latent(
Parameters
----------
latent_class
- Class to use for latent infections (e.g., HierarchicalInfections,
- PartitionedInfections, or a custom implementation)
+ Class to use for latent infections (e.g., PopulationInfections,
+ SubpopulationInfections, or a custom implementation)
**params
Parameters for latent class constructor (model structure).
DO NOT include n_initialization_points - it will be computed
diff --git a/test/conftest.py b/test/conftest.py
index 980c5b25..e9a5e16c 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -10,7 +10,12 @@
import pytest
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
-from pyrenew.latent import AR1, HierarchicalInfections, RandomWalk
+from pyrenew.latent import (
+ AR1,
+ PopulationInfections,
+ RandomWalk,
+ SubpopulationInfections,
+)
from pyrenew.observation import (
Counts,
HierarchicalNormalNoise,
@@ -146,31 +151,51 @@ def hierarchical_normal_noise_tight():
# =============================================================================
-# Hierarchical Infections Fixture
+# Latent Infections Fixtures
# =============================================================================
@pytest.fixture
-def hierarchical_infections(gen_int_rv):
+def subpopulation_infections(gen_int_rv):
"""
- Pre-configured HierarchicalInfections instance.
+ Pre-configured SubpopulationInfections instance.
Returns
-------
- HierarchicalInfections
+ SubpopulationInfections
Configured infection process with realistic parameters.
"""
- return HierarchicalInfections(
- name="hierarchical",
+ return SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025),
n_initialization_points=7,
)
+@pytest.fixture
+def population_infections(gen_int_rv):
+ """
+ Pre-configured PopulationInfections instance.
+
+ Returns
+ -------
+ PopulationInfections
+ Configured infection process with realistic parameters.
+ """
+ return PopulationInfections(
+ name="population",
+ gen_int_rv=gen_int_rv,
+ I0_rv=DeterministicVariable("I0", 0.001),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
+ n_initialization_points=7,
+ )
+
+
# =============================================================================
# Counts Process Fixtures
# =============================================================================
diff --git a/test/integration/conftest.py b/test/integration/conftest.py
index 2592b083..8505b813 100644
--- a/test/integration/conftest.py
+++ b/test/integration/conftest.py
@@ -21,7 +21,7 @@
)
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import AR1
-from pyrenew.latent.shared_infections import SharedInfections
+from pyrenew.latent.population_infections import PopulationInfections
from pyrenew.model import PyrenewBuilder
from pyrenew.observation import Counts, NegativeBinomialNoise
from pyrenew.randomvariable import DistributionalVariable
@@ -137,7 +137,7 @@ def he_model(
ed_day_of_week_effects: jnp.ndarray,
) -> PyrenewBuilder:
"""
- Build a SharedInfections model with hospital + ED observation processes.
+ Build a PopulationInfections model with hospital + ED observation processes.
Parameters
----------
@@ -159,13 +159,11 @@ def he_model(
builder = PyrenewBuilder()
builder.configure_latent(
- SharedInfections,
+ PopulationInfections,
gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf),
I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)),
- initial_log_rt_rv=DistributionalVariable(
- "initial_log_rt", dist.Normal(0.0, 0.5)
- ),
- shared_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
+ log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)),
+ single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
)
hospital_obs = Counts(
diff --git a/test/integration/test_shared_infections_he.py b/test/integration/test_population_infections_he.py
similarity index 95%
rename from test/integration/test_shared_infections_he.py
rename to test/integration/test_population_infections_he.py
index 197913a4..12ac6553 100644
--- a/test/integration/test_shared_infections_he.py
+++ b/test/integration/test_population_infections_he.py
@@ -1,7 +1,7 @@
"""
-Integration test: SharedInfections H+E model with posterior recovery.
+Integration test: PopulationInfections H+E model with posterior recovery.
-Fits a SharedInfections model with hospital admissions and ED visit
+Fits a PopulationInfections model with hospital admissions and ED visit
observation processes to synthetic 120-day CA data, then checks that
posterior estimates recover known true parameters.
"""
@@ -173,9 +173,9 @@ def posterior_dt(
fitted_model.mcmc,
dims={
"latent_infections": ["time"],
- "SharedInfections::infections_aggregate": ["time"],
- "SharedInfections::log_rt_shared": ["time", "dummy"],
- "SharedInfections::rt_shared": ["time", "dummy"],
+ "PopulationInfections::infections_aggregate": ["time"],
+ "PopulationInfections::log_rt_single": ["time", "dummy"],
+ "PopulationInfections::rt_single": ["time", "dummy"],
"hospital_predicted": ["time"],
"ed_predicted": ["time"],
},
@@ -216,7 +216,7 @@ def test_mcmc_convergence(
"""
summary = az.summary(
posterior_dt,
- var_names=["I0", "initial_log_rt", "ihr", "iedr"],
+ var_names=["I0", "log_rt_time_0", "ihr", "iedr"],
)
rhat = summary["r_hat"].astype(float)
ess = summary["ess_bulk"].astype(float)
@@ -239,7 +239,7 @@ def test_rt_posterior_covers_truth(
daily_infections : pl.DataFrame
True infections and R(t) trajectory.
"""
- rt_posterior = posterior_dt.posterior["SharedInfections::rt_shared"]
+ rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"]
rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values
rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values
diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py
index 500b65dc..0d03e3a5 100644
--- a/test/test_interface_coverage.py
+++ b/test/test_interface_coverage.py
@@ -23,12 +23,12 @@
AR1,
DifferencedAR1,
GammaGroupSdPrior,
- HierarchicalInfections,
HierarchicalNormalPrior,
Infections,
InfectionsWithFeedback,
RandomWalk,
StudentTGroupModePrior,
+ SubpopulationInfections,
)
from pyrenew.metaclass import RandomVariable
from pyrenew.observation import (
@@ -225,11 +225,11 @@ def validate_data(self, n_total, n_subpops, **obs_data): # numpydoc ignore=GL08
def test_get_required_lookback(gen_int_rv):
"""get_required_lookback returns generation interval PMF length."""
- infections = HierarchicalInfections(
- name="hierarchical",
+ infections = SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025),
n_initialization_points=7,
@@ -239,17 +239,17 @@ def test_get_required_lookback(gen_int_rv):
# =============================================================================
-# HierarchicalInfections.validate() coverage
+# SubpopulationInfections.validate() coverage
# =============================================================================
-def test_hierarchical_infections_validate(gen_int_rv):
- """HierarchicalInfections.validate() runs without error on valid PMF."""
- infections = HierarchicalInfections(
- name="hierarchical",
+def test_subpopulation_infections_validate(gen_int_rv):
+ """SubpopulationInfections.validate() runs without error on valid PMF."""
+ infections = SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025),
n_initialization_points=7,
@@ -434,13 +434,13 @@ def test_name_attribute_matches_expected(instance, expected_name):
assert instance.name == expected_name
-def test_hierarchical_infections_name(gen_int_rv):
- """HierarchicalInfections.name is correctly set during construction."""
- infections = HierarchicalInfections(
+def test_subpopulation_infections_name(gen_int_rv):
+ """SubpopulationInfections.name is correctly set during construction."""
+ infections = SubpopulationInfections(
name="test_hi",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025),
n_initialization_points=7,
diff --git a/test/test_shared_infections.py b/test/test_population_infections.py
similarity index 55%
rename from test/test_shared_infections.py
rename to test/test_population_infections.py
index dfefc5c7..261ac8f7 100644
--- a/test/test_shared_infections.py
+++ b/test/test_population_infections.py
@@ -1,5 +1,5 @@
"""
-Unit tests for SharedInfections.
+Unit tests for PopulationInfections.
"""
import jax.numpy as jnp
@@ -7,110 +7,90 @@
import pytest
from pyrenew.deterministic import DeterministicVariable
-from pyrenew.latent import AR1, RandomWalk
-from pyrenew.latent.shared_infections import SharedInfections
-
-
-@pytest.fixture
-def shared_infections(gen_int_rv):
- """
- Pre-configured SharedInfections instance.
-
- Returns
- -------
- SharedInfections
- Configured infection process with realistic parameters.
- """
- return SharedInfections(
- name="shared",
- gen_int_rv=gen_int_rv,
- I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
- n_initialization_points=7,
- )
-
-
-class TestSharedInfectionsSample:
- """Test sample method for SharedInfections."""
-
- def test_output_shapes(self, shared_infections):
+from pyrenew.latent import RandomWalk
+from pyrenew.latent.population_infections import PopulationInfections
+
+
+class TestPopulationInfectionsSample:
+ """Test sample method for PopulationInfections."""
+
+ def test_output_shapes(self, population_infections):
"""Test that output shapes are correct for a single-population model."""
n_days_post_init = 30
- n_total = shared_infections.n_initialization_points + n_days_post_init
+ n_total = population_infections.n_initialization_points + n_days_post_init
with numpyro.handlers.seed(rng_seed=42):
- inf_agg, inf_all = shared_infections.sample(
+ inf_agg, inf_all = population_infections.sample(
n_days_post_init=n_days_post_init,
)
assert inf_agg.shape == (n_total,)
assert inf_all.shape == (n_total, 1)
- def test_aggregate_equals_single_subpop(self, shared_infections):
+ def test_aggregate_equals_single_subpop(self, population_infections):
"""Test that aggregate infections equal the single subpopulation column."""
with numpyro.handlers.seed(rng_seed=42):
- inf_agg, inf_all = shared_infections.sample(
+ inf_agg, inf_all = population_infections.sample(
n_days_post_init=30,
)
assert jnp.allclose(inf_agg, inf_all[:, 0], atol=1e-6)
- def test_infections_are_positive(self, shared_infections):
+ def test_infections_are_positive(self, population_infections):
"""Test that all infections are positive."""
with numpyro.handlers.seed(rng_seed=42):
- inf_agg, inf_all = shared_infections.sample(
+ inf_agg, inf_all = population_infections.sample(
n_days_post_init=30,
)
assert jnp.all(inf_agg > 0)
assert jnp.all(inf_all > 0)
- def test_deterministic_sites_recorded(self, shared_infections):
+ def test_deterministic_sites_recorded(self, population_infections):
"""Test that expected numpyro deterministic sites are recorded."""
with numpyro.handlers.seed(rng_seed=42):
with numpyro.handlers.trace() as trace:
- shared_infections.sample(n_days_post_init=30)
+ population_infections.sample(n_days_post_init=30)
expected_sites = [
- "shared::I0_init",
- "shared::log_rt_shared",
- "shared::rt_shared",
- "shared::infections_aggregate",
+ "population::I0_init",
+ "population::log_rt_single",
+ "population::rt_single",
+ "population::infections_aggregate",
]
for site in expected_sites:
assert site in trace, f"Missing deterministic site: {site}"
- def test_rt_is_exp_of_log_rt(self, shared_infections):
+ def test_rt_is_exp_of_log_rt(self, population_infections):
"""Test that recorded Rt equals exp of recorded log Rt."""
with numpyro.handlers.seed(rng_seed=42):
with numpyro.handlers.trace() as trace:
- shared_infections.sample(n_days_post_init=30)
+ population_infections.sample(n_days_post_init=30)
- log_rt = trace["shared::log_rt_shared"]["value"]
- rt = trace["shared::rt_shared"]["value"]
+ log_rt = trace["population::log_rt_single"]["value"]
+ rt = trace["population::rt_single"]["value"]
assert jnp.allclose(rt, jnp.exp(log_rt), atol=1e-6)
- def test_default_fractions_used_when_none(self, shared_infections):
+ def test_default_fractions_used_when_none(self, population_infections):
"""Test that default fractions [1.0] are used when not provided."""
with numpyro.handlers.seed(rng_seed=42):
- inf_agg, inf_all = shared_infections.sample(
+ inf_agg, inf_all = population_infections.sample(
n_days_post_init=30,
subpop_fractions=None,
)
assert inf_all.shape[1] == 1
- def test_explicit_fractions_one(self, shared_infections):
+ def test_explicit_fractions_one(self, population_infections):
"""Test that explicit fractions [1.0] produce same results as default."""
with numpyro.handlers.seed(rng_seed=42):
- inf_agg_default, inf_all_default = shared_infections.sample(
+ inf_agg_default, inf_all_default = population_infections.sample(
n_days_post_init=30,
)
with numpyro.handlers.seed(rng_seed=42):
- inf_agg_explicit, inf_all_explicit = shared_infections.sample(
+ inf_agg_explicit, inf_all_explicit = population_infections.sample(
n_days_post_init=30,
subpop_fractions=jnp.array([1.0]),
)
@@ -118,24 +98,24 @@ def test_explicit_fractions_one(self, shared_infections):
assert jnp.allclose(inf_agg_default, inf_agg_explicit, atol=1e-6)
assert jnp.allclose(inf_all_default, inf_all_explicit, atol=1e-6)
- def test_different_seeds_give_different_results(self, shared_infections):
+ def test_different_seeds_give_different_results(self, population_infections):
"""Test that different RNG seeds produce different trajectories."""
with numpyro.handlers.seed(rng_seed=1):
- inf_agg_1, _ = shared_infections.sample(n_days_post_init=30)
+ inf_agg_1, _ = population_infections.sample(n_days_post_init=30)
with numpyro.handlers.seed(rng_seed=999):
- inf_agg_2, _ = shared_infections.sample(n_days_post_init=30)
+ inf_agg_2, _ = population_infections.sample(n_days_post_init=30)
assert not jnp.allclose(inf_agg_1, inf_agg_2)
def test_custom_name_prefix(self, gen_int_rv):
"""Test that custom name prefix is used in deterministic sites."""
- process = SharedInfections(
+ process = PopulationInfections(
name="my_infections",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=RandomWalk(),
n_initialization_points=7,
)
@@ -143,69 +123,69 @@ def test_custom_name_prefix(self, gen_int_rv):
with numpyro.handlers.trace() as trace:
process.sample(n_days_post_init=30)
- assert "my_infections::rt_shared" in trace
+ assert "my_infections::rt_single" in trace
-class TestSharedInfectionsValidation:
+class TestPopulationInfectionsValidation:
"""Test validation of inputs."""
def test_rejects_missing_I0_rv(self, gen_int_rv):
"""Test that None I0_rv is rejected."""
with pytest.raises(ValueError, match="I0_rv is required"):
- SharedInfections(
- name="shared",
+ PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=None,
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=RandomWalk(),
n_initialization_points=7,
)
- def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv):
- """Test that None initial_log_rt_rv is rejected."""
- with pytest.raises(ValueError, match="initial_log_rt_rv is required"):
- SharedInfections(
- name="shared",
+ def test_rejects_missing_log_rt_time_0_rv(self, gen_int_rv):
+ """Test that None log_rt_time_0_rv is rejected."""
+ with pytest.raises(ValueError, match="log_rt_time_0_rv is required"):
+ PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=None,
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=None,
+ single_rt_process=RandomWalk(),
n_initialization_points=7,
)
- def test_rejects_missing_shared_rt_process(self, gen_int_rv):
- """Test that None shared_rt_process is rejected."""
- with pytest.raises(ValueError, match="shared_rt_process is required"):
- SharedInfections(
- name="shared",
+ def test_rejects_missing_single_rt_process(self, gen_int_rv):
+ """Test that None single_rt_process is rejected."""
+ with pytest.raises(ValueError, match="single_rt_process is required"):
+ PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=None,
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=None,
n_initialization_points=7,
)
def test_rejects_invalid_I0(self, gen_int_rv):
"""Test that invalid I0 values are rejected at construction."""
with pytest.raises(ValueError, match="I0 must be positive"):
- SharedInfections(
- name="shared",
+ PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", -0.1),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=RandomWalk(),
n_initialization_points=7,
)
def test_rejects_I0_greater_than_one(self, gen_int_rv):
"""Test that I0 > 1 is rejected at construction."""
with pytest.raises(ValueError, match="I0 must be <= 1"):
- SharedInfections(
- name="shared",
+ PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 1.5),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=RandomWalk(),
n_initialization_points=7,
)
@@ -214,46 +194,46 @@ def test_rejects_insufficient_n_initialization_points(self, gen_int_rv):
with pytest.raises(
ValueError, match="n_initialization_points must be at least"
):
- SharedInfections(
- name="shared",
+ PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=RandomWalk(),
n_initialization_points=2,
)
- def test_rejects_fractions_not_summing_to_one(self, shared_infections):
+ def test_rejects_fractions_not_summing_to_one(self, population_infections):
"""Test that fractions not summing to 1 raises error at sample time."""
with pytest.raises(ValueError, match="must sum to 1.0"):
with numpyro.handlers.seed(rng_seed=42):
- shared_infections.sample(
+ population_infections.sample(
n_days_post_init=30,
subpop_fractions=jnp.array([0.5]),
)
def test_rejects_multiple_subpop_fractions_even_if_sum_to_one(
- self, shared_infections
+ self, population_infections
):
- """Test that multi-element fractions are rejected for shared infections."""
+ """Test that multi-element fractions are rejected for population infections."""
with pytest.raises(
ValueError,
match="requires exactly one subpopulation with fraction \\[1.0\\]",
):
with numpyro.handlers.seed(rng_seed=42):
- shared_infections.sample(
+ population_infections.sample(
n_days_post_init=30,
subpop_fractions=jnp.array([0.5, 0.5]),
)
def test_rejects_non_scalar_I0(self, gen_int_rv):
"""Test that vector-valued I0 is rejected with a clear error."""
- process = SharedInfections(
- name="shared",
+ process = PopulationInfections(
+ name="population",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", jnp.array([0.001, 0.002])),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
- shared_rt_process=RandomWalk(),
+ log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0),
+ single_rt_process=RandomWalk(),
n_initialization_points=7,
)
@@ -264,47 +244,47 @@ def test_rejects_non_scalar_I0(self, gen_int_rv):
with numpyro.handlers.seed(rng_seed=42):
process.sample(n_days_post_init=30)
- def test_validate_passes(self, shared_infections):
+ def test_validate_passes(self, population_infections):
"""Test that validate() succeeds for a properly constructed instance."""
- shared_infections.validate()
+ population_infections.validate()
- def test_default_subpop_fractions(self, shared_infections):
+ def test_default_subpop_fractions(self, population_infections):
"""Test that default_subpop_fractions returns [1.0]."""
- fracs = shared_infections.default_subpop_fractions()
+ fracs = population_infections.default_subpop_fractions()
assert jnp.allclose(fracs, jnp.array([1.0]))
-class TestSharedValidateAndPrepareI0:
- """Test _validate_and_prepare_I0 for SharedInfections."""
+class TestPopulationValidateAndPrepareI0:
+ """Test _validate_and_prepare_I0 for PopulationInfections."""
- def test_accepts_valid_scalar(self, shared_infections):
+ def test_accepts_valid_scalar(self, population_infections):
"""Test that a valid scalar I0 passes through unchanged."""
- pop = shared_infections._parse_and_validate_fractions()
+ pop = population_infections._parse_and_validate_fractions()
I0 = jnp.array(0.01)
- result = shared_infections._validate_and_prepare_I0(I0, pop)
+ result = population_infections._validate_and_prepare_I0(I0, pop)
assert result.ndim == 0
assert jnp.isclose(result, 0.01)
- def test_rejects_vector(self, shared_infections):
+ def test_rejects_vector(self, population_infections):
"""Test that a vector I0 is rejected."""
- pop = shared_infections._parse_and_validate_fractions()
+ pop = population_infections._parse_and_validate_fractions()
I0 = jnp.array([0.01, 0.02])
with pytest.raises(ValueError, match="scalar prevalence"):
- shared_infections._validate_and_prepare_I0(I0, pop)
+ population_infections._validate_and_prepare_I0(I0, pop)
- def test_rejects_negative(self, shared_infections):
+ def test_rejects_negative(self, population_infections):
"""Test that negative I0 is rejected."""
- pop = shared_infections._parse_and_validate_fractions()
+ pop = population_infections._parse_and_validate_fractions()
I0 = jnp.array(-0.01)
with pytest.raises(ValueError, match="I0 must be positive"):
- shared_infections._validate_and_prepare_I0(I0, pop)
+ population_infections._validate_and_prepare_I0(I0, pop)
- def test_rejects_greater_than_one(self, shared_infections):
+ def test_rejects_greater_than_one(self, population_infections):
"""Test that I0 > 1 is rejected."""
- pop = shared_infections._parse_and_validate_fractions()
+ pop = population_infections._parse_and_validate_fractions()
I0 = jnp.array(1.5)
with pytest.raises(ValueError, match="I0 must be <= 1"):
- shared_infections._validate_and_prepare_I0(I0, pop)
+ population_infections._validate_and_prepare_I0(I0, pop)
if __name__ == "__main__":
diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py
index 74769411..211a3ace 100644
--- a/test/test_pyrenew_builder.py
+++ b/test/test_pyrenew_builder.py
@@ -6,7 +6,7 @@
import pytest
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
-from pyrenew.latent import HierarchicalInfections, RandomWalk
+from pyrenew.latent import RandomWalk, SubpopulationInfections
from pyrenew.model import MultiSignalModel, PyrenewBuilder
from pyrenew.observation import Counts, CountsBySubpop, NegativeBinomialNoise
@@ -28,10 +28,10 @@ def simple_builder():
gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3]))
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
)
@@ -66,10 +66,10 @@ def validation_builder():
gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3]))
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
)
@@ -105,10 +105,10 @@ def test_rejects_population_structure_at_configure_time(self):
with pytest.raises(ValueError, match="Do not specify"):
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
subpop_fractions=jnp.array([0.5, 0.5]),
@@ -121,10 +121,10 @@ def test_rejects_n_initialization_points_at_configure_time(self):
with pytest.raises(ValueError, match="Do not specify n_initialization_points"):
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=10,
@@ -136,20 +136,20 @@ def test_rejects_reconfiguring_latent(self):
gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3]))
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
)
with pytest.raises(RuntimeError, match="already configured"):
builder.configure_latent(
- HierarchicalInfections,
+ SubpopulationInfections,
gen_int_rv=gen_int,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
)
@@ -189,7 +189,7 @@ def test_compute_n_initialization_points_without_latent_raises(self):
def test_compute_n_initialization_points_without_gen_int_raises(self):
"""Test that compute_n_initialization_points without gen_int_rv raises."""
builder = PyrenewBuilder()
- builder.latent_class = HierarchicalInfections
+ builder.latent_class = SubpopulationInfections
builder.latent_params = {}
with pytest.raises(ValueError, match="gen_int_rv is required"):
diff --git a/test/test_hierarchical_infections.py b/test/test_subpopulation_infections.py
similarity index 68%
rename from test/test_hierarchical_infections.py
rename to test/test_subpopulation_infections.py
index 60501e06..658546e3 100644
--- a/test/test_hierarchical_infections.py
+++ b/test/test_subpopulation_infections.py
@@ -1,5 +1,5 @@
"""
-Unit tests for HierarchicalInfections.
+Unit tests for SubpopulationInfections.
"""
import jax.numpy as jnp
@@ -7,18 +7,18 @@
import pytest
from pyrenew.deterministic import DeterministicVariable
-from pyrenew.latent import HierarchicalInfections, RandomWalk
+from pyrenew.latent import RandomWalk, SubpopulationInfections
-class TestHierarchicalInfectionsSample:
+class TestSubpopulationInfectionsSample:
"""Test sample method with population structure at sample time."""
- def test_jurisdiction_total_is_weighted_sum(self, hierarchical_infections):
+ def test_jurisdiction_total_is_weighted_sum(self, subpopulation_infections):
"""Test that jurisdiction total equals weighted sum of subpopulations."""
fractions = jnp.array([0.3, 0.25, 0.45])
with numpyro.handlers.seed(rng_seed=42):
- inf_juris, inf_all = hierarchical_infections.sample(
+ inf_juris, inf_all = subpopulation_infections.sample(
n_days_post_init=30,
subpop_fractions=fractions,
)
@@ -27,24 +27,24 @@ def test_jurisdiction_total_is_weighted_sum(self, hierarchical_infections):
assert jnp.allclose(inf_juris, expected, atol=1e-6)
- def test_deviations_sum_to_zero(self, hierarchical_infections):
+ def test_deviations_sum_to_zero(self, subpopulation_infections):
"""Test that subpopulation deviations sum to zero (identifiability)."""
with numpyro.handlers.seed(rng_seed=42):
with numpyro.handlers.trace() as trace:
- hierarchical_infections.sample(
+ subpopulation_infections.sample(
n_days_post_init=30,
subpop_fractions=jnp.array([0.3, 0.25, 0.45]),
)
- deviations = trace["hierarchical::subpop_deviations"]["value"]
+ deviations = trace["subpopulation::subpop_deviations"]["value"]
deviation_sums = jnp.sum(deviations, axis=1)
assert jnp.allclose(deviation_sums, 0.0, atol=1e-6)
- def test_infections_are_positive(self, hierarchical_infections):
+ def test_infections_are_positive(self, subpopulation_infections):
"""Test that all infections are positive (epidemiological invariant)."""
with numpyro.handlers.seed(rng_seed=42):
- inf_juris, inf_all = hierarchical_infections.sample(
+ inf_juris, inf_all = subpopulation_infections.sample(
n_days_post_init=30,
subpop_fractions=jnp.array([0.3, 0.25, 0.45]),
)
@@ -62,15 +62,15 @@ def test_infections_are_positive(self, hierarchical_infections):
ids=["K=1", "K=3", "K=6"],
)
def test_shape_and_positivity_across_subpop_counts(
- self, hierarchical_infections, fractions
+ self, subpopulation_infections, fractions
):
"""Test correct shapes and positivity for varying numbers of subpops."""
n_days_post_init = 30
- n_total = hierarchical_infections.n_initialization_points + n_days_post_init
+ n_total = subpopulation_infections.n_initialization_points + n_days_post_init
n_subpops = len(fractions)
with numpyro.handlers.seed(rng_seed=42):
- inf_juris, inf_all = hierarchical_infections.sample(
+ inf_juris, inf_all = subpopulation_infections.sample(
n_days_post_init=n_days_post_init,
subpop_fractions=fractions,
)
@@ -85,30 +85,30 @@ def test_shape_and_positivity_across_subpop_counts(
assert jnp.allclose(inf_juris, expected, atol=1e-6)
-class TestHierarchicalInfectionsValidation:
+class TestSubpopulationInfectionsValidation:
"""Test validation of inputs."""
def test_rejects_missing_I0_rv(self, gen_int_rv):
"""Test that None I0_rv is rejected."""
with pytest.raises(ValueError, match="I0_rv is required"):
- HierarchicalInfections(
- name="hierarchical",
+ SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=None,
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=7,
)
- def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv):
- """Test that None initial_log_rt_rv is rejected."""
- with pytest.raises(ValueError, match="initial_log_rt_rv is required"):
- HierarchicalInfections(
- name="hierarchical",
+ def test_rejects_missing_log_rt_time_0_rv(self, gen_int_rv):
+ """Test that None log_rt_time_0_rv is rejected."""
+ with pytest.raises(ValueError, match="log_rt_time_0_rv is required"):
+ SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=None,
+ log_rt_time_0_rv=None,
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=7,
@@ -117,11 +117,11 @@ def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv):
def test_rejects_missing_baseline_rt_process(self, gen_int_rv):
"""Test that None baseline_rt_process is rejected."""
with pytest.raises(ValueError, match="baseline_rt_process is required"):
- HierarchicalInfections(
- name="hierarchical",
+ SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=None,
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=7,
@@ -130,11 +130,11 @@ def test_rejects_missing_baseline_rt_process(self, gen_int_rv):
def test_rejects_missing_subpop_rt_deviation_process(self, gen_int_rv):
"""Test that None subpop_rt_deviation_process is rejected."""
with pytest.raises(ValueError, match="subpop_rt_deviation_process is required"):
- HierarchicalInfections(
- name="hierarchical",
+ SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=None,
n_initialization_points=7,
@@ -143,11 +143,11 @@ def test_rejects_missing_subpop_rt_deviation_process(self, gen_int_rv):
def test_rejects_invalid_I0(self, gen_int_rv):
"""Test that invalid I0 values are rejected."""
with pytest.raises(ValueError, match="I0 must be positive"):
- HierarchicalInfections(
- name="hierarchical",
+ SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", -0.1),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=7,
@@ -158,36 +158,36 @@ def test_rejects_insufficient_n_initialization_points(self, gen_int_rv):
with pytest.raises(
ValueError, match="n_initialization_points must be at least"
):
- HierarchicalInfections(
- name="hierarchical",
+ SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=2,
)
- def test_rejects_fractions_not_summing_to_one(self, hierarchical_infections):
+ def test_rejects_fractions_not_summing_to_one(self, subpopulation_infections):
"""Test that fractions not summing to 1 raises error at sample time."""
with pytest.raises(ValueError, match="must sum to 1.0"):
with numpyro.handlers.seed(rng_seed=42):
- hierarchical_infections.sample(
+ subpopulation_infections.sample(
n_days_post_init=30,
subpop_fractions=jnp.array([0.3, 0.25, 0.40]),
)
-class TestHierarchicalInfectionsPerSubpopI0:
+class TestSubpopulationInfectionsPerSubpopI0:
"""Test per-subpopulation I0 values."""
def test_per_subpop_I0_array(self, gen_int_rv):
"""Test with per-subpopulation I0 values and verify positivity."""
- process = HierarchicalInfections(
- name="hierarchical",
+ process = SubpopulationInfections(
+ name="subpopulation",
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", jnp.array([0.001, 0.002, 0.0015])),
- initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
+ log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=RandomWalk(),
subpop_rt_deviation_process=RandomWalk(),
n_initialization_points=7,
@@ -207,46 +207,46 @@ def test_per_subpop_I0_array(self, gen_int_rv):
assert jnp.all(inf_all > 0)
-class TestHierarchicalValidateAndPrepareI0:
- """Test _validate_and_prepare_I0 for HierarchicalInfections."""
+class TestSubpopulationValidateAndPrepareI0:
+ """Test _validate_and_prepare_I0 for SubpopulationInfections."""
- def test_broadcasts_scalar_to_subpop_array(self, hierarchical_infections):
+ def test_broadcasts_scalar_to_subpop_array(self, subpopulation_infections):
"""Test that scalar I0 is broadcast to per-subpopulation array."""
- pop = hierarchical_infections._parse_and_validate_fractions(
+ pop = subpopulation_infections._parse_and_validate_fractions(
subpop_fractions=jnp.array([0.3, 0.25, 0.45])
)
I0 = jnp.array(0.01)
- result = hierarchical_infections._validate_and_prepare_I0(I0, pop)
+ result = subpopulation_infections._validate_and_prepare_I0(I0, pop)
assert result.shape == (3,)
assert jnp.allclose(result, 0.01)
- def test_passes_through_matching_array(self, hierarchical_infections):
+ def test_passes_through_matching_array(self, subpopulation_infections):
"""Test that a per-subpopulation I0 array passes through unchanged."""
- pop = hierarchical_infections._parse_and_validate_fractions(
+ pop = subpopulation_infections._parse_and_validate_fractions(
subpop_fractions=jnp.array([0.3, 0.25, 0.45])
)
I0 = jnp.array([0.001, 0.002, 0.0015])
- result = hierarchical_infections._validate_and_prepare_I0(I0, pop)
+ result = subpopulation_infections._validate_and_prepare_I0(I0, pop)
assert result.shape == (3,)
assert jnp.allclose(result, I0)
- def test_rejects_negative(self, hierarchical_infections):
+ def test_rejects_negative(self, subpopulation_infections):
"""Test that negative I0 is rejected."""
- pop = hierarchical_infections._parse_and_validate_fractions(
+ pop = subpopulation_infections._parse_and_validate_fractions(
subpop_fractions=jnp.array([0.5, 0.5])
)
I0 = jnp.array(-0.01)
with pytest.raises(ValueError, match="I0 must be positive"):
- hierarchical_infections._validate_and_prepare_I0(I0, pop)
+ subpopulation_infections._validate_and_prepare_I0(I0, pop)
- def test_rejects_greater_than_one(self, hierarchical_infections):
+ def test_rejects_greater_than_one(self, subpopulation_infections):
"""Test that I0 > 1 is rejected."""
- pop = hierarchical_infections._parse_and_validate_fractions(
+ pop = subpopulation_infections._parse_and_validate_fractions(
subpop_fractions=jnp.array([0.5, 0.5])
)
I0 = jnp.array(1.5)
with pytest.raises(ValueError, match="I0 must be <= 1"):
- hierarchical_infections._validate_and_prepare_I0(I0, pop)
+ subpopulation_infections._validate_and_prepare_I0(I0, pop)
if __name__ == "__main__":