diff --git a/cpp/examples/abm_parameter_study.cpp b/cpp/examples/abm_parameter_study.cpp index cfaba3ae94..d6b44c4bb1 100644 --- a/cpp/examples/abm_parameter_study.cpp +++ b/cpp/examples/abm_parameter_study.cpp @@ -34,6 +34,7 @@ #include "memilio/utils/stl_util.h" #include +#include constexpr size_t num_age_groups = 4; @@ -173,9 +174,9 @@ int main() // Set start and end time for the simulation. auto t0 = mio::abm::TimePoint(0); - auto tmax = t0 + mio::abm::days(5); + auto tmax = t0 + mio::abm::days(10); // Set the number of simulations to run in the study - const size_t num_runs = 3; + const size_t num_runs = 10; // Create a parameter study. // Note that the study for the ABM currently does not make use of the arguments "parameters" or "dt", as we create @@ -187,42 +188,74 @@ int main() // study.get_rng().seed({12341234, 53456, 63451, 5232576, 84586, 52345}); const std::string result_dir = mio::path_join(mio::base_dir(), "example_results"); + std::filesystem::remove_all(result_dir); if (!mio::create_directory(result_dir)) { mio::log_error("Could not create result directory \"{}\".", result_dir); return 1; } + const std::string result_directory_standard = mio::path_join(result_dir, "standart_results"); + if (!mio::create_directory(result_directory_standard)) { + mio::log_error("Could not create result directory \"{}\".", result_directory_standard); + return 1; + } + + const std::string result_dir_detailed = mio::path_join(result_dir, "detailed"); + if (!mio::create_directory(result_dir_detailed)) { + mio::log_error("Could not create result directory \"{}\".", result_dir_detailed); + return 1; + } + auto ensemble_results = study.run( [](auto, auto t0_, auto, size_t) { return mio::abm::ResultSimulation(make_model(mio::thread_local_rng()), t0_); }, [result_dir](auto&& sim, auto&& run_idx) { - auto interpolated_result = mio::interpolate_simulation_result(sim.get_result()); + auto interpolated_result = mio::interpolate_simulation_result(sim.get_result()); + auto interpolated_result_detailed = mio::interpolate_simulation_result(sim.get_result_detailed()); + std::string outpath = mio::path_join(result_dir, "abm_minimal_run_" + std::to_string(run_idx) + ".txt"); std::ofstream outfile_run(outpath); sim.get_result().print_table(outfile_run, {"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4); + std::cout << "Results written to " << outpath << std::endl; - auto params = std::vector{}; - return std::vector{interpolated_result}; + + return std::vector>{interpolated_result, interpolated_result_detailed}; }); if (ensemble_results.size() > 0) { - auto ensemble_results_p05 = ensemble_percentile(ensemble_results, 0.05); - auto ensemble_results_p25 = ensemble_percentile(ensemble_results, 0.25); - auto ensemble_results_p50 = ensemble_percentile(ensemble_results, 0.50); - auto ensemble_results_p75 = ensemble_percentile(ensemble_results, 0.75); - auto ensemble_results_p95 = ensemble_percentile(ensemble_results, 0.95); - - mio::unused(save_result(ensemble_results_p05, {0}, num_age_groups, - mio::path_join(result_dir, "Results_" + std::string("p05") + ".h5"))); - mio::unused(save_result(ensemble_results_p25, {0}, num_age_groups, - mio::path_join(result_dir, "Results_" + std::string("p25") + ".h5"))); - mio::unused(save_result(ensemble_results_p50, {0}, num_age_groups, - mio::path_join(result_dir, "Results_" + std::string("p50") + ".h5"))); - mio::unused(save_result(ensemble_results_p75, {0}, num_age_groups, - mio::path_join(result_dir, "Results_" + std::string("p75") + ".h5"))); - mio::unused(save_result(ensemble_results_p95, {0}, num_age_groups, - mio::path_join(result_dir, "Results_" + std::string("p95") + ".h5"))); + + std::vector>> ensemble_new, ensemble_detailed; + for (auto& run : ensemble_results) { + ensemble_new.push_back({std::move(run[0])}); + ensemble_detailed.push_back({std::move(run[1])}); + } + + // Percentiles for aggregated results + auto new_p05 = ensemble_percentile(ensemble_new, 0.05); + auto new_p25 = ensemble_percentile(ensemble_new, 0.25); + auto new_p50 = ensemble_percentile(ensemble_new, 0.50); + auto new_p75 = ensemble_percentile(ensemble_new, 0.75); + auto new_p95 = ensemble_percentile(ensemble_new, 0.95); + + mio::unused(save_result(new_p05, {0}, num_age_groups, mio::path_join(result_directory_standard, "Results_p05.h5"))); + mio::unused(save_result(new_p25, {0}, num_age_groups, mio::path_join(result_directory_standard, "Results_p25.h5"))); + mio::unused(save_result(new_p50, {0}, num_age_groups, mio::path_join(result_directory_standard, "Results_p50.h5"))); + mio::unused(save_result(new_p75, {0}, num_age_groups, mio::path_join(result_directory_standard, "Results_p75.h5"))); + mio::unused(save_result(new_p95, {0}, num_age_groups, mio::path_join(result_directory_standard, "Results_p95.h5"))); + + // Percentiles for detailed results + auto det_p05 = ensemble_percentile(ensemble_detailed, 0.05); + auto det_p25 = ensemble_percentile(ensemble_detailed, 0.25); + auto det_p50 = ensemble_percentile(ensemble_detailed, 0.50); + auto det_p75 = ensemble_percentile(ensemble_detailed, 0.75); + auto det_p95 = ensemble_percentile(ensemble_detailed, 0.95); + + mio::unused(save_result(det_p05, {0}, num_age_groups, mio::path_join(result_dir_detailed, "Results_p05.h5"))); + mio::unused(save_result(det_p25, {0}, num_age_groups, mio::path_join(result_dir_detailed, "Results_p25.h5"))); + mio::unused(save_result(det_p50, {0}, num_age_groups, mio::path_join(result_dir_detailed, "Results_p50.h5"))); + mio::unused(save_result(det_p75, {0}, num_age_groups, mio::path_join(result_dir_detailed, "Results_p75.h5"))); + mio::unused(save_result(det_p95, {0}, num_age_groups, mio::path_join(result_dir_detailed, "Results_p95.h5"))); } mio::mpi::finalize(); diff --git a/cpp/models/abm/common_abm_loggers.h b/cpp/models/abm/common_abm_loggers.h index 0d91c69c24..b0a40d5959 100644 --- a/cpp/models/abm/common_abm_loggers.h +++ b/cpp/models/abm/common_abm_loggers.h @@ -189,6 +189,67 @@ struct LogInfectionState : mio::LogAlways { } }; +struct LogInfectionStatePerAgeGroup : mio::LogAlways { + using Type = std::pair; + /** + * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. + * @param[in] sim The simulation of the abm. + * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. + */ + static Type log(const mio::abm::Simulation<>& sim) + { + + Eigen::VectorXd sum = Eigen::VectorXd::Zero( + Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_model().parameters.get_num_groups())); + const auto curr_time = sim.get_time(); + const auto persons = sim.get_model().get_persons(); + + + for (auto i = size_t(0); i < persons.size(); ++i) { + auto& p = persons[i]; + auto index = (((size_t)(mio::abm::InfectionState::Count)) * ((uint32_t)p.get_age().get())) + + ((uint32_t)p.get_infection_state(curr_time)); + // PRAGMA_OMP(atomic) + sum[index] += 1; + } + return std::make_pair(curr_time, sum); + } +}; + +struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { + using Type = std::pair; + /** + * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. + * @param[in] sim The simulation of the abm. + * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. + */ + static Type log(const mio::abm::Simulation<>& sim) + { + + Eigen::VectorXd sum = Eigen::VectorXd::Zero( + Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_model().parameters.get_num_groups())); + auto curr_time = sim.get_time(); + auto prev_time = sim.get_prev_time(); + const auto persons = sim.get_model().get_persons(); + + + for (auto i = size_t(0); i < persons.size(); ++i) { + auto& p = persons[i]; + + + if ((p.get_infection_state(prev_time) != mio::abm::InfectionState::Exposed) && + (p.get_infection_state(curr_time) == mio::abm::InfectionState::Exposed)) { + auto index = (((size_t)(mio::abm::LocationType::Count)) * ((uint32_t)p.get_age().get())) + + ((uint32_t)p.get_location_type()); + sum[index] += 1; + + } + } + return std::make_pair(curr_time, sum); + } +}; + + /** * @brief This is like the DataWriterToMemory, but it only logs time series data. * @tparam Loggers The loggers that are used to log data. The loggers must return a touple with a TimePoint and a value. diff --git a/cpp/models/abm/result_simulation.h b/cpp/models/abm/result_simulation.h index 2f1d02131a..02aeffe069 100644 --- a/cpp/models/abm/result_simulation.h +++ b/cpp/models/abm/result_simulation.h @@ -45,19 +45,33 @@ class ResultSimulation : public Simulation */ void advance(TimePoint tmax) { - Simulation::advance(tmax, history); + Simulation::advance(tmax, history, history_detailed); } /** * @brief Return the simulation result aggregated by infection states. */ - const mio::TimeSeries& get_result() const + mio::TimeSeries get_result() const { - return get<0>(history.get_log()); + return std::get<0>(history.get_log()); } + /** + * @brief Return the detailed simulation result aggregated by infection states. + */ + mio::TimeSeries get_result_detailed() const + { + return std::get<0>(history_detailed.get_log()); + } + + mio::History history{ - Eigen::Index(InfectionState::Count)}; ///< History used to create the result TimeSeries. + Eigen::Index(InfectionState::Count)}; + + mio::History history_detailed{ + Eigen::Index(LocationType::Count) * this->get_model().parameters.get_num_groups()}; + + }; } // namespace abm diff --git a/cpp/models/abm/simulation.h b/cpp/models/abm/simulation.h index 8e6ff52839..283fcd26bd 100644 --- a/cpp/models/abm/simulation.h +++ b/cpp/models/abm/simulation.h @@ -46,6 +46,7 @@ class Simulation Simulation(TimePoint t0, Model&& model) : m_model(std::move(model)) , m_t(t0) + , m_t_prev(t0-hours(1)) , m_dt(hours(1)) { } @@ -85,6 +86,15 @@ class Simulation return m_t; } + /** + * @brief Get the previous time of the Simulation. + */ + TimePoint get_prev_time() const + { + return m_t_prev; + } + + /** * @brief Get the Model that this Simulation evolves. */ @@ -103,11 +113,14 @@ class Simulation { auto dt = std::min(m_dt, tmax - m_t); m_model.evolve(m_t, dt); + m_t_prev = m_t; m_t += m_dt; + } Model m_model; ///< The Model to simulate. TimePoint m_t; ///< The current TimePoint of the Simulation. + TimePoint m_t_prev; ///< The previous TimePoint of the Simulation. TimeSpan m_dt; ///< The length of the time steps. }; diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py new file mode 100644 index 0000000000..3a4db0f75c --- /dev/null +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -0,0 +1,406 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import sys +import argparse +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import h5py +from datetime import datetime +from scipy.ndimage import gaussian_filter1d + + +# Module for plotting number of agents per infection state and number of infected agents per location type from ABM results. +# This module provides functions to load and visualize infection states and +# location types from simulation results of the agent-based model (ABM) stored in HDF5 format. + +# The used Loggers are: +# LogInfectionStatePerAgeGroup +# LogInfectionPerLocationTypePerAgeGroup +# The output of the loggers of several runs is stored in HDF5 files using mio::save_results in mio/io/result_io.h, see abm_history_object.cpp. + +# Adjust these as needed. +state_labels = { + 1: 'Exposed', + 2: 'I_Asymp', + 3: 'I_Symp', + 4: 'I_Severe', + 5: 'I_Critical', + 7: 'Dead' +} + +age_groups = ['Group1', 'Group2', 'Group3', 'Group4', + 'Group5', 'Group6', 'Total'] + +age_groups_dict = { + 'Group1': 'Ages 0-4', + 'Group2': 'Ages 5-14', + 'Group3': 'Ages 15-34', + 'Group4': 'Ages 35-59', + 'Group5': 'Ages 60-79', + 'Group6': 'Ages 80+', + 'Total': 'All Ages' +} + +location_type_labels = { + 0: 'Home', + 1: 'School', + 2: 'Work', + 3: 'SocialEvent', + 4: 'BasicsShop', + 5: 'Hospital', + 6: 'ICU' +} + + +def load_h5_results(base_path, percentile): + """ Reads HDF5 results for a given group and percentile. + + @param[in] base_path Path to results directory. + @param[in] percentile Subdirectory for percentile (e.g. 'p50'). + @return Dictionary with data arrays. Keys are dataset names from the HDF5 file + (e.g., 'Time', 'Total', age group names like 'Group1', 'Group2', etc.). + Values are numpy arrays containing the corresponding time series data. + """ + # Try flat structure first (e.g., Results_p50.h5) + file_path_flat = os.path.join(base_path, f"Results_{percentile}.h5") + # Try subdirectory structure (e.g., p50/Results.h5) + file_path_nested = os.path.join(base_path, percentile, "Results.h5") + + # Determine which file exists + if os.path.exists(file_path_flat): + file_path = file_path_flat + elif os.path.exists(file_path_nested): + file_path = file_path_nested + else: + raise FileNotFoundError( + f"Could not find percentile results file. Tried:\n" + f" - {file_path_flat}\n" + f" - {file_path_nested}\n\n" + f"Expected directory structure (option 1):\n" + f" {base_path}/\n" + f" Results_p05.h5\n" + f" Results_p25.h5\n" + f" Results_p50.h5\n" + f" Results_p75.h5\n" + f" Results_p95.h5\n\n" + f"Or (option 2):\n" + f" {base_path}/\n" + f" p05/Results.h5\n" + f" p25/Results.h5\n" + f" p50/Results.h5\n" + f" p75/Results.h5\n" + f" p95/Results.h5\n" + ) + + with h5py.File(file_path, 'r') as f: + data = {k: v[()] for k, v in f['0'].items()} + return data + + +def plot_infections_loc_types_average( + path_to_loc_types, + start_date='2021-03-01', + colormap='Set1', + smooth_sigma=1, + rolling_window=24, + xtick_step=150): + """ Plots rolling sum of new infections per 24 hours location type for the median run. + + @param[in] base_path Path to results directory. + @param[in] start_date Start date as string. + @param[in] colormap Matplotlib colormap. + @param[in] smooth_sigma Sigma for Gaussian smoothing. + @param[in] rolling_window Window size for rolling sum. + @param[in] xtick_step Step size for x-axis ticks. + """ + # Load data + p50 = load_h5_results(path_to_loc_types, "p50") + time = p50['Time'] + total_50 = p50['Total'] + + plt.figure('Infection_location_types') + plt.title( + 'Number of new infections per location type for the median run, rolling sum over 24 hours') + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + + # Check if data dimensions match expected location types + num_cols = total_50.shape[1] if len(total_50.shape) > 1 else 1 + num_location_types = len(location_type_labels) + + if num_cols < num_location_types: + print(f"Warning: Data has {num_cols} columns but {num_location_types} location types defined.") + print(f"Only plotting first {num_cols} location types.") + + for idx, i in enumerate(location_type_labels.keys()): + if i >= num_cols: + break # Skip if we don't have data for this location type + color = color_plot[i % len(color_plot)] if i < len( + color_plot) else "black" + # Sum up every 24 hours, then smooth + indexer = pd.api.indexers.FixedForwardWindowIndexer( + window_size=rolling_window) + y = pd.DataFrame(total_50[:, i]).rolling( + window=indexer, min_periods=1).sum().to_numpy() + y = y[0::rolling_window].flatten() + y = gaussian_filter1d(y, sigma=smooth_sigma, mode='nearest') + plt.plot(time[0::rolling_window], y, color=color, linewidth=2.5, label=location_type_labels[i]) + + plt.legend() # Auto-generate legend from plot labels + _format_x_axis(time, start_date, xtick_step) + plt.xlabel('Date') + plt.ylabel('Number of individuals') + + +def plot_infection_states_results( + path_to_infection_states, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150, + show90=False +): + """ Loads and plots infection state results. + + @param[in] path_to_infection_states Path to results directory containing infection state data. + @param[in] start_date Start date as string (YYYY-MM-DD format). + @param[in] colormap Matplotlib colormap name. + @param[in] xtick_step Step size for x-axis ticks. + @param[in] show90 If True, plot 90% percentile (5% and 95%) in addition to 50% percentile. + """ + + # Load data + p50 = load_h5_results(path_to_infection_states, "p50") + p25 = load_h5_results(path_to_infection_states, "p25") + p75 = load_h5_results(path_to_infection_states, "p75") + time = p50['Time'] + total_50 = p50['Total'] + total_25 = p25['Total'] + total_75 = p75['Total'] + p05 = p95 = None + total_05 = total_95 = None + if show90: + total_95 = load_h5_results(path_to_infection_states, "p95") + total_05 = load_h5_results(path_to_infection_states, "p05") + p95 = total_95['Total'] + p05 = total_05['Total'] + + plot_infection_states_by_age_group( + time, p50, p25, p75, colormap, + p05_bs=total_05 if show90 else None, + p95_bs=total_95 if show90 else None, + show90=show90 + ) + plot_infection_states(time, total_50, total_25, + total_75, start_date, colormap, xtick_step, + y05=p05, y95=p95, show_90=show90) + + +def plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150, + y05=None, y95=None, show_90=False): + """ Plots infection states with percentile bands. + + @param[in] x Time array for x-axis. + @param[in] y50 50th percentile data array. + @param[in] y25 25th percentile data array. + @param[in] y75 75th percentile data array. + @param[in] start_date Start date as string (YYYY-MM-DD format). + @param[in] colormap Matplotlib colormap name. + @param[in] xtick_step Step size for x-axis ticks. + @param[in] y05 5th percentile data array (optional). + @param[in] y95 95th percentile data array (optional). + @param[in] show_90 If True, plot 90% percentile bands in addition to 50% percentile. + """ + + plt.figure('Infection_states') + + plt.title('Infection states with 50% percentile') + if show_90: + plt.title('Infection states with 50% and 90% percentiles') + + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + + states_plot = list(state_labels.keys()) + + for i in states_plot: + plt.plot(x, y50[:, i], color=color_plot[i], + linewidth=2.5, label=state_labels[i]) + # needs to be after the plot calls + plt.legend([state_labels[i] for i in states_plot]) + for i in states_plot: + plt.plot(x, y25[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.plot(x, y75[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.fill_between(x, y25[:, i], y75[:, i], + alpha=0.2, color=color_plot[i]) + # Optional: 90% percentile + if show_90 and y05 is not None and y95 is not None: + plt.plot(x, y05[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.plot(x, y95[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.fill_between(x, y05[:, i], y95[:, i], + # More transparent + alpha=0.25, color=color_plot[i]) + + _format_x_axis(x, start_date, xtick_step) + plt.xlabel('Date') + plt.ylabel('Number of individuals') + + +def plot_infection_states_by_age_group( + x, p50_bs, p25_bs, p75_bs, colormap='Set1', + p05_bs=None, p95_bs=None, show90=False +): + """ Plots infection states for each age group, with optional 90% percentile. + + @param[in] x Time array for x-axis. + @param[in] p50_bs Dictionary containing 50th percentile data for all age groups. + @param[in] p25_bs Dictionary containing 25th percentile data for all age groups. + @param[in] p75_bs Dictionary containing 75th percentile data for all age groups. + @param[in] colormap Matplotlib colormap name. + @param[in] p05_bs Dictionary containing 5th percentile data for all age groups (optional). + @param[in] p95_bs Dictionary containing 95th percentile data for all age groups (optional). + @param[in] show90 If True, plot 90% percentile bands in addition to 50% percentile. + """ + + # Dynamically detect available age groups from the data + available_groups = [key for key in p50_bs.keys() if key.startswith('Group') or key == 'Total'] + # Sort to ensure Group1, Group2, ..., Total order + available_groups = sorted(available_groups, key=lambda x: (x != 'Total', x)) + + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + n_states = len(state_labels) + fig, ax = plt.subplots( + n_states, len(available_groups), constrained_layout=True, figsize=(20, 3 * n_states)) + + for col_idx, group in enumerate(available_groups): + y50 = p50_bs[group] + y25 = p25_bs[group] + y75 = p75_bs[group] + y05 = p05_bs[group] if (show90 and p05_bs is not None) else None + y95 = p95_bs[group] if (show90 and p95_bs is not None) else None + + # Get group label, using default if not in predefined dict + group_label = age_groups_dict.get(group, group) + + for row_idx, (state_idx, label) in enumerate(state_labels.items()): + _plot_state( + ax[row_idx, col_idx], x, y50[:, state_idx], y25[:, + state_idx], y75[:, state_idx], + color_plot[col_idx], f'#{label}, {group_label}', + y05=y05[:, state_idx] if y05 is not None else None, + y95=y95[:, state_idx] if y95 is not None else None, + show90=show90 + ) + # The legend should say: solid line = median, dashed line = 25% and 75% perc. and if show90 is True, dotted line = 5%, 25%, 75%, 95% perc. + perc_string = '25/75%' if not show90 else '5/25/75/95%' + ax[row_idx, col_idx].legend( + ['Median', f'{perc_string} perc.'], + loc='upper left', fontsize=8) + + # Add y label for leftmost column + if col_idx == 0: + ax[row_idx, col_idx].set_ylabel('Number of individuals') + + # Add x label for bottom row + if row_idx == n_states - 1: + ax[row_idx, col_idx].set_xlabel('Time (days)') + + string_short = ' and 90%' if show90 else '' + fig.suptitle( + 'Infection states per age group with 50%' + string_short + ' percentile', + fontsize=16) + + +def _plot_state(ax, x, y50, y25, y75, color, title, y05=None, y95=None, show90=False): + """ Helper to plot a single state with fill_between and optional 90% percentile. """ + ax.plot(x, y50, color=color, label='Median') + ax.fill_between(x, y25, y75, alpha=0.5, color=color) + if show90 and y05 is not None and y95 is not None: + ax.plot(x, y05, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.plot(x, y95, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.fill_between(x, y05, y95, alpha=0.15, color=color) + ax.tick_params(axis='y') + ax.set_title(title) + + +def _format_x_axis(x, start_date, xtick_step): + """ Helper to format x-axis as dates. """ + start = datetime.strptime(start_date, '%Y-%m-%d') + xx = [start + pd.Timedelta(days=int(i)) for i in x] + xx_str = [dt.strftime('%Y-%m-%d') for dt in xx] + plt.gca().set_xticks(x[::xtick_step]) + plt.gca().set_xticklabels(xx_str[::xtick_step]) + plt.gcf().autofmt_xdate() + + +def main(): + """ Main function for CLI usage. """ + parser = argparse.ArgumentParser( + description="Plot infection state and location type results.") + parser.add_argument("--path-to-infection-states", + help="Path to infection states results") + parser.add_argument("--path-to-loc-types", + help="Path to location types results") + parser.add_argument("--start-date", type=str, default='2021-03-01', + help="Simulation start date (YYYY-MM-DD)") + parser.add_argument("--colormap", type=str, + default='Set1', help="Matplotlib colormap") + parser.add_argument("--xtick-step", type=int, + default=150, help="Step for x-axis ticks (usually hours)") + parser.add_argument("--90percentile", action="store_true", + help="If set, plot 90% percentile as well") + args = parser.parse_args() + + if not args.path_to_infection_states and not args.path_to_loc_types: + print("Please provide a path to infection states or location types results.") + return + + if args.path_to_infection_states: + plot_infection_states_results( + path_to_infection_states=args.path_to_infection_states, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step, + show90=True + ) + + if args.path_to_loc_types: + plot_infections_loc_types_average( + path_to_loc_types=args.path_to_loc_types, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step) + + plt.show() + + +if __name__ == "__main__": + main() diff --git a/pycode/memilio-plot/tests/test_plot_plotAbmInfectionStates.py b/pycode/memilio-plot/tests/test_plot_plotAbmInfectionStates.py new file mode 100644 index 0000000000..4bc5dde5e1 --- /dev/null +++ b/pycode/memilio-plot/tests/test_plot_plotAbmInfectionStates.py @@ -0,0 +1,304 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import unittest +from unittest.mock import patch, MagicMock +import numpy as np +import pandas as pd + +import memilio.plot.plotAbmInfectionStates as abm + + +class TestPlotAbmInfectionStates(unittest.TestCase): + + @patch('memilio.plot.plotAbmInfectionStates.h5py.File') + def test_load_h5_results(self, mock_h5file): + mock_group = {'Time': np.arange(10), 'Total': np.ones((10, 8))} + mock_h5file().__enter__().get.return_value = {'0': mock_group} + mock_h5file().__enter__().items.return_value = [('0', mock_group)] + mock_h5file().__enter__().__getitem__.return_value = mock_group + with patch('memilio.plot.plotAbmInfectionStates.h5py.File', mock_h5file): + result = abm.load_h5_results('dummy_path', 'p50') + assert 'Time' in result + assert 'Total' in result + np.testing.assert_array_equal(result['Time'], np.arange(10)) + np.testing.assert_array_equal(result['Total'], np.ones((10, 8))) + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + @patch('memilio.plot.plotAbmInfectionStates.gaussian_filter1d', side_effect=lambda x, sigma, mode: x) + @patch('memilio.plot.plotAbmInfectionStates.pd.DataFrame') + def test_plot_infections_loc_types_average(self, mock_df, mock_gauss, mock_matplotlib, mock_load): + mock_load.return_value = { + 'Time': np.arange(48), 'Total': np.ones((48, 7))} + mock_df.return_value.rolling.return_value.sum.return_value.to_numpy.return_value = np.ones( + (48, 1)) + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*7 + + # Patch plt methods + with patch.object(abm.plt, 'gca') as mock_gca, \ + patch.object(abm.plt, 'figure') as mock_figure, \ + patch.object(abm.plt, 'title') as mock_title, \ + patch.object(abm.plt, 'legend') as mock_legend, \ + patch.object(abm.plt, 'xlabel') as mock_xlabel, \ + patch.object(abm.plt, 'ylabel') as mock_ylabel, \ + patch.object(abm.plt, 'show') as mock_show: + + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + + abm.plot_infections_loc_types_average('dummy_path') + + # Test basic plotting functionality + assert mock_ax.plot.called + assert mock_ax.set_xticks.called + assert mock_ax.set_xticklabels.called + + # Test figure settings + self.assertEqual(mock_figure.call_count, 2, + "figure should be called twice") + # Verify first call is with the figure name + mock_figure.assert_any_call('Infection_location_types') + # Verify second call is without arguments (for autofmt_xdate) + mock_figure.assert_any_call() + mock_title.assert_called_once_with( + 'Number of new infections per location type for the median run, rolling sum over 24 hours') + mock_legend.assert_called_once() + mock_xlabel.assert_called_once_with('Date') + mock_ylabel.assert_called_once_with('Number of individuals') + mock_show.assert_called_once() + + # Verify legend was called with location type labels + legend_call_args = mock_legend.call_args + if legend_call_args and legend_call_args[0]: + legend_labels = legend_call_args[0][0] + # Should contain the location type labels from the function + expected_labels = list(abm.location_type_labels.values()) + self.assertEqual(legend_labels, expected_labels) + + # Verify that plot was called for each location type + plot_calls = mock_ax.plot.call_args_list + # Should plot 7 location types + self.assertEqual(len(plot_calls), 7, + "Should plot all 7 location types") + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states_by_age_group') + def test_plot_infection_states_results(self, mock_indiv, mock_states, mock_load): + test_data = { + 'Time': np.arange(10), + 'Total': np.ones((10, 8)), + 'Group1': np.ones((10, 8)), + 'Group2': np.ones((10, 8)), + 'Group3': np.ones((10, 8)), + 'Group4': np.ones((10, 8)), + 'Group5': np.ones((10, 8)), + 'Group6': np.ones((10, 8)) + } + mock_load.side_effect = [test_data, test_data, test_data] + + abm.plot_infection_states_results('dummy_path') + + # Verify functions are called with correct arguments + self.assertEqual(mock_load.call_count, 3, + "load_h5_results should be called 3 times (p25, p50, p75)") + + # Check that load_h5_results was called with correct percentiles + expected_calls = [ + unittest.mock.call('dummy_path', 'p25'), + unittest.mock.call('dummy_path', 'p50'), + unittest.mock.call('dummy_path', 'p75') + ] + mock_load.assert_has_calls(expected_calls, any_order=True) + + # Verify plotting functions were called with the loaded data + mock_states.assert_called_once() + mock_indiv.assert_called_once() + + # Check that plot_infection_states was called with correct data structure + states_call_args = mock_states.call_args + self.assertIsNotNone(states_call_args) + # x, y50, y25, y75, y05, y95, show_90 + self.assertEqual(len(states_call_args[0]), 7) + + # Check that plot_infection_states_by_age_group was called with correct data structure + indiv_call_args = mock_indiv.call_args + self.assertIsNotNone(indiv_call_args) + # x, p50_bs, p25_bs, p75_bs, p05_bs + self.assertEqual(len(indiv_call_args[0]), 5) + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states(self, mock_matplotlib): + x = np.arange(10) + y50 = np.ones((10, 8)) + y25 = np.zeros((10, 8)) + y75 = np.ones((10, 8))*2 + y05 = np.ones((10, 8))*-1 + y95 = np.ones((10, 8))*3 + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 + + # Patch plt.gca().plot and fill_between + with patch.object(abm.plt, 'gca') as mock_gca: + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + + abm.plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=2, + y05=y05, + y95=y95, + show_90=True + ) + + # Verify plot was called with correct data + self.assertTrue(mock_ax.plot.called, "Plot should be called") + plot_calls = mock_ax.plot.call_args_list + # From the actual function: 6 infection states * 5 plot calls each (median + 4 percentile lines) = 30 + self.assertEqual(len(plot_calls), 30, + "Should plot 30 lines total (6 states * 5 lines each)") + + # Verify fill_between was called for confidence intervals + self.assertTrue(mock_ax.fill_between.called, + "fill_between should be called for confidence intervals") + fill_calls = mock_ax.fill_between.call_args_list + # Should have calls for both 50% and 90% confidence intervals if show_90=True + self.assertEqual(len(fill_calls), 12, + "Should have 12 fill_between calls (6 states * 2 confidence intervals each)") + + # Verify axis formatting with correct parameters + mock_ax.set_xticks.assert_called_once() + xticks_call = mock_ax.set_xticks.call_args[0][0] + expected_ticks = np.arange(0, len(x), 2) # xtick_step=2 + np.testing.assert_array_equal(xticks_call, expected_ticks) + + # Verify xticklabels are set correctly + mock_ax.set_xticklabels.assert_called_once() + xticklabels_call = mock_ax.set_xticklabels.call_args[0][0] + self.assertEqual(len(xticklabels_call), len(expected_ticks)) + + # Verify that the xticklabels contain the correct date formatting + expected_dates = ['2021-03-01', '2021-03-03', + '2021-03-05', '2021-03-07', '2021-03-09'] + for i, label in enumerate(xticklabels_call): + self.assertEqual(str(label), expected_dates[i], + f"Label at position {i} should be {expected_dates[i]}") + + # Verify that start_date is used in label formatting + if len(xticklabels_call) > 0: + # All labels should contain date strings when start_date is provided + self.assertTrue(all('2021' in str(label) + for label in xticklabels_call)) + # First label should match the start_date + self.assertEqual(str(xticklabels_call[0]), '2021-03-01') + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states_by_age_group(self, mock_matplotlib): + x = np.arange(10) + group_data = np.ones((10, 8)) + groups = ['Group1', 'Group2', 'Group3', + 'Group4', 'Group5', 'Group6', 'Total'] + p50_bs = {g: group_data for g in groups} + p25_bs = {g: group_data for g in groups} + p75_bs = {g: group_data for g in groups} + p05_bs = {g: group_data*-1 for g in groups} + p95_bs = {g: group_data*3 for g in groups} + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 + + # Patch plt.subplots to return a grid of MagicMock axes + with patch.object(abm.plt, 'subplots') as mock_subplots: + fig_mock = MagicMock() + # From the actual function: n_states (6) rows, len(age_groups) (7) columns + ax_mock = np.empty((6, 7), dtype=object) + for i in range(6): + for j in range(7): + ax_mock[i, j] = MagicMock() + mock_subplots.return_value = (fig_mock, ax_mock) + + abm.plot_infection_states_by_age_group( + x, p50_bs, p25_bs, p75_bs, + colormap='Set1', + p05_bs=p05_bs, + p95_bs=p95_bs, + show90=True + ) + + # Verify that subplots was called to create a grid of axes + mock_subplots.assert_called_once() + subplot_call = mock_subplots.call_args + + # Verify that subplots was called with reasonable dimensions + if subplot_call and len(subplot_call[0]) >= 2: + rows, cols = subplot_call[0][:2] + self.assertEqual( + rows, 6, "Should have 6 rows (number of infection states)") + self.assertEqual( + cols, 7, "Should have 7 columns (number of age groups)") + + # Verify figure title is set + fig_mock.suptitle.assert_called_once() + + def test__format_x_axis(self): + test_x = np.arange(10) + test_start_date = '2021-03-01' + test_xtick_step = 2 + + with patch('memilio.plot.plotAbmInfectionStates.plt') as mock_plt: + mock_ax = MagicMock() + mock_plt.gca.return_value = mock_ax + + abm._format_x_axis(test_x, test_start_date, test_xtick_step) + + # Verify that gca was called to get current axis (it's called twice in the function) + self.assertEqual(mock_plt.gca.call_count, 2, + "gca should be called twice") + + # Verify that gcf was called to get current figure + mock_plt.gcf.assert_called_once() + + # Verify axis formatting methods were called + self.assertTrue(mock_ax.set_xticks.called, + "set_xticks should be called") + self.assertTrue(mock_ax.set_xticklabels.called, + "set_xticklabels should be called") + + # Verify correct tick positions + xticks_call = mock_ax.set_xticks.call_args + if xticks_call and xticks_call[0]: + tick_positions = xticks_call[0][0] + expected_positions = np.arange(0, len(test_x), test_xtick_step) + np.testing.assert_array_equal( + tick_positions, expected_positions) + + # Verify tick labels are date strings + xticklabels_call = mock_ax.set_xticklabels.call_args + if xticklabels_call and xticklabels_call[0]: + tick_labels = xticklabels_call[0][0] + self.assertIsInstance(tick_labels, (list, np.ndarray)) + if len(tick_labels) > 0: + # Should contain date information when start_date is provided + self.assertTrue(any('2021' in str(label) + for label in tick_labels)) + + +if __name__ == '__main__': + unittest.main()