From 96f5d9d7184dfa116b7954b55bcf67cc80ac7bf8 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 17 Jun 2025 14:18:42 +0200 Subject: [PATCH 01/34] Added debug infrastructure. ALl tests passing --- src/dcegm/backward_induction.py | 59 +----------- src/dcegm/final_periods.py | 74 ++++++++++++--- src/dcegm/interfaces/inspect_solution.py | 66 ++++++++++++++ src/dcegm/pre_processing/sol_container.py | 52 +++++++++++ src/dcegm/solve_single_period.py | 89 +++++++++++++------ .../test_two_period_continuous_experience.py | 19 ++-- 6 files changed, 261 insertions(+), 98 deletions(-) create mode 100644 src/dcegm/interfaces/inspect_solution.py create mode 100644 src/dcegm/pre_processing/sol_container.py diff --git a/src/dcegm/backward_induction.py b/src/dcegm/backward_induction.py index 470b46f7..94c634ee 100644 --- a/src/dcegm/backward_induction.py +++ b/src/dcegm/backward_induction.py @@ -1,6 +1,5 @@ """Interface for the DC-EGM algorithm.""" -from functools import partial from typing import Any, Callable, Dict, Tuple import jax.lax @@ -9,6 +8,7 @@ from dcegm.final_periods import solve_last_two_periods from dcegm.law_of_motion import calc_cont_grids_next_period +from dcegm.pre_processing.sol_container import create_solution_container from dcegm.solve_single_period import solve_single_period @@ -91,16 +91,13 @@ def backward_induction( model_funcs=model_funcs, ) - # Create solution containers. The 20 percent extra in wealth grid needs to go - # into tuning parameters - n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"] ( value_solved, policy_solved, endog_grid_solved, ) = create_solution_container( model_config=model_config, - model_structure=model_structure, + n_state_choices=model_structure["state_choice_space"].shape[0], ) # Solve the last two periods. We do this separately as the marginal utility of @@ -120,6 +117,7 @@ def backward_induction( value_solved=value_solved, policy_solved=policy_solved, endog_grid_solved=endog_grid_solved, + debug_info=None, ) # If it is a two period model we are done. @@ -135,6 +133,7 @@ def partial_single_period(carry, xs): cont_grids_next_period=cont_grids_next_period, model_funcs=model_funcs, income_shock_weights=income_shock_weights, + debug_info=None, ) for id_segment in range(batch_info["n_segments"]): @@ -192,53 +191,3 @@ def partial_single_period(carry, xs): policy_solved, endog_grid_solved, ) - - -def create_solution_container( - model_config: Dict[str, Any], - model_structure: Dict[str, Any], -): - """Create solution containers for value, policy, and endog_grid.""" - - # Read out grid size - n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"] - n_state_choices = model_structure["state_choice_space"].shape[0] - - # Check if second continuous state exists and read out array size - continuous_states_info = model_config["continuous_states_info"] - if continuous_states_info["second_continuous_exists"]: - n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"] - - value_solved = jnp.full( - (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - policy_solved = jnp.full( - (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - endog_grid_solved = jnp.full( - (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - else: - value_solved = jnp.full( - (n_state_choices, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - policy_solved = jnp.full( - (n_state_choices, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - endog_grid_solved = jnp.full( - (n_state_choices, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - - return value_solved, policy_solved, endog_grid_solved diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index e5a0a4e7..b9f3caf3 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -21,6 +21,7 @@ def solve_last_two_periods( value_solved, policy_solved, endog_grid_solved, + debug_info, ): """Solves the last two periods of the model. @@ -48,6 +49,16 @@ def solve_last_two_periods( """ + idx_state_choices_final_period = last_two_period_batch_info[ + "idx_state_choices_final_period" + ] + if debug_info is not None: + if "rescale_idx" in debug_info.keys(): + # If we want to rescale the idx, because we only solve part of the model, then to this first. + idx_state_choices_final_period = ( + idx_state_choices_final_period - debug_info["rescale_idx"] + ) + ( value_solved, policy_solved, @@ -55,9 +66,7 @@ def solve_last_two_periods( value_interp_final_period, marginal_utility_final_last_period, ) = solve_final_period( - idx_state_choices_final_period=last_two_period_batch_info[ - "idx_state_choices_final_period" - ], + idx_state_choices_final_period=idx_state_choices_final_period, idx_parent_states_final_period=last_two_period_batch_info[ "idxs_parent_states_final_period" ], @@ -86,7 +95,7 @@ def solve_last_two_periods( last_two_period_batch_info["state_choice_mat_final_period"], params ) - endog_grid, policy, value = solve_for_interpolated_values( + out_dict_second_last = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, state_choice_mat=last_two_period_batch_info[ @@ -104,19 +113,60 @@ def solve_last_two_periods( income_shock_weights=income_shock_weights, continuous_grids_info=continuous_states_info, model_funcs=model_funcs, + debug_info=debug_info, ) idx_second_last = last_two_period_batch_info["idx_state_choices_second_last_period"] - value_solved = value_solved.at[idx_second_last, ...].set(value) - policy_solved = policy_solved.at[idx_second_last, ...].set(policy) - endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set(endog_grid) + # If we do not call the function in debug mode. Assign everything and return + if debug_info is None: + value_solved = value_solved.at[idx_second_last, ...].set( + out_dict_second_last["value"] + ) + policy_solved = policy_solved.at[idx_second_last, ...].set( + out_dict_second_last["policy"] + ) + endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set( + out_dict_second_last["endog_grid"] + ) + return ( + value_solved, + policy_solved, + endog_grid_solved, + ) - return ( - value_solved, - policy_solved, - endog_grid_solved, - ) + else: + if "rescale_idx" in debug_info.keys(): + # If we want to rescale the idx, because we only solve part of the model, then to this first. + idx_rescaled_second_last = idx_second_last - debug_info["rescale_idx"] + # And then assign to the solution containers. + value_solved = value_solved.at[idx_rescaled_second_last, ...].set( + out_dict_second_last["value"] + ) + policy_solved = policy_solved.at[idx_rescaled_second_last, ...].set( + out_dict_second_last["policy"] + ) + endog_grid_solved = endog_grid_solved.at[idx_rescaled_second_last, ...].set( + out_dict_second_last["endog_grid"] + ) + + # If candidates are also needed to returned we return them additionally to the solution containers. + if debug_info["return_candidates"]: + return ( + value_solved, + policy_solved, + endog_grid_solved, + out_dict_second_last["value_candidates"], + out_dict_second_last["policy_candidates"], + out_dict_second_last["endog_grid_candidates"], + ) + + else: + return ( + value_solved, + policy_solved, + endog_grid_solved, + ) def solve_final_period( diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py new file mode 100644 index 00000000..87d71a9c --- /dev/null +++ b/src/dcegm/interfaces/inspect_solution.py @@ -0,0 +1,66 @@ +import jax.lax +import jax.numpy as jnp +import numpy as np + +from dcegm.final_periods import solve_last_two_periods +from dcegm.law_of_motion import calc_cont_grids_next_period +from dcegm.pre_processing.sol_container import create_solution_container + + +def partially_solve( + income_shock_draws_unscaled, + income_shock_weights, + model_config, + batch_info, + model_funcs, + model_structure, + params, + n_periods, + return_candidates=False, +): + """Partially solve the model for the last n_periods. + + This method allows for large models to only solve part of the model, to debug the solution process. + + Args: + params: Model parameters. + n_periods: Number of periods to solve. + return_candidates: If True, additionally return candidate solutions before applying the upper envelope. + + """ + + continuous_states_info = model_config["continuous_states_info"] + + cont_grids_next_period = calc_cont_grids_next_period( + model_structure=model_structure, + model_config=model_config, + income_shock_draws_unscaled=income_shock_draws_unscaled, + params=params, + model_funcs=model_funcs, + ) + + last_relevant_period = model_config["n_periods"] - n_periods - 1 + + relevant_state_choices_mask = ( + model_structure["state_choice_space"][:, 0] >= last_relevant_period + ) + relevant_state_choice_space = model_structure["state_choice_space"][ + relevant_state_choices_mask + ] + + ( + value_solved, + policy_solved, + endog_grid_solved, + ) = create_solution_container( + model_config=model_config, + n_state_choices=relevant_state_choice_space.shape[0], + ) + + if return_candidates: + (value_candidates, policy_candidates, endog_grid_candidates) = ( + create_solution_container( + model_config=model_config, + n_state_choices=relevant_state_choice_space.shape[0], + ) + ) diff --git a/src/dcegm/pre_processing/sol_container.py b/src/dcegm/pre_processing/sol_container.py new file mode 100644 index 00000000..8efca2f0 --- /dev/null +++ b/src/dcegm/pre_processing/sol_container.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + +from jax import numpy as jnp + + +def create_solution_container( + model_config: Dict[str, Any], + n_state_choices: int, +): + """Create solution containers for value, policy, and endog_grid.""" + + # Read out grid size + n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"] + + # Check if second continuous state exists and read out array size + continuous_states_info = model_config["continuous_states_info"] + if continuous_states_info["second_continuous_exists"]: + n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"] + + value_solved = jnp.full( + (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + policy_solved = jnp.full( + (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + endog_grid_solved = jnp.full( + (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + else: + value_solved = jnp.full( + (n_state_choices, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + policy_solved = jnp.full( + (n_state_choices, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + endog_grid_solved = jnp.full( + (n_state_choices, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + + return value_solved, policy_solved, endog_grid_solved diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index ce7a4cd7..3f8384a9 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -15,6 +15,7 @@ def solve_single_period( cont_grids_next_period, model_funcs, income_shock_weights, + debug_info, ): """Solve a single period of the model using DCEGM.""" (value_solved, policy_solved, endog_grid_solved) = carry @@ -60,29 +61,58 @@ def solve_single_period( state_choice_mat_child, params ) - endog_grid_state_choice, policy_state_choice, value_state_choice = ( - solve_for_interpolated_values( - value_interpolated=value_interpolated, - marginal_utility_interpolated=marginal_utility_interpolated, - state_choice_mat=state_choice_mat, - child_state_idxs=child_states_to_integrate_stochastic, - states_to_choices_child_states=child_state_choices_to_aggr_choice, - params=params, - taste_shock_scale=taste_shock_scale, - taste_shock_scale_is_scalar=taste_shock_scale_is_scalar, - income_shock_weights=income_shock_weights, - continuous_grids_info=continuous_grids_info, - model_funcs=model_funcs, - ) + out_dict_period = solve_for_interpolated_values( + value_interpolated=value_interpolated, + marginal_utility_interpolated=marginal_utility_interpolated, + state_choice_mat=state_choice_mat, + child_state_idxs=child_states_to_integrate_stochastic, + states_to_choices_child_states=child_state_choices_to_aggr_choice, + params=params, + taste_shock_scale=taste_shock_scale, + taste_shock_scale_is_scalar=taste_shock_scale_is_scalar, + income_shock_weights=income_shock_weights, + continuous_grids_info=continuous_grids_info, + model_funcs=model_funcs, + debug_info=debug_info, ) - value_solved = value_solved.at[state_choices_idxs, :].set(value_state_choice) - policy_solved = policy_solved.at[state_choices_idxs, :].set(policy_state_choice) - endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( - endog_grid_state_choice - ) + # If we are not in the debug mode, we only return the solved values. + if debug_info is None: + + value_solved = value_solved.at[state_choices_idxs, :].set( + out_dict_period["value"] + ) + policy_solved = policy_solved.at[state_choices_idxs, :].set( + out_dict_period["policy"] + ) + endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( + out_dict_period["endog_grid"] + ) + carry = (value_solved, policy_solved, endog_grid_solved) - carry = (value_solved, policy_solved, endog_grid_solved) + else: + if "rescale_idx" in debug_info.keys(): + state_choices_idxs = state_choices_idxs - debug_info["rescale_idx"] + value_solved = value_solved.at[state_choices_idxs, :].set( + out_dict_period["value"] + ) + policy_solved = policy_solved.at[state_choices_idxs, :].set( + out_dict_period["policy"] + ) + endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( + out_dict_period["endog_grid"] + ) + if debug_info["return_candidates"]: + carry = ( + value_solved, + policy_solved, + endog_grid_solved, + out_dict_period["value_candidates"], + out_dict_period["policy_candidates"], + out_dict_period["endog_grid_candidates"], + ) + else: + carry = (value_solved, policy_solved, endog_grid_solved) return carry, () @@ -99,6 +129,7 @@ def solve_for_interpolated_values( income_shock_weights, continuous_grids_info, model_funcs, + debug_info, ): # EGM step 2) # Aggregate the marginal utilities and expected values over all state-choice @@ -150,12 +181,20 @@ def solve_for_interpolated_values( has_second_continuous_state=continuous_grids_info["second_continuous_exists"], compute_upper_envelope_for_state_choice=model_funcs["compute_upper_envelope"], ) + out_dict = { + "endog_grid": endog_grid_state_choice, + "policy": policy_state_choice, + "value": value_state_choice, + } - return ( - endog_grid_state_choice, - policy_state_choice, - value_state_choice, - ) + # If candidates are requested, we additionally return them in the output dictionary. + if debug_info is not None: + if debug_info["return_candidates"]: + out_dict["endog_grid_candidates"] = endog_grid_candidate + out_dict["policy_candidates"] = policy_candidate + out_dict["value_candidates"] = value_candidate + + return out_dict def run_upper_envelope( diff --git a/tests/test_two_period_continuous_experience.py b/tests/test_two_period_continuous_experience.py index 98c7fe23..728fb7fc 100644 --- a/tests/test_two_period_continuous_experience.py +++ b/tests/test_two_period_continuous_experience.py @@ -9,10 +9,10 @@ import dcegm import dcegm.toy_models as toy_models -from dcegm.backward_induction import create_solution_container from dcegm.final_periods import solve_final_period from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.numerical_integration import quadrature_legendre +from dcegm.pre_processing.sol_container import create_solution_container from dcegm.solve_single_period import solve_for_interpolated_values MAX_WEALTH = 50 @@ -290,7 +290,7 @@ def create_test_inputs(): endog_grid_solved=endog_grid_solved, ) - endog_grid, policy, value_second_last = solve_for_interpolated_values( + out_dict_second_last = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, state_choice_mat=last_two_period_batch_info_cont[ @@ -308,15 +308,22 @@ def create_test_inputs(): income_shock_weights=income_shock_weights, continuous_grids_info=model_config["continuous_states_info"], model_funcs=model_funcs_cont, + debug_info=None, ) idx_second_last = last_two_period_batch_info_cont[ "idx_state_choices_second_last_period" ] - value_solved = value_solved.at[idx_second_last, ...].set(value_second_last) - policy_solved = policy_solved.at[idx_second_last, ...].set(policy) - endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set(endog_grid) + value_solved = value_solved.at[idx_second_last, ...].set( + out_dict_second_last["value"] + ) + policy_solved = policy_solved.at[idx_second_last, ...].set( + out_dict_second_last["policy"] + ) + endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set( + out_dict_second_last["endog_grid"] + ) return ( value_solved, @@ -442,8 +449,8 @@ def _get_solve_last_two_periods_args(model, params, has_second_continuous_state) # Create solution containers for value, policy, and endogenous grids value_solved, policy_solved, endog_grid_solved = create_solution_container( - model_structure=model_structure, model_config=model_config, + n_state_choices=model_structure["state_choice_space"].shape[0], ) return ( From 79e164056c93fd9c778e473697d0c9a09c63971f Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 17 Jun 2025 15:12:49 +0200 Subject: [PATCH 02/34] Adding debug code. --- src/dcegm/backward_induction.py | 4 +- src/dcegm/interfaces/inspect_solution.py | 53 +++++++++++++++++-- src/dcegm/pre_processing/sol_container.py | 9 +--- .../test_two_period_continuous_experience.py | 3 +- 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/dcegm/backward_induction.py b/src/dcegm/backward_induction.py index 94c634ee..a75a9a2b 100644 --- a/src/dcegm/backward_induction.py +++ b/src/dcegm/backward_induction.py @@ -96,7 +96,9 @@ def backward_induction( policy_solved, endog_grid_solved, ) = create_solution_container( - model_config=model_config, + continuous_states_info=model_config["continuous_states_info"], + # Read out grid size + n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"], n_state_choices=model_structure["state_choice_space"].shape[0], ) diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py index 87d71a9c..94f16a3d 100644 --- a/src/dcegm/interfaces/inspect_solution.py +++ b/src/dcegm/interfaces/inspect_solution.py @@ -39,7 +39,7 @@ def partially_solve( model_funcs=model_funcs, ) - last_relevant_period = model_config["n_periods"] - n_periods - 1 + last_relevant_period = model_config["n_periods"] - n_periods relevant_state_choices_mask = ( model_structure["state_choice_space"][:, 0] >= last_relevant_period @@ -53,14 +53,61 @@ def partially_solve( policy_solved, endog_grid_solved, ) = create_solution_container( - model_config=model_config, + continuous_states_info=model_config["continuous_states_info"], + # Read out grid size + n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"], n_state_choices=relevant_state_choice_space.shape[0], ) if return_candidates: + n_assets_end_of_period = model_config["continuous_states_info"][ + "assets_grid_end_of_period" + ].shape[0] (value_candidates, policy_candidates, endog_grid_candidates) = ( create_solution_container( - model_config=model_config, + continuous_states_info=model_config["continuous_states_info"], + n_total_wealth_grid=n_assets_end_of_period, n_state_choices=relevant_state_choice_space.shape[0], ) ) + + # Create debug information + debug_info = { + "return_candidates": return_candidates, + "rescale_idx": np.where(relevant_state_choices_mask)[0].min(), + } + ( + value_solved, + policy_solved, + endog_grid_solved, + value_candidates_second_last, + policy_candidates_second_last, + endog_grid_candidates_second_last, + ) = solve_last_two_periods( + params=params, + continuous_states_info=continuous_states_info, + cont_grids_next_period=cont_grids_next_period, + income_shock_weights=income_shock_weights, + model_funcs=model_funcs, + last_two_period_batch_info=batch_info["last_two_period_info"], + value_solved=value_solved, + policy_solved=policy_solved, + endog_grid_solved=endog_grid_solved, + debug_info=debug_info, + ) + if return_candidates: + idx_second_last = batch_info["last_two_period_info"][ + "idx_state_choices_second_last_period" + ] + idx_second_last_rescaled = idx_second_last - debug_info["rescale_idx"] + value_candidates = value_candidates.at[idx_second_last_rescaled, ...].set( + value_candidates_second_last + ) + policy_candidates = policy_candidates.at[idx_second_last_rescaled, ...].set( + policy_candidates_second_last, + ) + endog_grid_candidates = endog_grid_candidates.at[ + idx_second_last_rescaled, ... + ].set(endog_grid_candidates_second_last) + + return value_solved, policy_solved, endog_grid_solved diff --git a/src/dcegm/pre_processing/sol_container.py b/src/dcegm/pre_processing/sol_container.py index 8efca2f0..68e4efe0 100644 --- a/src/dcegm/pre_processing/sol_container.py +++ b/src/dcegm/pre_processing/sol_container.py @@ -4,16 +4,11 @@ def create_solution_container( - model_config: Dict[str, Any], + continuous_states_info: Dict[str, Any], + n_total_wealth_grid: int, n_state_choices: int, ): """Create solution containers for value, policy, and endog_grid.""" - - # Read out grid size - n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"] - - # Check if second continuous state exists and read out array size - continuous_states_info = model_config["continuous_states_info"] if continuous_states_info["second_continuous_exists"]: n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"] diff --git a/tests/test_two_period_continuous_experience.py b/tests/test_two_period_continuous_experience.py index 728fb7fc..a081093c 100644 --- a/tests/test_two_period_continuous_experience.py +++ b/tests/test_two_period_continuous_experience.py @@ -449,7 +449,8 @@ def _get_solve_last_two_periods_args(model, params, has_second_continuous_state) # Create solution containers for value, policy, and endogenous grids value_solved, policy_solved, endog_grid_solved = create_solution_container( - model_config=model_config, + continuous_states_info=model_config["continuous_states_info"], + n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"], n_state_choices=model_structure["state_choice_space"].shape[0], ) From 09b831336f7cb56bd98d926299dfe653553d8acf Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 17 Jun 2025 19:11:35 +0200 Subject: [PATCH 03/34] Further debug interfac.e --- src/dcegm/final_periods.py | 48 ++------ src/dcegm/interfaces/inspect_solution.py | 148 +++++++++++++++++++++-- src/dcegm/solve_single_period.py | 58 ++++----- 3 files changed, 174 insertions(+), 80 deletions(-) diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index b9f3caf3..a42787d5 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -48,17 +48,6 @@ def solve_last_two_periods( for all states, end of period assets, and income shocks. """ - - idx_state_choices_final_period = last_two_period_batch_info[ - "idx_state_choices_final_period" - ] - if debug_info is not None: - if "rescale_idx" in debug_info.keys(): - # If we want to rescale the idx, because we only solve part of the model, then to this first. - idx_state_choices_final_period = ( - idx_state_choices_final_period - debug_info["rescale_idx"] - ) - ( value_solved, policy_solved, @@ -66,7 +55,9 @@ def solve_last_two_periods( value_interp_final_period, marginal_utility_final_last_period, ) = solve_final_period( - idx_state_choices_final_period=idx_state_choices_final_period, + idx_state_choices_final_period=last_two_period_batch_info[ + "idx_state_choices_final_period" + ], idx_parent_states_final_period=last_two_period_batch_info[ "idxs_parent_states_final_period" ], @@ -118,17 +109,18 @@ def solve_last_two_periods( idx_second_last = last_two_period_batch_info["idx_state_choices_second_last_period"] + value_solved = value_solved.at[idx_second_last, ...].set( + out_dict_second_last["value"] + ) + policy_solved = policy_solved.at[idx_second_last, ...].set( + out_dict_second_last["policy"] + ) + endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set( + out_dict_second_last["endog_grid"] + ) + # If we do not call the function in debug mode. Assign everything and return if debug_info is None: - value_solved = value_solved.at[idx_second_last, ...].set( - out_dict_second_last["value"] - ) - policy_solved = policy_solved.at[idx_second_last, ...].set( - out_dict_second_last["policy"] - ) - endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set( - out_dict_second_last["endog_grid"] - ) return ( value_solved, policy_solved, @@ -136,20 +128,6 @@ def solve_last_two_periods( ) else: - if "rescale_idx" in debug_info.keys(): - # If we want to rescale the idx, because we only solve part of the model, then to this first. - idx_rescaled_second_last = idx_second_last - debug_info["rescale_idx"] - # And then assign to the solution containers. - value_solved = value_solved.at[idx_rescaled_second_last, ...].set( - out_dict_second_last["value"] - ) - policy_solved = policy_solved.at[idx_rescaled_second_last, ...].set( - out_dict_second_last["policy"] - ) - endog_grid_solved = endog_grid_solved.at[idx_rescaled_second_last, ...].set( - out_dict_second_last["endog_grid"] - ) - # If candidates are also needed to returned we return them additionally to the solution containers. if debug_info["return_candidates"]: return ( diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py index 94f16a3d..b14488d7 100644 --- a/src/dcegm/interfaces/inspect_solution.py +++ b/src/dcegm/interfaces/inspect_solution.py @@ -1,3 +1,5 @@ +import copy + import jax.lax import jax.numpy as jnp import numpy as np @@ -5,6 +7,7 @@ from dcegm.final_periods import solve_last_two_periods from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.pre_processing.sol_container import create_solution_container +from dcegm.solve_single_period import solve_single_period def partially_solve( @@ -28,6 +31,10 @@ def partially_solve( return_candidates: If True, additionally return candidate solutions before applying the upper envelope. """ + batch_info_internal = copy.deepcopy(batch_info) + + if n_periods < 2: + raise ValueError("You must at least solve for two periods.") continuous_states_info = model_config["continuous_states_info"] @@ -38,7 +45,7 @@ def partially_solve( params=params, model_funcs=model_funcs, ) - + # Determine the last period we need to solve for. last_relevant_period = model_config["n_periods"] - n_periods relevant_state_choices_mask = ( @@ -71,11 +78,21 @@ def partially_solve( ) ) + # Determine rescale idx for reduced solution + rescale_idx = np.where(relevant_state_choices_mask)[0].min() + # Create debug information debug_info = { "return_candidates": return_candidates, - "rescale_idx": np.where(relevant_state_choices_mask)[0].min(), } + last_two_period_batch_info = batch_info_internal["last_two_period_info"] + # Rescale the indexes to save of the last two periods: + last_two_period_batch_info["idx_state_choices_final_period"] = ( + last_two_period_batch_info["idx_state_choices_final_period"] - rescale_idx + ) + last_two_period_batch_info["idx_state_choices_second_last_period"] = ( + last_two_period_batch_info["idx_state_choices_second_last_period"] - rescale_idx + ) ( value_solved, policy_solved, @@ -89,25 +106,134 @@ def partially_solve( cont_grids_next_period=cont_grids_next_period, income_shock_weights=income_shock_weights, model_funcs=model_funcs, - last_two_period_batch_info=batch_info["last_two_period_info"], + last_two_period_batch_info=last_two_period_batch_info, value_solved=value_solved, policy_solved=policy_solved, endog_grid_solved=endog_grid_solved, debug_info=debug_info, ) if return_candidates: - idx_second_last = batch_info["last_two_period_info"][ + idx_second_last = batch_info_internal["last_two_period_info"][ "idx_state_choices_second_last_period" ] - idx_second_last_rescaled = idx_second_last - debug_info["rescale_idx"] - value_candidates = value_candidates.at[idx_second_last_rescaled, ...].set( + value_candidates = value_candidates.at[idx_second_last, ...].set( value_candidates_second_last ) - policy_candidates = policy_candidates.at[idx_second_last_rescaled, ...].set( + policy_candidates = policy_candidates.at[idx_second_last, ...].set( policy_candidates_second_last, ) - endog_grid_candidates = endog_grid_candidates.at[ - idx_second_last_rescaled, ... - ].set(endog_grid_candidates_second_last) + endog_grid_candidates = endog_grid_candidates.at[idx_second_last, ...].set( + endog_grid_candidates_second_last + ) + + if n_periods <= 2: + out_dict = { + "value": value_solved, + "policy": policy_solved, + "endog_grid": endog_grid_solved, + } + if return_candidates: + out_dict["value_candidates"] = value_candidates + out_dict["policy_candidates"] = policy_candidates + out_dict["endog_grid_candidates"] = endog_grid_candidates + + return out_dict + + stop_segment_loop = False + for id_segment in range(batch_info_internal["n_segments"]): + segment_info = batch_info_internal[f"batches_info_segment_{id_segment}"] + + n_batches_in_segment = segment_info["batches_state_choice_idx"].shape[0] - return value_solved, policy_solved, endog_grid_solved + for id_batch in range(n_batches_in_segment): + periods_batch = segment_info["state_choices"]["period"][id_batch, :] + + # Now there can be three cases: + # 1) All periods are smaller than the last relevant period. Then we stop the loop + # 2) Part of the periods are smaller than the last relevant period. Then we only solve for the partial state choices. + # 3) All periods are larger than the last relevant period. Then we solve for state choices. + if (periods_batch < last_relevant_period).all(): + stop_segment_loop = True + break + elif (periods_batch < last_relevant_period).any(): + solve_mask = periods_batch >= last_relevant_period + state_choices_batch = { + key: segment_info["state_choices"][key][id_batch, solve_mask] + for key in segment_info["state_choices"].keys() + } + # We need to rescale the idx, because of saving + idx_to_solve = ( + segment_info["batches_state_choice_idx"][id_batch, solve_mask] + - rescale_idx + ) + child_states_to_integrate_stochastic = segment_info[ + "child_states_to_integrate_stochastic" + ][id_batch, solve_mask, :] + + else: + state_choices_batch = { + key: segment_info["state_choices"][key][id_batch, :] + for key in segment_info["state_choices"].keys() + } + # We need to rescale the idx, because of saving + idx_to_solve = ( + segment_info["batches_state_choice_idx"][id_batch, :] - rescale_idx + ) + child_states_to_integrate_stochastic = segment_info[ + "child_states_to_integrate_stochastic" + ][id_batch, :, :] + + state_choices_childs_batch = { + key: segment_info["state_choices_childs"][key][id_batch, :] + for key in segment_info["state_choices_childs"].keys() + } + xs = ( + idx_to_solve, + segment_info["child_state_choices_to_aggr_choice"][id_batch, :, :], + child_states_to_integrate_stochastic, + segment_info["child_state_choice_idxs_to_interp"][id_batch, :], + segment_info["child_states_idxs"][id_batch, :], + state_choices_batch, + state_choices_childs_batch, + ) + carry = (value_solved, policy_solved, endog_grid_solved) + single_period_out_dict = solve_single_period( + carry=carry, + xs=xs, + params=params, + continuous_grids_info=continuous_states_info, + cont_grids_next_period=cont_grids_next_period, + model_funcs=model_funcs, + income_shock_weights=income_shock_weights, + debug_info=debug_info, + ) + + value_solved = single_period_out_dict["value"] + policy_solved = single_period_out_dict["policy"] + endog_grid_solved = single_period_out_dict["endog_grid"] + + # If candidates are requested, we assign them to the solution container + if return_candidates: + value_candidates = value_candidates.at[idx_to_solve, ...].set( + single_period_out_dict["value_candidates"] + ) + policy_candidates = policy_candidates.at[idx_to_solve, ...].set( + single_period_out_dict["policy_candidates"] + ) + endog_grid_candidates = endog_grid_candidates.at[idx_to_solve, ...].set( + single_period_out_dict["endog_grid_candidates"] + ) + + if stop_segment_loop: + break + + out_dict = { + "value": value_solved, + "policy": policy_solved, + "endog_grid": endog_grid_solved, + } + if return_candidates: + out_dict["value_candidates"] = value_candidates + out_dict["policy_candidates"] = policy_candidates + out_dict["endog_grid_candidates"] = endog_grid_candidates + return out_dict diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index 3f8384a9..93499bc1 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -75,46 +75,36 @@ def solve_single_period( model_funcs=model_funcs, debug_info=debug_info, ) + value_solved = value_solved.at[state_choices_idxs, :].set(out_dict_period["value"]) + policy_solved = policy_solved.at[state_choices_idxs, :].set( + out_dict_period["policy"] + ) + endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( + out_dict_period["endog_grid"] + ) - # If we are not in the debug mode, we only return the solved values. + # If we are not in the debug mode, we only return the solution as a tuple and an empty tuple. if debug_info is None: - - value_solved = value_solved.at[state_choices_idxs, :].set( - out_dict_period["value"] - ) - policy_solved = policy_solved.at[state_choices_idxs, :].set( - out_dict_period["policy"] - ) - endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( - out_dict_period["endog_grid"] - ) carry = (value_solved, policy_solved, endog_grid_solved) + return carry, () else: - if "rescale_idx" in debug_info.keys(): - state_choices_idxs = state_choices_idxs - debug_info["rescale_idx"] - value_solved = value_solved.at[state_choices_idxs, :].set( - out_dict_period["value"] - ) - policy_solved = policy_solved.at[state_choices_idxs, :].set( - out_dict_period["policy"] - ) - endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( - out_dict_period["endog_grid"] - ) - if debug_info["return_candidates"]: - carry = ( - value_solved, - policy_solved, - endog_grid_solved, - out_dict_period["value_candidates"], - out_dict_period["policy_candidates"], - out_dict_period["endog_grid_candidates"], - ) - else: - carry = (value_solved, policy_solved, endog_grid_solved) + # In debug mode we return a dictionary. + out_dict = { + "value": value_solved, + "policy": policy_solved, + "endog_grid": endog_grid_solved, + } - return carry, () + # If candidates are requested, we add them + if debug_info["return_candidates"]: + out_dict = { + **out_dict, + "value_candidates": out_dict_period["value_candidates"], + "policy_candidates": out_dict_period["policy_candidates"], + "endog_grid_candidates": out_dict_period["endog_grid_candidates"], + } + return out_dict def solve_for_interpolated_values( From 7433585b4f38fac322129fd374d7bcd856827987 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Wed, 18 Jun 2025 16:40:50 +0200 Subject: [PATCH 04/34] Child states including trans probs calc. --- src/dcegm/interfaces/model_class.py | 11 +++++++++++ src/dcegm/solve_single_period.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 14450158..0e399713 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -322,6 +322,17 @@ def get_child_states(self, state, choice): } return pd.DataFrame(child_states) + def get_child_states_and_calc_trans_probs(self, state, choice, params): + """Get the child states for a given state and choice and calculate the + transition probabilities.""" + child_states_df = self.get_child_states(state, choice) + + trans_probs = self.model_funcs["compute_stochastic_transition_vec"]( + params=params, choice=choice, **state + ) + child_states_df["trans_probs"] = trans_probs + return child_states_df + def compute_law_of_motions(self, params): return calc_cont_grids_next_period( params=params, diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index 93499bc1..1ee2be5c 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -122,7 +122,7 @@ def solve_for_interpolated_values( debug_info, ): # EGM step 2) - # Aggregate the marginal utilities and expected values over all state-choice + # Aggregate the marginal utilities and expected values over all child state-choice # combinations and income shock draws marg_util, emax = aggregate_marg_utils_and_exp_values( value_state_choice_specific=value_interpolated, From f70987a36bf1791110c4f6dc04ba6fd0b340ace7 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Sun, 22 Jun 2025 13:17:28 +0200 Subject: [PATCH 05/34] Fix options. --- src/dcegm/interfaces/inspect_solution.py | 8 +++- src/dcegm/interfaces/model_class.py | 57 ++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py index b14488d7..d3e4e63c 100644 --- a/src/dcegm/interfaces/inspect_solution.py +++ b/src/dcegm/interfaces/inspect_solution.py @@ -187,11 +187,17 @@ def partially_solve( key: segment_info["state_choices_childs"][key][id_batch, :] for key in segment_info["state_choices_childs"].keys() } + + child_state_choice_idxs_to_interp = ( + segment_info["child_state_choice_idxs_to_interp"][id_batch, :] + - rescale_idx + ) + xs = ( idx_to_solve, segment_info["child_state_choices_to_aggr_choice"][id_batch, :, :], child_states_to_integrate_stochastic, - segment_info["child_state_choice_idxs_to_interp"][id_batch, :], + child_state_choice_idxs_to_interp, segment_info["child_states_idxs"][id_batch, :], state_choices_batch, state_choices_childs_batch, diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 0e399713..7948eca8 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -333,6 +333,63 @@ def get_child_states_and_calc_trans_probs(self, state, choice, params): child_states_df["trans_probs"] = trans_probs return child_states_df + def get_full_child_states_by_asset_id_and_probs( + self, state, choice, params, asset_id, second_continuous_id=None + ): + """Get the child states for a given state and choice and calculate the + transition probabilities.""" + if "map_state_choice_to_child_states" not in self.model_structure: + raise ValueError( + "For this function the model needs to be created with debug_info='all'" + ) + + child_idx = get_child_state_index_per_state_choice( + states=state, choice=choice, model_structure=self.model_structure + ) + state_space_dict = self.model_structure["state_space_dict"] + discrete_states_names = self.model_structure["discrete_states_names"] + child_states = { + key: state_space_dict[key][child_idx] for key in discrete_states_names + } + child_states_df = pd.DataFrame(child_states) + + child_continuous_states = self.compute_law_of_motions(params=params) + + if "second_continuous" in child_continuous_states.keys(): + if second_continuous_id is None: + raise ValueError("second_continuous_id must be provided.") + else: + quad_wealth = child_continuous_states["assets_begin_of_period"][ + child_idx, second_continuous_id, asset_id, : + ] + next_period_second_continuous = child_continuous_states[ + "second_continuous" + ][child_idx, second_continuous_id] + + second_continuous_name = self.model_config["continuous_states_info"][ + "second_continuous_state_name" + ] + child_states_df[second_continuous_name] = next_period_second_continuous + + else: + if second_continuous_id is not None: + raise ValueError("second_continuous_id must not be provided.") + else: + quad_wealth = child_continuous_states["assets_begin_of_period"][ + child_idx, asset_id, : + ] + + for id_quad in range(quad_wealth.shape[1]): + child_states_df[f"assets_begin_of_period_quad_point_{id_quad}"] = ( + quad_wealth[:, id_quad] + ) + + trans_probs = self.model_funcs["compute_stochastic_transition_vec"]( + params=params, choice=choice, **state + ) + child_states_df["trans_probs"] = trans_probs + return child_states_df + def compute_law_of_motions(self, params): return calc_cont_grids_next_period( params=params, From 2b5f277b7bdace19710f0d1260fb3031b10fdbc1 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Mon, 28 Jul 2025 13:27:37 +0200 Subject: [PATCH 06/34] Added aux options to asset correction. --- src/dcegm/asset_correction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dcegm/asset_correction.py b/src/dcegm/asset_correction.py index f9bc2ddc..17a6e2b4 100644 --- a/src/dcegm/asset_correction.py +++ b/src/dcegm/asset_correction.py @@ -7,7 +7,7 @@ ) -def adjust_observed_assets(observed_states_dict, params, model_class): +def adjust_observed_assets(observed_states_dict, params, model_class, aux_outs=False): """Correct observed beginning of period assets data for likelihood estimation. Assets in empirical survey data is observed without the income of last period's @@ -47,7 +47,7 @@ def adjust_observed_assets(observed_states_dict, params, model_class): jnp.array(0.0, dtype=jnp.float64), params, model_funcs["compute_assets_begin_of_period"], - False, + aux_outs, ) else: @@ -60,7 +60,7 @@ def adjust_observed_assets(observed_states_dict, params, model_class): jnp.array(0.0, dtype=jnp.float64), params, model_funcs["compute_assets_begin_of_period"], - False, + aux_outs, ) return adjusted_assets From 9db7f39068e8e776a6e573e4410b8ce8ad75823b Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 8 Aug 2025 10:20:22 +0200 Subject: [PATCH 07/34] Streamlined choice prob calc --- src/dcegm/interfaces/interface.py | 49 ++---- src/dcegm/interfaces/sol_interface.py | 28 +++- src/dcegm/interpolation/interp_interfaces.py | 46 ++++++ src/dcegm/likelihood.py | 152 ++++++------------- 4 files changed, 131 insertions(+), 144 deletions(-) create mode 100644 src/dcegm/interpolation/interp_interfaces.py diff --git a/src/dcegm/interfaces/interface.py b/src/dcegm/interfaces/interface.py index a5da8813..bb580f01 100644 --- a/src/dcegm/interfaces/interface.py +++ b/src/dcegm/interfaces/interface.py @@ -7,13 +7,12 @@ from dcegm.interpolation.interp1d import ( interp1d_policy_and_value_on_wealth, interp_policy_on_wealth, - interp_value_on_wealth, ) from dcegm.interpolation.interp2d import ( interp2d_policy_and_value_on_wealth_and_regular_grid, interp2d_policy_on_wealth_and_regular_grid, - interp2d_value_on_wealth_and_regular_grid, ) +from dcegm.interpolation.interp_interfaces import interpolate_value_for_state_and_choice def get_n_state_choice_period(model): @@ -126,7 +125,7 @@ def policy_and_value_for_state_choice_vec( return policy, value -def value_for_state_choice_vec( +def value_for_state_and_choice( states, choice, params, @@ -156,7 +155,9 @@ def value_for_state_choice_vec( map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] discrete_states_names = model_structure["discrete_states_names"] - if "dummy_stochastic" in discrete_states_names: + if ("dummy_stochastic" in discrete_states_names) & ( + "dummy_stochastic" not in states.keys() + ): state_choice_vec = { **states, "choice": choice, @@ -173,38 +174,18 @@ def value_for_state_choice_vec( state_choice_vec[st] for st in discrete_states_names + ["choice"] ) state_choice_index = map_state_choice_to_index[state_choice_tuple] - continuous_states_info = model_config["continuous_states_info"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - compute_utility = model_funcs["compute_utility"] - if continuous_states_info["second_continuous_exists"]: - second_continuous = state_choice_vec[ - continuous_states_info["second_continuous_state_name"] - ] - - value = interp2d_value_on_wealth_and_regular_grid( - regular_grid=continuous_states_info["second_continuous_grid"], - wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value_grid=jnp.take(value_solved, state_choice_index, axis=0), - regular_point_to_interp=second_continuous, - wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - else: + value_grid_state_choice = jnp.take(value_solved, state_choice_index, axis=0) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_index, axis=0) - value = interp_value_on_wealth( - wealth=state_choice_vec["assets_begin_of_period"], - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value=jnp.take(value_solved, state_choice_index, axis=0), - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) + value = interpolate_value_for_state_and_choice( + value_grid_state_choice=value_grid_state_choice, + endog_grid_state_choice=endog_grid_state_choice, + state_choice_vec=state_choice_vec, + params=params, + model_config=model_config, + model_funcs=model_funcs, + ) return value diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index f52b222a..e71f691c 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -3,7 +3,11 @@ from dcegm.interfaces.interface import ( policy_and_value_for_state_choice_vec, policy_for_state_choice_vec, - value_for_state_choice_vec, + value_for_state_and_choice, +) +from dcegm.likelihood import ( + calc_choice_probs_for_states, + get_state_choice_index_per_discrete_state, ) from dcegm.simulation.sim_utils import create_simulation_df from dcegm.simulation.simulate import simulate_all_periods @@ -86,7 +90,7 @@ def value_for_state_and_choice(self, state, choice): """ - return value_for_state_choice_vec( + return value_for_state_and_choice( states=state, choice=choice, model_config=self.model_config, @@ -157,3 +161,23 @@ def get_solution_for_discrete_state_choice(self, states, choice): policy_grid = jnp.take(self.policy, state_choice_index, axis=0) return endog_grid, value_grid, policy_grid + + def choice_probabilites_for_states(self, states): + + state_choice_idxs = get_state_choice_index_per_discrete_state( + states=states, + map_state_choice_to_index=self.model_structure[ + "map_state_choice_to_index_with_proxy" + ], + discrete_states_names=self.model_structure["discrete_states_names"], + ) + + return calc_choice_probs_for_states( + value_solved=self.value, + endog_grid_solved=self.endog_grid, + params=self.params, + states=states, + state_choice_indexes_for_states=state_choice_idxs, + model_config=self.model_config, + model_funcs=self.model_funcs, + ) diff --git a/src/dcegm/interpolation/interp_interfaces.py b/src/dcegm/interpolation/interp_interfaces.py new file mode 100644 index 00000000..42317aca --- /dev/null +++ b/src/dcegm/interpolation/interp_interfaces.py @@ -0,0 +1,46 @@ +from dcegm.interpolation.interp1d import interp_value_on_wealth +from dcegm.interpolation.interp2d import interp2d_value_on_wealth_and_regular_grid + + +def interpolate_value_for_state_and_choice( + value_grid_state_choice, + endog_grid_state_choice, + state_choice_vec, + params, + model_config, + model_funcs, +): + """Interpolate the value for a state and choice given the respective grids.""" + continuous_states_info = model_config["continuous_states_info"] + discount_factor = model_funcs["read_funcs"]["discount_factor"](params) + + compute_utility = model_funcs["compute_utility"] + + if continuous_states_info["second_continuous_exists"]: + second_continuous = state_choice_vec[ + continuous_states_info["second_continuous_state_name"] + ] + + value = interp2d_value_on_wealth_and_regular_grid( + regular_grid=continuous_states_info["second_continuous_grid"], + wealth_grid=endog_grid_state_choice, + value_grid=value_grid_state_choice, + regular_point_to_interp=second_continuous, + wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + else: + + value = interp_value_on_wealth( + wealth=state_choice_vec["assets_begin_of_period"], + endog_grid=endog_grid_state_choice, + value=value_grid_state_choice, + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + return value diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index e7b284d1..67b58239 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -16,8 +16,7 @@ calculate_choice_probs_and_unsqueezed_logsum, ) from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state -from dcegm.interpolation.interp1d import interp_value_on_wealth -from dcegm.interpolation.interp2d import interp2d_value_on_wealth_and_regular_grid +from dcegm.interpolation.interp_interfaces import interpolate_value_for_state_and_choice def create_individual_likelihood_function( @@ -280,8 +279,8 @@ def calc_choice_prob_for_state_choices( value_solved=value_solved, endog_grid_solved=endog_grid_solved, params=params, - observed_states=states, - state_choice_indexes=state_choice_indexes, + states=states, + state_choice_indexes_for_states=state_choice_indexes, model_config=model_config, model_funcs=model_funcs, ) @@ -295,62 +294,54 @@ def calc_choice_probs_for_states( value_solved, endog_grid_solved, params, - observed_states, - state_choice_indexes, + states, + state_choice_indexes_for_states, model_config, model_funcs, ): - value_grid_agent = jnp.take( - value_solved, state_choice_indexes, axis=0, mode="fill", fill_value=jnp.nan + value_grid_states = jnp.take( + value_solved, + state_choice_indexes_for_states, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + endog_grid_states = jnp.take( + endog_grid_solved, state_choice_indexes_for_states, axis=0 ) - endog_grid_agent = jnp.take(endog_grid_solved, state_choice_indexes, axis=0) - - # Read out relevant model objects - continuous_states_info = model_config["continuous_states_info"] - choice_range = model_config["choices"] - if continuous_states_info["second_continuous_exists"]: - vectorized_interp2d = jax.vmap( - jax.vmap( - interp2d_value_for_state_in_each_choice, - in_axes=(None, None, 0, 0, 0, None, None, None), - ), - in_axes=(0, 0, 0, 0, None, None, None, None), - ) - # Extract second cont state name - second_continuous_state_name = continuous_states_info[ - "second_continuous_state_name" - ] - second_cont_value = observed_states[second_continuous_state_name] - - value_per_agent_interp = vectorized_interp2d( - observed_states, - second_cont_value, - endog_grid_agent, - value_grid_agent, - choice_range, - params, - continuous_states_info["second_continuous_grid"], - model_funcs, + def wrapper_interp_value_for_choice( + state, + value_grid_state_choice, + endog_grid_state_choice, + choice, + ): + state_choice_vec = {**state, "choice": choice} + + return interpolate_value_for_state_and_choice( + value_grid_state_choice=value_grid_state_choice, + endog_grid_state_choice=endog_grid_state_choice, + state_choice_vec=state_choice_vec, + params=params, + model_config=model_config, + model_funcs=model_funcs, ) - else: - vectorized_interp1d = jax.vmap( - jax.vmap( - interp1d_value_for_state_in_each_choice, - in_axes=(None, 0, 0, 0, None, None), - ), - in_axes=(0, 0, 0, None, None, None), - ) + # Read out choice range to loop over + choice_range = model_config["choices"] - value_per_agent_interp = vectorized_interp1d( - observed_states, - endog_grid_agent, - value_grid_agent, - choice_range, - params, - model_funcs, - ) + value_per_agent_interp = jax.vmap( + jax.vmap( + wrapper_interp_value_for_choice, + in_axes=(None, 0, 0, 0), + ), + in_axes=(0, 0, 0, None), + )( + states, + endog_grid_states, + value_grid_states, + choice_range, + ) if model_funcs["taste_shock_function"]["taste_shock_scale_is_scalar"]: taste_shock_scale = model_funcs["taste_shock_function"][ @@ -361,7 +352,7 @@ def calc_choice_probs_for_states( "taste_shock_scale_per_state" ] taste_shock_scale = vmap(taste_shock_scale_per_state_func, in_axes=(0, None))( - observed_states, params + states, params ) taste_shock_scale = taste_shock_scale[:, None] @@ -372,61 +363,6 @@ def calc_choice_probs_for_states( return choice_prob_across_choices -def interp2d_value_for_state_in_each_choice( - state, - second_cont_state, - endog_grid_agent, - value_agent, - choice, - params, - regular_grid, - model_funcs, -): - state_choice_vec = {**state, "choice": choice} - - compute_utility = model_funcs["compute_utility"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - value_interp = interp2d_value_on_wealth_and_regular_grid( - regular_grid=regular_grid, - wealth_grid=endog_grid_agent, - value_grid=value_agent, - regular_point_to_interp=second_cont_state, - wealth_point_to_interp=state["assets_begin_of_period"], - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - - return value_interp - - -def interp1d_value_for_state_in_each_choice( - state, - endog_grid_agent, - value_agent, - choice, - params, - model_funcs, -): - state_choice_vec = {**state, "choice": choice} - compute_utility = model_funcs["compute_utility"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - value_interp = interp_value_on_wealth( - wealth=state["assets_begin_of_period"], - endog_grid=endog_grid_agent, - value=value_agent, - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - - return value_interp - - def calculate_weights_for_each_state(params, weight_vars, model_specs, weight_func): """Calculate the weights for each state. From 7ccbc116eb4071eeebdded3c4cd7d7b492c18130 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 14 Aug 2025 12:38:33 +0200 Subject: [PATCH 08/34] Restructure choice probs --- src/dcegm/interfaces/sol_interface.py | 23 ++++++++- src/dcegm/likelihood.py | 74 ++++++++++++++++++--------- 2 files changed, 71 insertions(+), 26 deletions(-) diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index e71f691c..b3aafec1 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -7,6 +7,7 @@ ) from dcegm.likelihood import ( calc_choice_probs_for_states, + choice_values_for_states, get_state_choice_index_per_discrete_state, ) from dcegm.simulation.sim_utils import create_simulation_df @@ -162,7 +163,7 @@ def get_solution_for_discrete_state_choice(self, states, choice): return endog_grid, value_grid, policy_grid - def choice_probabilites_for_states(self, states): + def choice_probabilities_for_states(self, states): state_choice_idxs = get_state_choice_index_per_discrete_state( states=states, @@ -175,9 +176,27 @@ def choice_probabilites_for_states(self, states): return calc_choice_probs_for_states( value_solved=self.value, endog_grid_solved=self.endog_grid, + state_choice_indexes=state_choice_idxs, + params=self.params, + states=states, + model_config=self.model_config, + model_funcs=self.model_funcs, + ) + + def choice_values_for_states(self, states): + state_choice_idxs = get_state_choice_index_per_discrete_state( + states=states, + map_state_choice_to_index=self.model_structure[ + "map_state_choice_to_index_with_proxy" + ], + discrete_states_names=self.model_structure["discrete_states_names"], + ) + return choice_values_for_states( + value_solved=self.value, + endog_grid_solved=self.endog_grid, + state_choice_indexes=state_choice_idxs, params=self.params, states=states, - state_choice_indexes_for_states=state_choice_idxs, model_config=self.model_config, model_funcs=self.model_funcs, ) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 67b58239..03dcd028 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -275,12 +275,13 @@ def calc_choice_prob_for_state_choices( and then interpolates the wealth at the beginning of period on them. """ + choice_prob_across_choices = calc_choice_probs_for_states( value_solved=value_solved, endog_grid_solved=endog_grid_solved, + state_choice_indexes=state_choice_indexes, params=params, states=states, - state_choice_indexes_for_states=state_choice_indexes, model_config=model_config, model_funcs=model_funcs, ) @@ -293,21 +294,64 @@ def calc_choice_prob_for_state_choices( def calc_choice_probs_for_states( value_solved, endog_grid_solved, + state_choice_indexes, + params, + states, + model_config, + model_funcs, +): + choice_values_per_state = choice_values_for_states( + value_solved=value_solved, + endog_grid_solved=endog_grid_solved, + state_choice_indexes=state_choice_indexes, + params=params, + states=states, + model_config=model_config, + model_funcs=model_funcs, + ) + + if model_funcs["taste_shock_function"]["taste_shock_scale_is_scalar"]: + taste_shock_scale = model_funcs["taste_shock_function"][ + "read_out_taste_shock_scale" + ](params) + else: + taste_shock_scale_per_state_func = model_funcs["taste_shock_function"][ + "taste_shock_scale_per_state" + ] + taste_shock_scale = vmap(taste_shock_scale_per_state_func, in_axes=(0, None))( + states, params + ) + taste_shock_scale = taste_shock_scale[:, None] + + choice_prob_across_choices, _, _ = calculate_choice_probs_and_unsqueezed_logsum( + choice_values_per_state=choice_values_per_state, + taste_shock_scale=taste_shock_scale, + ) + return choice_prob_across_choices + + +def choice_values_for_states( + value_solved, + endog_grid_solved, + state_choice_indexes, params, states, - state_choice_indexes_for_states, model_config, model_funcs, ): value_grid_states = jnp.take( value_solved, - state_choice_indexes_for_states, + state_choice_indexes, axis=0, mode="fill", fill_value=jnp.nan, ) endog_grid_states = jnp.take( - endog_grid_solved, state_choice_indexes_for_states, axis=0 + endog_grid_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, ) def wrapper_interp_value_for_choice( @@ -330,7 +374,7 @@ def wrapper_interp_value_for_choice( # Read out choice range to loop over choice_range = model_config["choices"] - value_per_agent_interp = jax.vmap( + choice_values_per_state = jax.vmap( jax.vmap( wrapper_interp_value_for_choice, in_axes=(None, 0, 0, 0), @@ -342,25 +386,7 @@ def wrapper_interp_value_for_choice( value_grid_states, choice_range, ) - - if model_funcs["taste_shock_function"]["taste_shock_scale_is_scalar"]: - taste_shock_scale = model_funcs["taste_shock_function"][ - "read_out_taste_shock_scale" - ](params) - else: - taste_shock_scale_per_state_func = model_funcs["taste_shock_function"][ - "taste_shock_scale_per_state" - ] - taste_shock_scale = vmap(taste_shock_scale_per_state_func, in_axes=(0, None))( - states, params - ) - taste_shock_scale = taste_shock_scale[:, None] - - choice_prob_across_choices, _, _ = calculate_choice_probs_and_unsqueezed_logsum( - choice_values_per_state=value_per_agent_interp, - taste_shock_scale=taste_shock_scale, - ) - return choice_prob_across_choices + return choice_values_per_state def calculate_weights_for_each_state(params, weight_vars, model_specs, weight_func): From 6b1674b2f25520e94f21b731fb2d223c22b59829 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 15 Aug 2025 12:35:37 +0200 Subject: [PATCH 09/34] Streamlined interface and added check --- .../two_period_model_tutorial.ipynb | 15 +- src/dcegm/egm/interpolate_marginal_utility.py | 6 +- ...nspect_structure.py => index_functions.py} | 23 +- src/dcegm/interfaces/interface.py | 215 ++++++------------ src/dcegm/interfaces/interface_checks.py | 123 ++++++++++ src/dcegm/interfaces/model_class.py | 16 +- src/dcegm/interfaces/sol_interface.py | 79 +++---- src/dcegm/interpolation/interp1d.py | 46 ++-- src/dcegm/interpolation/interp_interfaces.py | 93 +++++++- src/dcegm/likelihood.py | 4 +- src/dcegm/simulation/sim_utils.py | 12 +- src/dcegm/simulation/simulate.py | 4 +- ...iscrete_versus_continuous_experience.ipynb | 47 ---- tests/sandbox/time_functions_jax.ipynb | 140 ++++++++++-- ...t_discrete_versus_continuous_experience.py | 24 +- tests/test_replication.py | 6 +- tests/test_sparse_stochastic_and_batch_sep.py | 2 +- tests/test_utility_second_continuous.py | 6 +- tests/test_varying_shock_scale.py | 6 +- 19 files changed, 527 insertions(+), 340 deletions(-) rename src/dcegm/interfaces/{inspect_structure.py => index_functions.py} (67%) create mode 100644 src/dcegm/interfaces/interface_checks.py delete mode 100644 tests/sandbox/discrete_versus_continuous_experience.ipynb diff --git a/docs/source/background/two_period_model_tutorial.ipynb b/docs/source/background/two_period_model_tutorial.ipynb index 3d0fb177..6f588557 100644 --- a/docs/source/background/two_period_model_tutorial.ipynb +++ b/docs/source/background/two_period_model_tutorial.ipynb @@ -749,15 +749,10 @@ ] }, { + "metadata": {}, "cell_type": "code", - "execution_count": 17, - "metadata": { - "ExecuteTime": { - "end_time": "2025-06-30T09:22:14.977528Z", - "start_time": "2025-06-30T09:22:14.323938Z" - } - }, "outputs": [], + "execution_count": null, "source": [ "state_dict = {\n", " \"ltc\": initial_condition[\"health\"],\n", @@ -767,9 +762,9 @@ "}\n", "\n", "\n", - "cons_calc, value = solved_model.value_and_policy_for_state_and_choice(\n", - " state=state_dict,\n", - " choice=choice_in_period_0,\n", + "cons_calc, value = solved_model.value_and_policy_for_states_and_choices(\n", + " states=state_dict,\n", + " choices=choice_in_period_0,\n", ")" ] }, diff --git a/src/dcegm/egm/interpolate_marginal_utility.py b/src/dcegm/egm/interpolate_marginal_utility.py index 0d1e19e1..96217cdf 100644 --- a/src/dcegm/egm/interpolate_marginal_utility.py +++ b/src/dcegm/egm/interpolate_marginal_utility.py @@ -164,9 +164,9 @@ def interp1d_value_and_marg_util_for_state_choice( def interp_on_single_wealth_point(wealth_point): policy_interp, value_interp = interp1d_policy_and_value_on_wealth( wealth=wealth_point, - endog_grid=endog_grid_child_state_choice, - policy=policy_child_state_choice, - value=value_child_state_choice, + wealth_grid=endog_grid_child_state_choice, + policy_grid=policy_child_state_choice, + value_grid=value_child_state_choice, compute_utility=compute_utility, state_choice_vec=state_choice_vec, params=params, diff --git a/src/dcegm/interfaces/inspect_structure.py b/src/dcegm/interfaces/index_functions.py similarity index 67% rename from src/dcegm/interfaces/inspect_structure.py rename to src/dcegm/interfaces/index_functions.py index 2944e119..e3016391 100644 --- a/src/dcegm/interfaces/inspect_structure.py +++ b/src/dcegm/interfaces/index_functions.py @@ -1,7 +1,6 @@ -def get_child_state_index_per_state_choice(states, choice, model_structure): - states_choice_dict = {**states, "choice": choice} - state_choice_index = get_state_choice_index_per_discrete_state_and_choice( - model_structure, states_choice_dict +def get_child_state_index_per_states_and_choices(states, choices, model_structure): + state_choice_index = get_state_choice_index_per_discrete_states_and_choices( + model_structure, states, choices ) child_states = model_structure["map_state_choice_to_child_states"][ @@ -11,7 +10,7 @@ def get_child_state_index_per_state_choice(states, choice, model_structure): return child_states -def get_state_choice_index_per_discrete_state( +def get_state_choice_index_per_discrete_states( states, map_state_choice_to_index, discrete_states_names ): """Get the state-choice index for a given set of discrete states. @@ -33,26 +32,26 @@ def get_state_choice_index_per_discrete_state( return indexes[0] -def get_state_choice_index_per_discrete_state_and_choice( - model_structure, state_choice_dict +def get_state_choice_index_per_discrete_states_and_choices( + model_structure, states, choices ): """Get the state-choice index for a given set of discrete states and a choice. Args: - model (dict): A dictionary representing the model. Must contain - 'model_structure' with a 'map_state_choice_to_index_with_proxy' - and 'discrete_states_names'. - state_choice_dict (dict): Dictionary containing discrete states and + model_structure (dict): Model structure containing all information on the structure of the model. + states (dict): Dictionary containing discrete states and the choice. Returns: int: The index corresponding to the specified discrete states and choice. """ + state_choices = {"choice": choices, **states} + map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] discrete_states_names = model_structure["discrete_states_names"] state_choice_tuple = tuple( - state_choice_dict[st] for st in discrete_states_names + ["choice"] + state_choices[st] for st in discrete_states_names + ["choice"] ) state_choice_index = map_state_choice_to_index[state_choice_tuple] diff --git a/src/dcegm/interfaces/interface.py b/src/dcegm/interfaces/interface.py index bb580f01..360c4b6a 100644 --- a/src/dcegm/interfaces/interface.py +++ b/src/dcegm/interfaces/interface.py @@ -4,15 +4,15 @@ import jax.numpy as jnp import pandas as pd -from dcegm.interpolation.interp1d import ( - interp1d_policy_and_value_on_wealth, - interp_policy_on_wealth, +from dcegm.interfaces.index_functions import ( + get_state_choice_index_per_discrete_states_and_choices, ) -from dcegm.interpolation.interp2d import ( - interp2d_policy_and_value_on_wealth_and_regular_grid, - interp2d_policy_on_wealth_and_regular_grid, +from dcegm.interfaces.interface_checks import check_states_and_choices +from dcegm.interpolation.interp_interfaces import ( + interpolate_policy_and_value_for_state_and_choice, + interpolate_policy_for_state_and_choice, + interpolate_value_for_state_and_choice, ) -from dcegm.interpolation.interp_interfaces import interpolate_value_for_state_and_choice def get_n_state_choice_period(model): @@ -34,9 +34,9 @@ def get_n_state_choice_period(model): ) -def policy_and_value_for_state_choice_vec( +def policy_and_value_for_states_and_choices( states, - choice, + choices, params, endog_grid_solved, value_solved, @@ -66,68 +66,38 @@ def policy_and_value_for_state_choice_vec( choice. """ - # ToDo: Check if states contains relevant structure - map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - discrete_states_names = model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[st] for st in discrete_states_names + ["choice"] + + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=states, choices=choices, model_structure=model_structure ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] - continuous_states_info = model_config["continuous_states_info"] - - compute_utility = model_funcs["compute_utility"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - if continuous_states_info["second_continuous_exists"]: - - second_continuous = state_choice_vec[ - continuous_states_info["second_continuous_state_name"] - ] - - policy, value = interp2d_policy_and_value_on_wealth_and_regular_grid( - regular_grid=continuous_states_info["second_continuous_grid"], - wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value_grid=jnp.take(value_solved, state_choice_index, axis=0), - policy_grid=jnp.take(policy_solved, state_choice_index, axis=0), - regular_point_to_interp=second_continuous, - wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - else: - policy, value = interp1d_policy_and_value_on_wealth( - wealth=state_choice_vec["assets_begin_of_period"], - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - policy=jnp.take(policy_solved, state_choice_index, axis=0), - value=jnp.take(value_solved, state_choice_index, axis=0), - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) + value_grid_state_choice = jnp.take(value_solved, state_choice_idx, axis=0) + policy_grid_state_choice = jnp.take(policy_solved, state_choice_idx, axis=0) - return policy, value + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure + ) + policy, value = jax.vmap( + interpolate_policy_and_value_for_state_and_choice, + in_axes=(0, 0, 0, 0, None, None, None), + )( + value_grid_state_choice, + policy_grid_state_choice, + endog_grid_state_choice, + state_choices, + params, + model_config, + model_funcs, + ) + return ( + jnp.squeeze(policy), + jnp.squeeze(value), + ) def value_for_state_and_choice( states, - choice, + choices, params, endog_grid_solved, value_solved, @@ -152,47 +122,32 @@ def value_for_state_and_choice( float: The value at the given state and choice. """ - map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - discrete_states_names = model_structure["discrete_states_names"] - - if ("dummy_stochastic" in discrete_states_names) & ( - "dummy_stochastic" not in states.keys() - ): - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[st] for st in discrete_states_names + ["choice"] - ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] - - value_grid_state_choice = jnp.take(value_solved, state_choice_index, axis=0) - endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_index, axis=0) - - value = interpolate_value_for_state_and_choice( - value_grid_state_choice=value_grid_state_choice, - endog_grid_state_choice=endog_grid_state_choice, - state_choice_vec=state_choice_vec, - params=params, - model_config=model_config, - model_funcs=model_funcs, + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=states, choices=choices, model_structure=model_structure ) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) + value_grid_state_choice = jnp.take(value_solved, state_choice_idx, axis=0) - return value + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure + ) + value = jax.vmap( + interpolate_value_for_state_and_choice, + in_axes=(0, 0, 0, None, None, None), + )( + value_grid_state_choice, + endog_grid_state_choice, + state_choices, + params, + model_config, + model_funcs, + ) + return jnp.squeeze(value) def policy_for_state_choice_vec( states, - choice, + choices, endog_grid_solved, policy_solved, model_structure, @@ -214,51 +169,25 @@ def policy_for_state_choice_vec( float: The policy at the given state and choice. """ - map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - discrete_states_names = model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[st] for st in discrete_states_names + ["choice"] + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=states, choices=choices, model_structure=model_structure ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] - continuous_states_info = model_config["continuous_states_info"] - - if continuous_states_info["second_continuous_exists"]: - second_continuous = states[ - continuous_states_info["second_continuous_state_name"] - ] - - policy = interp2d_policy_on_wealth_and_regular_grid( - regular_grid=model_config["continuous_states_info"][ - "second_continuous_grid" - ], - wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - policy_grid=jnp.take(policy_solved, state_choice_index, axis=0), - regular_point_to_interp=second_continuous, - wealth_point_to_interp=states["assets_begin_of_period"], - ) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) + policy_grid_state_choice = jnp.take(policy_solved, state_choice_idx, axis=0) - else: - policy = interp_policy_on_wealth( - wealth=states["assets_begin_of_period"], - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - policy=jnp.take(policy_solved, state_choice_index, axis=0), - ) - - return policy + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure + ) + policy = jax.vmap( + interpolate_policy_for_state_and_choice, + in_axes=(0, 0, 0, None), + )( + policy_grid_state_choice, + endog_grid_state_choice, + state_choices, + model_config, + ) + return jnp.squeeze(policy) def validate_stochastic_transition(params, model_config, model_funcs, model_structure): diff --git a/src/dcegm/interfaces/interface_checks.py b/src/dcegm/interfaces/interface_checks.py new file mode 100644 index 00000000..4062fb20 --- /dev/null +++ b/src/dcegm/interfaces/interface_checks.py @@ -0,0 +1,123 @@ +import numpy as np +from jax import numpy as jnp + + +def check_states_and_choices(states, choices, model_structure): + """Check if the states and choices are valid according to the model structure. + + Args: + states (dict): Dictionary containing state values. + choices (int): The choice value. + model_structure (dict): Model structure containing information on + discrete states and choices. + + Returns: + bool: True if the states and choices are valid, False otherwise. + + """ + discrete_states_names = model_structure["discrete_states_names"] + if "dummy_stochastic" in discrete_states_names: + if "dummy_stochastic" not in states.keys(): + need_to_add_dummy = True + # Check if all discrete states are present in states, except for the dummy stochastic state + observed_discrete_states = list( + set(discrete_states_names) - {"dummy_stochastic"} + ) + + else: + need_to_add_dummy = False + observed_discrete_states = discrete_states_names.copy() + + else: + need_to_add_dummy = False + observed_discrete_states = discrete_states_names.copy() + + if not all(state in states.keys() for state in observed_discrete_states): + raise ValueError("States should contain all discrete states.") + + # We start checking the dimensions: + # First check if the states are arrays or integers. If integers, all including choices need to be integers + # and we convert them to arrays. Determine first dimension of choice + if isinstance(choices, float): + raise ValueError("Choices should be integers or arrays, not floats. ") + # Check if choices is a single integer or numpy integers + elif isinstance(choices, (int, np.integer)): + choices = np.array([choices]) + single_state = True + # Now check if all states are integers + if not all( + isinstance(states[key], (int, np.integer)) + for key in observed_discrete_states + ): + raise ValueError( + "Discrete states should be integers or arrays. " + "As choices is a single integer, all states must be integers as well." + ) + else: + states = {key: np.array([value]) for key, value in states.items()} + + elif isinstance(choices, (np.ndarray, jnp.ndarray)): + if choices.ndim == 0: + # Check if choices has dtype int + if choices.dtype in [int, np.integer]: + raise ValueError( + "Choices should be integers or arrays with integer dtype." + ) + + choices = np.array([choices]) + single_state = True + # Now check if all states have dimension 0 as well + if not all(states[key].ndim == 0 for key in states.keys()): + raise ValueError( + "All states and choices must have the same dimension. Choices is dimension 0." + ) + # All observed discrete states must have dtype int as well + if not all( + states[key].dtype in [int, np.integer] + for key in observed_discrete_states + ): + raise ValueError( + "Discrete states should be integers or arrays with integer dtype. " + ) + states = {key: np.array([value]) for key, value in states.items()} + elif choices.ndim == 1: + # Check if choices has dtype int + if not np.issubdtype(choices.dtype, np.integer): + raise ValueError( + "Choices should be integers or arrays with integer dtype." + ) + single_state = False + # Check if all states are arrays with the same dimension as choices + if not all( + states[key].ndim == 1 and states[key].shape[0] == choices.shape[0] + for key in states.keys() + ): + raise ValueError( + "All states and choices must have the same dimension. Choices is dimension 1." + ) + # All observed discrete states must have dtype int as well + if not all( + np.issubdtype(states[key].dtype, np.integer) + for key in observed_discrete_states + ): + raise ValueError( + "Discrete states should be integers or arrays with integer dtype. " + ) + else: + raise ValueError( + "Choices should be integers or arrays with dimension 0 or 1." + ) + else: + raise ValueError("Choices should be integers or arrays with dimension 0 or 1.") + + if need_to_add_dummy: + if single_state: + states["dummy_stochastic"] = np.array([0]) + else: + states["dummy_stochastic"] = np.zeros(choices.shape[0], dtype=int) + + state_choices = { + **states, + "choice": choices, + } + return state_choices diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 7948eca8..8d6bd36d 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -6,9 +6,9 @@ import pandas as pd from dcegm.backward_induction import backward_induction -from dcegm.interfaces.inspect_structure import ( - get_child_state_index_per_state_choice, - get_state_choice_index_per_discrete_state, +from dcegm.interfaces.index_functions import ( + get_child_state_index_per_states_and_choices, + get_state_choice_index_per_discrete_states, ) from dcegm.interfaces.interface import validate_stochastic_transition from dcegm.interfaces.sol_interface import model_solved @@ -300,7 +300,7 @@ def validate_exogenous(self, params): def get_state_choices_idx(self, states): """Get the indices of the state choices for given states.""" - return get_state_choice_index_per_discrete_state( + return get_state_choice_index_per_discrete_states( states=states, map_state_choice_to_index=self.model_structure["map_state_choice_to_index"], discrete_states_names=self.model_structure["discrete_states_names"], @@ -312,8 +312,8 @@ def get_child_states(self, state, choice): "For this function the model needs to be created with debug_info='all'" ) - child_idx = get_child_state_index_per_state_choice( - states=state, choice=choice, model_structure=self.model_structure + child_idx = get_child_state_index_per_states_and_choices( + states=state, choices=choice, model_structure=self.model_structure ) state_space_dict = self.model_structure["state_space_dict"] discrete_states_names = self.model_structure["discrete_states_names"] @@ -343,8 +343,8 @@ def get_full_child_states_by_asset_id_and_probs( "For this function the model needs to be created with debug_info='all'" ) - child_idx = get_child_state_index_per_state_choice( - states=state, choice=choice, model_structure=self.model_structure + child_idx = get_child_state_index_per_states_and_choices( + states=state, choices=choice, model_structure=self.model_structure ) state_space_dict = self.model_structure["state_space_dict"] discrete_states_names = self.model_structure["discrete_states_names"] diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index b3aafec1..1bfabc3a 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -1,14 +1,18 @@ import jax.numpy as jnp +from dcegm.interfaces.index_functions import ( + get_state_choice_index_per_discrete_states_and_choices, +) from dcegm.interfaces.interface import ( - policy_and_value_for_state_choice_vec, + policy_and_value_for_states_and_choices, policy_for_state_choice_vec, value_for_state_and_choice, ) +from dcegm.interfaces.interface_checks import check_states_and_choices from dcegm.likelihood import ( calc_choice_probs_for_states, choice_values_for_states, - get_state_choice_index_per_discrete_state, + get_state_choice_index_per_discrete_states, ) from dcegm.simulation.sim_utils import create_simulation_df from dcegm.simulation.simulate import simulate_all_periods @@ -56,20 +60,20 @@ def simulate(self, states_initial, seed): ) return create_simulation_df(sim_dict) - def value_and_policy_for_state_and_choice(self, state, choice): + def value_and_policy_for_states_and_choices(self, states, choices): """Get the value and policy for a given state and choice. Args: - state: The state for which to get the value and policy. - choice: The choice for which to get the value and policy. + states: The state for which to get the value and policy. + choices: The choice for which to get the value and policy. Returns: A tuple containing the value and policy for the given state and choice. """ - return policy_and_value_for_state_choice_vec( - states=state, - choice=choice, + return policy_and_value_for_states_and_choices( + states=states, + choices=choices, model_config=self.model_config, model_structure=self.model_structure, model_funcs=self.model_funcs, @@ -79,12 +83,12 @@ def value_and_policy_for_state_and_choice(self, state, choice): policy_solved=self.policy, ) - def value_for_state_and_choice(self, state, choice): + def value_for_states_and_choices(self, states, choices): """Get the value for a given state and choice. Args: - state: The state for which to get the value. - choice: The choice for which to get the value. + states: The state for which to get the value. + choices: The choice for which to get the value. Returns: The value for the given state and choice. @@ -92,8 +96,8 @@ def value_for_state_and_choice(self, state, choice): """ return value_for_state_and_choice( - states=state, - choice=choice, + states=states, + choices=choices, model_config=self.model_config, model_structure=self.model_structure, model_funcs=self.model_funcs, @@ -102,12 +106,12 @@ def value_for_state_and_choice(self, state, choice): value_solved=self.value, ) - def policy_for_state_and_choice(self, state, choice): + def policy_for_states_and_choices(self, states, choices): """Get the policy for a given state and choice. Args: - state: The state for which to get the policy. - choice: The choice for which to get the policy. + states: The state for which to get the policy. + choices: The choice for which to get the policy. Returns: The policy for the given state and choice. @@ -115,47 +119,36 @@ def policy_for_state_and_choice(self, state, choice): """ return policy_for_state_choice_vec( - states=state, - choice=choice, + states=states, + choices=choices, model_config=self.model_config, model_structure=self.model_structure, endog_grid_solved=self.endog_grid, policy_solved=self.policy, ) - def get_solution_for_discrete_state_choice(self, states, choice): + def get_solution_for_discrete_state_choice(self, states, choices): """Get the solution container for a given discrete state and choice combination. Args: states: The state for which to get the solution. - choice: The choice for which to get the solution. + choices: The choice for which to get the solution. Returns: A tuple containing the wealth grid, value grid, and policy grid for the given state and choice. """ # Get the value and policy for a given state and choice. - - map_state_choice_to_index = self.model_structure[ - "map_state_choice_to_index_with_proxy" - ] - discrete_states_names = self.model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[state] for state in discrete_states_names + ["choice"] + state_choice_index = get_state_choice_index_per_discrete_states_and_choices( + model_structure=self.model_structure, + states=states, + choices=choices, + ) + # Check if the states and choices are valid according to the model structure. + check_states_and_choices( + states=states, + choices=choices, + model_structure=self.model_structure, ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] endog_grid = jnp.take(self.endog_grid, state_choice_index, axis=0) value_grid = jnp.take(self.value, state_choice_index, axis=0) @@ -165,7 +158,7 @@ def get_solution_for_discrete_state_choice(self, states, choice): def choice_probabilities_for_states(self, states): - state_choice_idxs = get_state_choice_index_per_discrete_state( + state_choice_idxs = get_state_choice_index_per_discrete_states( states=states, map_state_choice_to_index=self.model_structure[ "map_state_choice_to_index_with_proxy" @@ -184,7 +177,7 @@ def choice_probabilities_for_states(self, states): ) def choice_values_for_states(self, states): - state_choice_idxs = get_state_choice_index_per_discrete_state( + state_choice_idxs = get_state_choice_index_per_discrete_states( states=states, map_state_choice_to_index=self.model_structure[ "map_state_choice_to_index_with_proxy" diff --git a/src/dcegm/interpolation/interp1d.py b/src/dcegm/interpolation/interp1d.py index 4eed0272..6f7d73a9 100644 --- a/src/dcegm/interpolation/interp1d.py +++ b/src/dcegm/interpolation/interp1d.py @@ -38,9 +38,9 @@ def get_index_high_and_low(x, x_new): def interp1d_policy_and_value_on_wealth( wealth: float | jnp.ndarray, - endog_grid: jnp.ndarray, - policy: jnp.ndarray, - value: jnp.ndarray, + wealth_grid: jnp.ndarray, + policy_grid: jnp.ndarray, + value_grid: jnp.ndarray, compute_utility: Callable, state_choice_vec: Dict[str, int], params: Dict[str, float], @@ -50,9 +50,9 @@ def interp1d_policy_and_value_on_wealth( Args: wealth (float | jnp.ndarray): New wealth point(s) to interpolate. - endog_grid (jnp.ndarray): Solved endogenous wealth grid. - policy (jnp.ndarray): Solved policy function. - value (jnp.ndarray): Solved value function. + wealth_grid (jnp.ndarray): Solved endogenous wealth grid. + policy_grid (jnp.ndarray): Solved policy function. + value_grid (jnp.ndarray): Solved value function. state_choice_vec (Dict): Dictionary containing a single state and choice. params (Dict): Dictionary containing the model parameters. @@ -65,25 +65,25 @@ def interp1d_policy_and_value_on_wealth( """ # For all choices, the wealth is the same in the solution - ind_high, ind_low = get_index_high_and_low(x=endog_grid, x_new=wealth) + ind_high, ind_low = get_index_high_and_low(x=wealth_grid, x_new=wealth) policy_interp = linear_interpolation_formula( - y_high=policy[ind_high], - y_low=policy[ind_low], - x_high=endog_grid[ind_high], - x_low=endog_grid[ind_low], + y_high=policy_grid[ind_high], + y_low=policy_grid[ind_low], + x_high=wealth_grid[ind_high], + x_low=wealth_grid[ind_low], x_new=wealth, ) value_interp = interp_value_and_check_creditconstraint( - value_high=value[ind_high], - wealth_high=endog_grid[ind_high], - value_low=value[ind_low], - wealth_low=endog_grid[ind_low], + value_high=value_grid[ind_high], + wealth_high=wealth_grid[ind_high], + value_low=value_grid[ind_low], + wealth_low=wealth_grid[ind_low], new_wealth=wealth, compute_utility=compute_utility, - endog_grid_min=endog_grid[1], - value_at_zero_wealth=value[0], + endog_grid_min=wealth_grid[1], + value_at_zero_wealth=value_grid[0], state_choice_vec=state_choice_vec, params=params, discount_factor=discount_factor, @@ -94,7 +94,7 @@ def interp1d_policy_and_value_on_wealth( def interp_value_on_wealth( wealth: float | jnp.ndarray, - endog_grid: jnp.ndarray, + wealth_grid: jnp.ndarray, value: jnp.ndarray, compute_utility: Callable, state_choice_vec: Dict[str, int], @@ -105,7 +105,7 @@ def interp_value_on_wealth( Args: wealth (float): New wealth point to interpolate. - endog_grid (jnp.ndarray): Solved endogenous wealth grid. + wealth_grid (jnp.ndarray): Solved endogenous wealth grid. value (jnp.ndarray): Solved value function. state_choice_vec (Dict): Dictionary containing a single state and choice. params (Dict): Dictionary containing the model parameters. @@ -115,16 +115,16 @@ def interp_value_on_wealth( """ - ind_high, ind_low = get_index_high_and_low(x=endog_grid, x_new=wealth) + ind_high, ind_low = get_index_high_and_low(x=wealth_grid, x_new=wealth) value_interp = interp_value_and_check_creditconstraint( value_high=value[ind_high], - wealth_high=endog_grid[ind_high], + wealth_high=wealth_grid[ind_high], value_low=value[ind_low], - wealth_low=endog_grid[ind_low], + wealth_low=wealth_grid[ind_low], new_wealth=wealth, compute_utility=compute_utility, - endog_grid_min=endog_grid[1], + endog_grid_min=wealth_grid[1], value_at_zero_wealth=value[0], state_choice_vec=state_choice_vec, params=params, diff --git a/src/dcegm/interpolation/interp_interfaces.py b/src/dcegm/interpolation/interp_interfaces.py index 42317aca..f0a66b0b 100644 --- a/src/dcegm/interpolation/interp_interfaces.py +++ b/src/dcegm/interpolation/interp_interfaces.py @@ -1,5 +1,13 @@ -from dcegm.interpolation.interp1d import interp_value_on_wealth -from dcegm.interpolation.interp2d import interp2d_value_on_wealth_and_regular_grid +from dcegm.interpolation.interp1d import ( + interp1d_policy_and_value_on_wealth, + interp_policy_on_wealth, + interp_value_on_wealth, +) +from dcegm.interpolation.interp2d import ( + interp2d_policy_and_value_on_wealth_and_regular_grid, + interp2d_policy_on_wealth_and_regular_grid, + interp2d_value_on_wealth_and_regular_grid, +) def interpolate_value_for_state_and_choice( @@ -36,7 +44,7 @@ def interpolate_value_for_state_and_choice( value = interp_value_on_wealth( wealth=state_choice_vec["assets_begin_of_period"], - endog_grid=endog_grid_state_choice, + wealth_grid=endog_grid_state_choice, value=value_grid_state_choice, compute_utility=compute_utility, state_choice_vec=state_choice_vec, @@ -44,3 +52,82 @@ def interpolate_value_for_state_and_choice( discount_factor=discount_factor, ) return value + + +def interpolate_policy_for_state_and_choice( + policy_grid_state_choice, + endog_grid_state_choice, + state_choice_vec, + model_config, +): + """Interpolate the value for a state and choice given the respective grids.""" + continuous_states_info = model_config["continuous_states_info"] + + if continuous_states_info["second_continuous_exists"]: + second_continuous = state_choice_vec[ + continuous_states_info["second_continuous_state_name"] + ] + + policy = interp2d_policy_on_wealth_and_regular_grid( + regular_grid=continuous_states_info["second_continuous_grid"], + wealth_grid=endog_grid_state_choice, + policy_grid=policy_grid_state_choice, + regular_point_to_interp=second_continuous, + wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], + ) + + else: + policy = interp_policy_on_wealth( + wealth=state_choice_vec["assets_begin_of_period"], + endog_grid=endog_grid_state_choice, + policy=policy_grid_state_choice, + ) + + return policy + + +def interpolate_policy_and_value_for_state_and_choice( + value_grid_state_choice, + policy_grid_state_choice, + endog_grid_state_choice, + state_choice_vec, + params, + model_config, + model_funcs, +): + continuous_states_info = model_config["continuous_states_info"] + + compute_utility = model_funcs["compute_utility"] + discount_factor = model_funcs["read_funcs"]["discount_factor"](params) + + if continuous_states_info["second_continuous_exists"]: + + second_continuous = state_choice_vec[ + continuous_states_info["second_continuous_state_name"] + ] + + policy, value = interp2d_policy_and_value_on_wealth_and_regular_grid( + regular_grid=continuous_states_info["second_continuous_grid"], + wealth_grid=endog_grid_state_choice, + policy_grid=policy_grid_state_choice, + value_grid=value_grid_state_choice, + regular_point_to_interp=second_continuous, + wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + else: + policy, value = interp1d_policy_and_value_on_wealth( + wealth=state_choice_vec["assets_begin_of_period"], + wealth_grid=endog_grid_state_choice, + policy_grid=policy_grid_state_choice, + value_grid=value_grid_state_choice, + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + + return policy, value diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 03dcd028..c622968f 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -15,7 +15,7 @@ from dcegm.egm.aggregate_marginal_utility import ( calculate_choice_probs_and_unsqueezed_logsum, ) -from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state +from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.interpolation.interp_interfaces import interpolate_value_for_state_and_choice @@ -236,7 +236,7 @@ def create_partial_choice_prob_calculation( model_config, model_funcs, ): - discrete_observed_state_choice_indexes = get_state_choice_index_per_discrete_state( + discrete_observed_state_choice_indexes = get_state_choice_index_per_discrete_states( states=observed_states, map_state_choice_to_index=model_structure[ "map_state_choice_to_index_with_proxy" diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index 3f3a802d..fe2e25b7 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -4,7 +4,7 @@ from jax import numpy as jnp from jax import vmap -from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state +from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.interpolation.interp1d import interp1d_policy_and_value_on_wealth from dcegm.interpolation.interp2d import ( interp2d_policy_and_value_on_wealth_and_regular_grid, @@ -34,7 +34,7 @@ def interpolate_policy_and_value_for_all_agents( if continuous_state_beginning_of_period is not None: - discrete_state_choice_indexes = get_state_choice_index_per_discrete_state( + discrete_state_choice_indexes = get_state_choice_index_per_discrete_states( states=discrete_states_beginning_of_period, map_state_choice_to_index=map_state_choice_to_index, discrete_states_names=discrete_states_names, @@ -93,7 +93,7 @@ def interpolate_policy_and_value_for_all_agents( return policy_agent, value_agent else: - discrete_state_choice_indexes = get_state_choice_index_per_discrete_state( + discrete_state_choice_indexes = get_state_choice_index_per_discrete_states( states=discrete_states_beginning_of_period, map_state_choice_to_index=map_state_choice_to_index, discrete_states_names=discrete_states_names, @@ -303,9 +303,9 @@ def interp1d_policy_and_value_function( policy_interp, value_interp = interp1d_policy_and_value_on_wealth( wealth=wealth_beginning_of_period, - endog_grid=endog_grid_agent, - policy=policy_agent, - value=value_agent, + wealth_grid=endog_grid_agent, + policy_grid=policy_agent, + value_grid=value_agent, compute_utility=compute_utility, state_choice_vec=state_choice_vec, params=params, diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index 5886dc3f..c0ab9a92 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -7,7 +7,7 @@ import numpy as np from jax import vmap -from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state +from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.simulation.random_keys import draw_random_keys_for_seed from dcegm.simulation.sim_utils import ( compute_final_utility_for_each_choice, @@ -300,7 +300,7 @@ def simulate_final_period( params, compute_utility_final, ) - state_choice_indexes = get_state_choice_index_per_discrete_state( + state_choice_indexes = get_state_choice_index_per_discrete_states( states=states_begin_of_final_period, map_state_choice_to_index=map_state_choice_to_index, discrete_states_names=discrete_states_names, diff --git a/tests/sandbox/discrete_versus_continuous_experience.ipynb b/tests/sandbox/discrete_versus_continuous_experience.ipynb deleted file mode 100644 index 71a1b96e..00000000 --- a/tests/sandbox/discrete_versus_continuous_experience.ipynb +++ /dev/null @@ -1,47 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "0ed8eee6-6946-46ed-9064-f0be2aebb19c", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import jax.numpy as jnp\n", - "import jax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd0bbeea-7c35-45ff-ba12-bb4bcfce97bf", - "metadata": {}, - "outputs": [], - "source": [ - "OPTIONS = {}" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/sandbox/time_functions_jax.ipynb b/tests/sandbox/time_functions_jax.ipynb index d888f51c..c642478d 100644 --- a/tests/sandbox/time_functions_jax.ipynb +++ b/tests/sandbox/time_functions_jax.ipynb @@ -1,12 +1,108 @@ { "cells": [ + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:32:08.943565Z", + "start_time": "2025-08-15T09:32:08.305907Z" + } + }, + "cell_type": "code", + "source": [ + "import jax\n", + "import numpy as np" + ], + "id": "3939a0c83c7a102b", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:33:06.685423Z", + "start_time": "2025-08-15T09:33:06.572482Z" + } + }, + "cell_type": "code", + "source": [ + "def func_a(x, y):\n", + " return x + y\n", + "\n", + "jax.vmap(func_a, in_axes=(0, None))(np.array([2]), 3)" + ], + "id": "83f45f46db8be341", + "outputs": [ + { + "data": { + "text/plain": [ + "Array([5], dtype=int32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:52:23.042715Z", + "start_time": "2025-08-15T09:52:23.037510Z" + } + }, + "cell_type": "code", + "source": "isinstance(np.array(2), np.ndarray)", + "id": "d2b3690f1f318672", + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:54:14.537765Z", + "start_time": "2025-08-15T09:54:14.530364Z" + } + }, + "cell_type": "code", + "source": [ + "# Check if array is a scalar\n", + "np.array(2).dtype == int" + ], + "id": "d66fc223e808a7e2", + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 10 + }, { "cell_type": "code", "id": "35ab80e3", "metadata": { "ExecuteTime": { - "end_time": "2025-04-10T09:20:40.430908Z", - "start_time": "2025-04-10T09:20:39.276843Z" + "end_time": "2025-08-15T09:52:31.700874Z", + "start_time": "2025-08-15T09:52:30.996171Z" } }, "source": [ @@ -20,8 +116,20 @@ "import numpy as np\n", "from tests.utils.markov_simulator import markov_simulator" ], - "outputs": [], - "execution_count": 2 + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'tests'", + "output_type": "error", + "traceback": [ + "\u001B[31m---------------------------------------------------------------------------\u001B[39m", + "\u001B[31mModuleNotFoundError\u001B[39m Traceback (most recent call last)", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[8]\u001B[39m\u001B[32m, line 9\u001B[39m\n\u001B[32m 7\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjax\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjnp\u001B[39;00m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnp\u001B[39;00m\n\u001B[32m----> \u001B[39m\u001B[32m9\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtests\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mutils\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mmarkov_simulator\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m markov_simulator\n", + "\u001B[31mModuleNotFoundError\u001B[39m: No module named 'tests'" + ] + } + ], + "execution_count": 8 }, { "metadata": { @@ -197,19 +305,19 @@ "evalue": "len() of unsized object", "output_type": "error", "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", - "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1895\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", - "\u001b[31mIndexError\u001b[39m: tuple index out of range", + "\u001B[31m---------------------------------------------------------------------------\u001B[39m", + "\u001B[31mIndexError\u001B[39m Traceback (most recent call last)", + "\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1895\u001B[39m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mshape\u001B[49m\u001B[43m[\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n", + "\u001B[31mIndexError\u001B[39m: tuple index out of range", "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f, x, y))\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mjit_g\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_a\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_b\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m jit_g_aux(test_a, test_b)\n", - " \u001b[31m[... skipping hidden 13 frame]\u001b[39m\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 1\u001b[39m, in \u001b[36m\u001b[39m\u001b[34m(x, y)\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: \u001b[43mg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m 3\u001b[39m jit_g(test_a, test_b)\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 10\u001b[39m, in \u001b[36mg\u001b[39m\u001b[34m(func, x, y)\u001b[39m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mg\u001b[39m(func, x, y):\n\u001b[32m 9\u001b[39m func_val = func(x, y)\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfunc_val\u001b[49m\u001b[43m)\u001b[49m == \u001b[32m2\u001b[39m:\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m func_val[\u001b[32m0\u001b[39m]\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", - " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.shape[\u001b[32m0\u001b[39m]\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[32m-> \u001b[39m\u001b[32m1898\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mlen() of unsized object\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01merr\u001b[39;00m\n", - "\u001b[31mTypeError\u001b[39m: len() of unsized object" + "\u001B[31mTypeError\u001B[39m Traceback (most recent call last)", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f, x, y))\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m \u001B[43mjit_g\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtest_a\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtest_b\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4\u001B[39m jit_g_aux(test_a, test_b)\n", + " \u001B[31m[... skipping hidden 13 frame]\u001B[39m\n", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 1\u001B[39m, in \u001B[36m\u001B[39m\u001B[34m(x, y)\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: \u001B[43mg\u001B[49m\u001B[43m(\u001B[49m\u001B[43mf\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m 3\u001B[39m jit_g(test_a, test_b)\n", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[4]\u001B[39m\u001B[32m, line 10\u001B[39m, in \u001B[36mg\u001B[39m\u001B[34m(func, x, y)\u001B[39m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mg\u001B[39m(func, x, y):\n\u001B[32m 9\u001B[39m func_val = func(x, y)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28;43mlen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfunc_val\u001B[49m\u001B[43m)\u001B[49m == \u001B[32m2\u001B[39m:\n\u001B[32m 11\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m func_val[\u001B[32m0\u001B[39m]\n\u001B[32m 12\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n", + " \u001B[31m[... skipping hidden 1 frame]\u001B[39m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m.shape[\u001B[32m0\u001B[39m]\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n\u001B[32m-> \u001B[39m\u001B[32m1898\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[33m\"\u001B[39m\u001B[33mlen() of unsized object\u001B[39m\u001B[33m\"\u001B[39m) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01merr\u001B[39;00m\n", + "\u001B[31mTypeError\u001B[39m: len() of unsized object" ] } ], diff --git a/tests/test_discrete_versus_continuous_experience.py b/tests/test_discrete_versus_continuous_experience.py index 3587afea..2c65545f 100644 --- a/tests/test_discrete_versus_continuous_experience.py +++ b/tests/test_discrete_versus_continuous_experience.py @@ -138,24 +138,24 @@ def test_replication_discrete_versus_continuous_experience( states_cont["assets_begin_of_period"] = wealth_to_test - value_cont_interp = model_solved_cont.value_for_state_and_choice( - state=states_cont, - choice=choice, + value_cont_interp = model_solved_cont.value_for_states_and_choices( + states=states_cont, + choices=choice, ) - policy_cont_interp = model_solved_cont.policy_for_state_and_choice( - state=states_cont, - choice=choice, + policy_cont_interp = model_solved_cont.policy_for_states_and_choices( + states=states_cont, + choices=choice, ) states_disc["assets_begin_of_period"] = wealth_to_test - value_disc_interp = model_solved_disc.value_for_state_and_choice( - state=states_disc, - choice=choice, + value_disc_interp = model_solved_disc.value_for_states_and_choices( + states=states_disc, + choices=choice, ) - policy_disc_interp = model_solved_disc.policy_for_state_and_choice( - state=states_disc, - choice=choice, + policy_disc_interp = model_solved_disc.policy_for_states_and_choices( + states=states_disc, + choices=choice, ) # policy_cont_interp, = ( diff --git a/tests/test_replication.py b/tests/test_replication.py index f07259f9..ade553f0 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -81,9 +81,9 @@ def test_benchmark_models(model_name): "assets_begin_of_period": wealth_grid_to_test, } policy_calc_interp, value_calc_interp = ( - model_solved.value_and_policy_for_state_and_choice( - state=state, - choice=choice, + model_solved.value_and_policy_for_states_and_choices( + states=state, + choices=choice, ) ) diff --git a/tests/test_sparse_stochastic_and_batch_sep.py b/tests/test_sparse_stochastic_and_batch_sep.py index daa47246..ba9707c0 100644 --- a/tests/test_sparse_stochastic_and_batch_sep.py +++ b/tests/test_sparse_stochastic_and_batch_sep.py @@ -127,7 +127,7 @@ def test_benchmark_models(): } (endog_grid_full, policy_full, value_full) = ( model_solved_full.get_solution_for_discrete_state_choice( - states=states_dict, choice=state_choices_sparse[:, -1] + states=states_dict, choices=state_choices_sparse[:, -1] ) ) diff --git a/tests/test_utility_second_continuous.py b/tests/test_utility_second_continuous.py index 1e026eb8..705fab75 100644 --- a/tests/test_utility_second_continuous.py +++ b/tests/test_utility_second_continuous.py @@ -388,9 +388,9 @@ def test_replication_discrete_versus_continuous_experience( policy_disc_interp, value_disc_interp = interp1d_policy_and_value_on_wealth( wealth=jnp.array(wealth_to_test), - endog_grid=endog_grid_disc[idx_state_choice_disc], - policy=policy_disc[idx_state_choice_disc], - value=value_disc[idx_state_choice_disc], + wealth_grid=endog_grid_disc[idx_state_choice_disc], + policy_grid=policy_disc[idx_state_choice_disc], + value_grid=value_disc[idx_state_choice_disc], compute_utility=model_disc.model_funcs["compute_utility"], state_choice_vec=state_choice_disc_dict, params=PARAMS, diff --git a/tests/test_varying_shock_scale.py b/tests/test_varying_shock_scale.py index 156f0bf6..9e2b1368 100644 --- a/tests/test_varying_shock_scale.py +++ b/tests/test_varying_shock_scale.py @@ -70,9 +70,9 @@ def test_benchmark_models(): ( policy_calc_interp, value_calc_interp, - ) = model_solved.value_and_policy_for_state_and_choice( - state=state, - choice=choice, + ) = model_solved.value_and_policy_for_states_and_choices( + states=state, + choices=choice, ) aaae(policy_expec_interp, policy_calc_interp) From 297de360a3251c539d5c67d8b8791b6545a82501 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 15 Aug 2025 14:54:11 +0200 Subject: [PATCH 10/34] Finalize. --- src/dcegm/interfaces/interface.py | 26 +++++++++++++----------- src/dcegm/interfaces/interface_checks.py | 2 +- src/dcegm/interfaces/sol_interface.py | 16 ++++++++------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/dcegm/interfaces/interface.py b/src/dcegm/interfaces/interface.py index 360c4b6a..6a665755 100644 --- a/src/dcegm/interfaces/interface.py +++ b/src/dcegm/interfaces/interface.py @@ -66,17 +66,17 @@ def policy_and_value_for_states_and_choices( choice. """ + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure + ) state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( - states=states, choices=choices, model_structure=model_structure + states=state_choices, choices=choices, model_structure=model_structure ) endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) value_grid_state_choice = jnp.take(value_solved, state_choice_idx, axis=0) policy_grid_state_choice = jnp.take(policy_solved, state_choice_idx, axis=0) - state_choices = check_states_and_choices( - states=states, choices=choices, model_structure=model_structure - ) policy, value = jax.vmap( interpolate_policy_and_value_for_state_and_choice, in_axes=(0, 0, 0, 0, None, None, None), @@ -122,15 +122,16 @@ def value_for_state_and_choice( float: The value at the given state and choice. """ - state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + state_choices = check_states_and_choices( states=states, choices=choices, model_structure=model_structure ) + + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=state_choices, choices=choices, model_structure=model_structure + ) endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) value_grid_state_choice = jnp.take(value_solved, state_choice_idx, axis=0) - state_choices = check_states_and_choices( - states=states, choices=choices, model_structure=model_structure - ) value = jax.vmap( interpolate_value_for_state_and_choice, in_axes=(0, 0, 0, None, None, None), @@ -169,15 +170,16 @@ def policy_for_state_choice_vec( float: The policy at the given state and choice. """ - state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + state_choices = check_states_and_choices( states=states, choices=choices, model_structure=model_structure ) + + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=state_choices, choices=choices, model_structure=model_structure + ) endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) policy_grid_state_choice = jnp.take(policy_solved, state_choice_idx, axis=0) - state_choices = check_states_and_choices( - states=states, choices=choices, model_structure=model_structure - ) policy = jax.vmap( interpolate_policy_for_state_and_choice, in_axes=(0, 0, 0, None), diff --git a/src/dcegm/interfaces/interface_checks.py b/src/dcegm/interfaces/interface_checks.py index 4062fb20..61412f17 100644 --- a/src/dcegm/interfaces/interface_checks.py +++ b/src/dcegm/interfaces/interface_checks.py @@ -12,7 +12,7 @@ def check_states_and_choices(states, choices, model_structure): discrete states and choices. Returns: - bool: True if the states and choices are valid, False otherwise. + state_choices (dict): Dictionary containing the states and choices, """ discrete_states_names = model_structure["discrete_states_names"] diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index 1bfabc3a..f4f656bd 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -137,19 +137,21 @@ def get_solution_for_discrete_state_choice(self, states, choices): A tuple containing the wealth grid, value grid, and policy grid for the given state and choice. """ - # Get the value and policy for a given state and choice. - state_choice_index = get_state_choice_index_per_discrete_states_and_choices( - model_structure=self.model_structure, - states=states, - choices=choices, - ) # Check if the states and choices are valid according to the model structure. - check_states_and_choices( + state_choices = check_states_and_choices( states=states, choices=choices, model_structure=self.model_structure, ) + # Get the value and policy for a given state and choice. We use state choices as states as it is not important + # that these are missing. + state_choice_index = get_state_choice_index_per_discrete_states_and_choices( + model_structure=self.model_structure, + states=state_choices, + choices=state_choices["choice"], + ) + endog_grid = jnp.take(self.endog_grid, state_choice_index, axis=0) value_grid = jnp.take(self.value, state_choice_index, axis=0) policy_grid = jnp.take(self.policy, state_choice_index, axis=0) From 3081f27c4d3fa2d76fc9b4c1f7bc4447b5d78c5f Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Mon, 18 Aug 2025 19:22:18 +0200 Subject: [PATCH 11/34] Fixed it --- src/dcegm/likelihood.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index c622968f..2d1fc8ab 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -382,8 +382,8 @@ def wrapper_interp_value_for_choice( in_axes=(0, 0, 0, None), )( states, - endog_grid_states, value_grid_states, + endog_grid_states, choice_range, ) return choice_values_per_state From f8cad0e16b4e46371b24573511500528bad53037 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Wed, 20 Aug 2025 20:42:48 +0200 Subject: [PATCH 12/34] Modularizte --- src/dcegm/likelihood.py | 64 ++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 2d1fc8ab..69fed8a5 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -33,26 +33,16 @@ def create_individual_likelihood_function( use_probability_of_observed_states=True, ): - if unobserved_state_specs is None: - choice_prob_func = create_partial_choice_prob_calculation( - observed_states=observed_states, - observed_choices=observed_choices, - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - ) - else: - - choice_prob_func = create_choice_prob_func_unobserved_states( - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - model_specs=model_specs, - observed_states=observed_states, - observed_choices=observed_choices, - unobserved_state_specs=unobserved_state_specs, - use_probability_of_observed_states=use_probability_of_observed_states, - ) + choice_prob_func = create_choice_prob_function( + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + model_specs=model_specs, + observed_states=observed_states, + observed_choices=observed_choices, + unobserved_state_specs=unobserved_state_specs, + use_probability_of_observed_states=use_probability_of_observed_states, + ) def individual_likelihood(params): params_update = params_all.copy() @@ -82,6 +72,40 @@ def individual_likelihood(params): return jax.jit(individual_likelihood) +def create_choice_prob_function( + model_structure, + model_config, + model_funcs, + model_specs, + observed_states, + observed_choices, + unobserved_state_specs, + use_probability_of_observed_states, +): + if unobserved_state_specs is None: + choice_prob_func = create_partial_choice_prob_calculation( + observed_states=observed_states, + observed_choices=observed_choices, + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + ) + else: + + choice_prob_func = create_choice_prob_func_unobserved_states( + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + model_specs=model_specs, + observed_states=observed_states, + observed_choices=observed_choices, + unobserved_state_specs=unobserved_state_specs, + use_probability_of_observed_states=use_probability_of_observed_states, + ) + + return choice_prob_func + + def create_choice_prob_func_unobserved_states( model_structure, model_config, From 5b7bffc8814ee2f398b36797902ddd38e1490997 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 11 Sep 2025 13:30:12 +0200 Subject: [PATCH 13/34] Removed some more numpy --- src/dcegm/backward_induction.py | 1 - src/dcegm/egm/aggregate_marginal_utility.py | 6 ++-- src/dcegm/egm/solve_euler_equation.py | 33 ++++++++++----------- src/dcegm/likelihood.py | 12 ++++++-- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/dcegm/backward_induction.py b/src/dcegm/backward_induction.py index a75a9a2b..5182f7a5 100644 --- a/src/dcegm/backward_induction.py +++ b/src/dcegm/backward_induction.py @@ -4,7 +4,6 @@ import jax.lax import jax.numpy as jnp -import numpy as np from dcegm.final_periods import solve_last_two_periods from dcegm.law_of_motion import calc_cont_grids_next_period diff --git a/src/dcegm/egm/aggregate_marginal_utility.py b/src/dcegm/egm/aggregate_marginal_utility.py index 08ec2af5..a707ecda 100644 --- a/src/dcegm/egm/aggregate_marginal_utility.py +++ b/src/dcegm/egm/aggregate_marginal_utility.py @@ -1,13 +1,12 @@ from typing import Tuple import jax.numpy as jnp -import numpy as np def aggregate_marg_utils_and_exp_values( value_state_choice_specific: jnp.ndarray, marg_util_state_choice_specific: jnp.ndarray, - reshape_state_choice_vec_to_mat: np.ndarray, + reshape_state_choice_vec_to_mat: jnp.ndarray, taste_shock_scale, taste_shock_scale_is_scalar, income_shock_weights: jnp.ndarray, @@ -47,11 +46,12 @@ def aggregate_marg_utils_and_exp_values( mode="fill", fill_value=jnp.nan, ) + # If taste shock is not scalar, we select from the array, # where we have for each choice a taste shock scale one. They are by construction # the same for all choices in a state if not taste_shock_scale_is_scalar: - one_choice_per_state = np.min(reshape_state_choice_vec_to_mat, axis=1) + one_choice_per_state = jnp.min(reshape_state_choice_vec_to_mat, axis=1) taste_shock_scale = jnp.take( taste_shock_scale, one_choice_per_state, diff --git a/src/dcegm/egm/solve_euler_equation.py b/src/dcegm/egm/solve_euler_equation.py index b1a10c06..3cc3f8da 100644 --- a/src/dcegm/egm/solve_euler_equation.py +++ b/src/dcegm/egm/solve_euler_equation.py @@ -2,21 +2,20 @@ from typing import Callable, Dict, Tuple -import numpy as np from jax import numpy as jnp from jax import vmap def calculate_candidate_solutions_from_euler_equation( - continuous_grids_info: np.ndarray, + continuous_grids_info: jnp.ndarray, marg_util_next: jnp.ndarray, emax_next: jnp.ndarray, - state_choice_mat: np.ndarray, - idx_post_decision_child_states: np.ndarray, + state_choice_mat: jnp.ndarray, + idx_post_decision_child_states: jnp.ndarray, model_funcs: Dict[str, Callable], has_second_continuous_state: bool, params: Dict[str, float], -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Calculate candidates for the optimal policy and value function.""" feasible_marg_utils_child = jnp.take( @@ -78,14 +77,14 @@ def calculate_candidate_solutions_from_euler_equation( def compute_optimal_policy_and_value_wrapper( - marg_util_next: np.ndarray, - emax_next: np.ndarray, - second_continuous_grid: np.ndarray, - assets_grid_end_of_period: np.ndarray, + marg_util_next: jnp.ndarray, + emax_next: jnp.ndarray, + second_continuous_grid: jnp.ndarray, + assets_grid_end_of_period: jnp.ndarray, state_choice_vec: Dict, model_funcs: Dict[str, Callable], params: Dict[str, float], -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Write second continuous grid point into state_choice_vec.""" state_choice_vec["continuous_state"] = second_continuous_grid @@ -100,13 +99,13 @@ def compute_optimal_policy_and_value_wrapper( def compute_optimal_policy_and_value( - marg_util_next: np.ndarray, - emax_next: np.ndarray, - assets_grid_end_of_period: np.ndarray, + marg_util_next: jnp.ndarray, + emax_next: jnp.ndarray, + assets_grid_end_of_period: jnp.ndarray, state_choice_vec: Dict, model_funcs: Dict[str, Callable], params: Dict[str, float], -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Compute optimal child-state- and choice-specific policy and value function. Given the marginal utilities of possible child states and next period wealth, we @@ -173,14 +172,14 @@ def compute_optimal_policy_and_value( def solve_euler_equation( state_choice_vec: dict, - marg_util_next: np.ndarray, - emax_next: np.ndarray, + marg_util_next: jnp.ndarray, + emax_next: jnp.ndarray, compute_inverse_marginal_utility: Callable, compute_stochastic_transition_vec: Callable, params: Dict[str, float], discount_factor: float, interest_rate: float, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Solve the Euler equation for given discrete choice and child states. We integrate over the exogenous process and income uncertainty and diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 69fed8a5..0475ab96 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -26,7 +26,7 @@ def create_individual_likelihood_function( model_specs, backwards_induction, observed_states: Dict[str, int], - observed_choices: np.array, + observed_choices, params_all, unobserved_state_specs=None, return_model_solution=False, @@ -112,7 +112,7 @@ def create_choice_prob_func_unobserved_states( model_funcs, model_specs, observed_states: Dict[str, int], - observed_choices: np.array, + observed_choices, unobserved_state_specs, use_probability_of_observed_states=True, ): @@ -191,6 +191,8 @@ def create_choice_prob_func_unobserved_states( observed_weights[observed_bools[state_name]] /= n_state_values + observed_weights = jnp.asarray(observed_weights) + # Create a list of partial choice probability functions for each unique # combination of unobserved states. partial_choice_probs_unobserved_states = [] @@ -215,6 +217,12 @@ def create_choice_prob_func_unobserved_states( n_obs = len(observed_choices) + # Use jax tree map to make only jax arrays of possible states and weighting vars + possible_states = jax.tree_map(lambda x: jnp.asarray(x), possible_states) + weighting_vars_for_possible_states = jax.tree_map( + lambda x: jnp.asarray(x), weighting_vars_for_possible_states + ) + def choice_prob_func(value_in, endog_grid_in, params_in): choice_probs_final = jnp.zeros(n_obs, dtype=jnp.float64) integrate_out_weights = jnp.zeros(n_obs, dtype=jnp.float64) From a36353c21d52cf5babfb4133243681a4afcd42bc Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 11 Sep 2025 15:47:10 +0200 Subject: [PATCH 14/34] More jax arrays --- src/dcegm/pre_processing/setup_model.py | 12 +++++++++--- src/dcegm/pre_processing/shared.py | 7 +++++++ tests/test_replication.py | 9 +++++---- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index a2842d35..49e16350 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -2,6 +2,7 @@ from typing import Callable, Dict import jax +import jax.numpy as jnp from dcegm.pre_processing.batches.batch_creation import create_batches_and_information from dcegm.pre_processing.check_model_config import check_model_config_and_process @@ -15,7 +16,10 @@ from dcegm.pre_processing.model_structure.stochastic_states import ( create_stochastic_state_mapping, ) -from dcegm.pre_processing.shared import create_array_with_smallest_int_dtype +from dcegm.pre_processing.shared import ( + create_array_with_smallest_int_dtype, + try_jax_array, +) def create_model_dict( @@ -109,12 +113,14 @@ def create_model_dict( model_structure.pop("map_state_choice_to_child_states") model_structure.pop("map_state_choice_to_index") + batch_info = jax.tree.map(create_array_with_smallest_int_dtype, batch_info) print("Model setup complete.\n") return { "model_config": model_config_processed, "model_funcs": model_funcs, - "model_structure": model_structure, - "batch_info": jax.tree.map(create_array_with_smallest_int_dtype, batch_info), + # Model structure are also lists, therefore we use try function + "model_structure": jax.tree.map(try_jax_array, model_structure), + "batch_info": jax.tree.map(jnp.asarray, batch_info), } diff --git a/src/dcegm/pre_processing/shared.py b/src/dcegm/pre_processing/shared.py index cffcf7e0..4ff716e9 100644 --- a/src/dcegm/pre_processing/shared.py +++ b/src/dcegm/pre_processing/shared.py @@ -65,3 +65,10 @@ def get_smallest_int_type(n_values): for dtype in uint_types: if np.iinfo(dtype).max >= n_values: return dtype + + +def try_jax_array(x): + try: + return jnp.asarray(x) + except: + return x diff --git a/tests/test_replication.py b/tests/test_replication.py index ade553f0..1a27b958 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -74,16 +74,17 @@ def test_benchmark_models(model_name): policy_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=policy_expec[0], y=policy_expec[1] ) - + lagged_choice = state_choice_space_to_test[state_choice_idx, 1] state = { - "period": period, - "lagged_choice": state_choice_space_to_test[state_choice_idx, 1], + "period": jnp.ones_like(wealth_grid_to_test, dtype=int) * period, + "lagged_choice": jnp.ones_like(wealth_grid_to_test, dtype=int) + * lagged_choice, "assets_begin_of_period": wealth_grid_to_test, } policy_calc_interp, value_calc_interp = ( model_solved.value_and_policy_for_states_and_choices( states=state, - choices=choice, + choices=jnp.ones_like(wealth_grid_to_test, dtype=int) * choice, ) ) From e579b43d3744c6b6cea036fe3965299a024897aa Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 11 Sep 2025 16:14:34 +0200 Subject: [PATCH 15/34] More tree utils --- src/dcegm/interfaces/model_class.py | 4 ++- src/dcegm/likelihood.py | 4 +-- src/dcegm/pre_processing/check_model_specs.py | 34 +++++++------------ .../process_model_functions.py | 25 ++++++++------ tests/test_varying_shock_scale.py | 8 +++-- 5 files changed, 37 insertions(+), 38 deletions(-) diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 8d6bd36d..fc6e8cd7 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -24,6 +24,7 @@ create_model_dict_and_save, load_model_dict, ) +from dcegm.pre_processing.shared import try_jax_array from dcegm.simulation.sim_utils import create_simulation_df from dcegm.simulation.simulate import simulate_all_periods @@ -46,7 +47,6 @@ def __init__( ): """Setup the model and check if load or save is required.""" - self.model_specs = model_specs if model_load_path is not None: model_dict = load_model_dict( model_config=model_config, @@ -85,6 +85,8 @@ def __init__( debug_info=debug_info, ) + self.model_specs = jax.tree_util.tree_map(try_jax_array, model_specs) + self.model_config = model_dict["model_config"] self.model_funcs = model_dict["model_funcs"] self.model_structure = model_dict["model_structure"] diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 0475ab96..7acf59ee 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -218,8 +218,8 @@ def create_choice_prob_func_unobserved_states( n_obs = len(observed_choices) # Use jax tree map to make only jax arrays of possible states and weighting vars - possible_states = jax.tree_map(lambda x: jnp.asarray(x), possible_states) - weighting_vars_for_possible_states = jax.tree_map( + possible_states = jax.tree_util.tree_map(lambda x: jnp.asarray(x), possible_states) + weighting_vars_for_possible_states = jax.tree_util.tree_map( lambda x: jnp.asarray(x), weighting_vars_for_possible_states ) diff --git a/src/dcegm/pre_processing/check_model_specs.py b/src/dcegm/pre_processing/check_model_specs.py index b5281bac..098b49df 100644 --- a/src/dcegm/pre_processing/check_model_specs.py +++ b/src/dcegm/pre_processing/check_model_specs.py @@ -11,51 +11,41 @@ def extract_model_specs_info(model_specs): # discount_factor processing if "discount_factor" in model_specs: - read_func_discount_factor = lambda params: jnp.asarray( - model_specs["discount_factor"] - ) + discount_factor = jnp.asarray(model_specs["discount_factor"]) + read_func_discount_factor = lambda params: discount_factor discount_factor_in_params = False else: - read_func_discount_factor = lambda params: jnp.asarray( - params["discount_factor"] - ) + read_func_discount_factor = lambda params: params["discount_factor"] discount_factor_in_params = True # interest_rate processing if "interest_rate" in model_specs: # Check if interest_rate is a scalar - read_func_interest_rate = lambda params: jnp.asarray( - model_specs["interest_rate"] - ) + interest_rate = jnp.asarray(model_specs["interest_rate"]) + read_func_interest_rate = lambda params: interest_rate interest_rate_in_params = False else: - read_func_interest_rate = lambda params: jnp.asarray(params["interest_rate"]) + read_func_interest_rate = lambda params: params["interest_rate"] interest_rate_in_params = True # income shock std processing ("income_shock_std") if "income_shock_std" in model_specs: # Check if income_shock_std is a scalar - read_func_income_shock_std = lambda params: jnp.asarray( - model_specs["income_shock_std"] - ) + income_shock_std = jnp.asarray(model_specs["income_shock_std"]) + read_func_income_shock_std = lambda params: income_shock_std income_shock_std_in_params = False else: - read_func_income_shock_std = lambda params: jnp.asarray( - params["income_shock_std"] - ) + read_func_income_shock_std = lambda params: params["income_shock_std"] income_shock_std_in_params = True # income shock std processing ("income_shock_std") if "income_shock_mean" in model_specs: # Check if income_shock_std is a scalar - read_func_income_shock_mean = lambda params: jnp.asarray( - model_specs["income_shock_mean"] - ) + income_shock_mean = jnp.asarray(model_specs["income_shock_mean"]) + read_func_income_shock_mean = lambda params: income_shock_mean income_shock_mean_in_params = False else: - read_func_income_shock_mean = lambda params: jnp.asarray( - params["income_shock_mean"] - ) + read_func_income_shock_mean = lambda params: params["income_shock_mean"] income_shock_mean_in_params = True specs_read_funcs = { diff --git a/src/dcegm/pre_processing/model_functions/process_model_functions.py b/src/dcegm/pre_processing/model_functions/process_model_functions.py index eca89668..44a0467e 100644 --- a/src/dcegm/pre_processing/model_functions/process_model_functions.py +++ b/src/dcegm/pre_processing/model_functions/process_model_functions.py @@ -1,5 +1,6 @@ from typing import Callable, Dict, Optional +import jax import jax.numpy as jnp from dcegm.pre_processing.model_functions.taste_shock_function import ( @@ -13,6 +14,7 @@ ) from dcegm.pre_processing.shared import ( determine_function_arguments_and_partial_model_specs, + try_jax_array, ) @@ -70,23 +72,26 @@ def process_model_functions_and_extract_info( "second_continuous_state_name" ] + # We use this for functions which are called later in the jitted code + model_specs_jax = jax.tree_util.tree_map(try_jax_array, model_specs) + # Process mandatory functions. Start with utility functions compute_utility = determine_function_arguments_and_partial_model_specs( func=utility_functions["utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) compute_marginal_utility = determine_function_arguments_and_partial_model_specs( func=utility_functions["marginal_utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) compute_inverse_marginal_utility = ( determine_function_arguments_and_partial_model_specs( func=utility_functions["inverse_marginal_utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) @@ -99,14 +104,14 @@ def process_model_functions_and_extract_info( # Final period utility functions compute_utility_final = determine_function_arguments_and_partial_model_specs( func=utility_functions_final_period["utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) compute_marginal_utility_final = ( determine_function_arguments_and_partial_model_specs( func=utility_functions_final_period["marginal_utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) @@ -121,12 +126,12 @@ def process_model_functions_and_extract_info( create_stochastic_transition_function( stochastic_states_transitions, model_config=model_config, - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) - # Now state space functions + # Now state space functions - here we use the old model_specs state_specific_choice_set, next_period_deterministic_state, sparsity_condition = ( process_state_space_functions( state_space_functions, @@ -137,7 +142,7 @@ def process_model_functions_and_extract_info( ) next_period_continuous_state = process_second_continuous_update_function( - second_continuous_state_name, state_space_functions, model_specs=model_specs + second_continuous_state_name, state_space_functions, model_specs=model_specs_jax ) # Budget equation @@ -145,7 +150,7 @@ def process_model_functions_and_extract_info( determine_function_arguments_and_partial_model_specs( func=budget_constraint, continuous_state_name=second_continuous_state_name, - model_specs=model_specs, + model_specs=model_specs_jax, ) ) @@ -158,7 +163,7 @@ def process_model_functions_and_extract_info( taste_shock_function_processed, taste_shock_scale_in_params = ( process_shock_functions( shock_functions, - model_specs, + model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) diff --git a/tests/test_varying_shock_scale.py b/tests/test_varying_shock_scale.py index 9e2b1368..c36dc91b 100644 --- a/tests/test_varying_shock_scale.py +++ b/tests/test_varying_shock_scale.py @@ -61,10 +61,12 @@ def test_benchmark_models(): policy_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=policy_expec[0], y=policy_expec[1] ) + lagged_choice = state_choice_space_to_test[state_choice_idx, 1] state = { - "period": period, - "lagged_choice": state_choice_space_to_test[state_choice_idx, 1], + "period": jnp.ones_like(wealth_grid_to_test, dtype=int) * period, + "lagged_choice": jnp.ones_like(wealth_grid_to_test, dtype=int) + * lagged_choice, "assets_begin_of_period": wealth_grid_to_test, } ( @@ -72,7 +74,7 @@ def test_benchmark_models(): value_calc_interp, ) = model_solved.value_and_policy_for_states_and_choices( states=state, - choices=choice, + choices=jnp.ones_like(wealth_grid_to_test, dtype=int) * choice, ) aaae(policy_expec_interp, policy_calc_interp) From 948da431b02683e952d8a97797666e97c0cbe046 Mon Sep 17 00:00:00 2001 From: Maximilian Blesch Date: Thu, 11 Sep 2025 18:54:12 +0200 Subject: [PATCH 16/34] Colab (#185) --- src/dcegm/backward_induction.py | 116 +++++++----------- src/dcegm/interfaces/model_class.py | 63 +++++++--- src/dcegm/likelihood.py | 10 +- src/dcegm/numerical_integration.py | 7 +- .../alternative_sim_functions.py | 14 ++- .../process_model_functions.py | 5 +- .../model_functions/taste_shock_function.py | 10 +- tests/test_replication.py | 2 +- 8 files changed, 126 insertions(+), 101 deletions(-) diff --git a/src/dcegm/backward_induction.py b/src/dcegm/backward_induction.py index 5182f7a5..af59fbf7 100644 --- a/src/dcegm/backward_induction.py +++ b/src/dcegm/backward_induction.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Tuple +import jax import jax.lax import jax.numpy as jnp @@ -24,72 +25,35 @@ def backward_induction( Args: params (dict): Dictionary containing the model parameters. - options (dict): Dictionary containing the model options. - period_specific_state_objects (np.ndarray): Dictionary containing - period-specific state and state-choice objects, with the following keys: - - "state_choice_mat" (jnp.ndarray) - - "idx_state_of_state_choice" (jnp.ndarray) - - "reshape_state_choice_vec_to_mat" (callable) - - "transform_between_state_and_state_choice_vec" (callable) - exog_savings_grid (np.ndarray): 1d array of shape (n_grid_wealth,) - containing the exogenous savings grid. - has_second_continuous_state (bool): Boolean indicating whether the model - features a second continuous state variable. If False, the only - continuous state variable is consumption/savings. - state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1) - which serves as a collection of all possible states. By convention, - the first column must contain the period and the last column the - exogenous processes. Any other state variables are in between. - E.g. if the two state variables are period and lagged choice and all choices - are admissible in each period, the shape of the state space array is - (n_periods * n_choices, 3). - state_choice_space (np.ndarray): 2d array of shape - (n_feasible_states, n_state_and_exog_variables + 1) containing all - feasible state-choice combinations. By convention, the second to last - column contains the exogenous process. The last column always contains the - choice to be made (which is not a state variable). income_shock_draws_unscaled (np.ndarray): 1d array of shape (n_quad_points,) containing the Hermite quadrature points unscaled. income_shock_weights (np.ndarrray): 1d array of shape (n_stochastic_quad_points) with weights for each stoachstic shock draw. - n_periods (int): Number of periods. - model_funcs (dict): Dictionary containing following model functions: - - compute_marginal_utility (callable): User-defined function to compute the - agent's marginal utility. The input ```params``` is already partialled - in. - - compute_inverse_marginal_utility (Callable): Function for calculating the - inverse marginal utiFality, which takes the marginal utility as only - input. - - compute_next_period_wealth (callable): User-defined function to compute - the agent's wealth of the next period (t + 1). The inputs - ```saving```, ```shock```, ```params``` and ```options``` - are already partialled in. - - transition_vector_by_state (Callable): Partialled transition function - return transition vector for each state. - - final_period_partial (Callable): Partialled function for calculating the - consumption as well as value function and marginal utility in the final - period. - compute_upper_envelope (Callable): Function for calculating the upper - envelope of the policy and value function. If the number of discrete - choices is 1, this function is a dummy function that returns the policy - and value function as is, without performing a fast upper envelope - scan. + model_config (dict): Dictionary containing the model configuration. + model_funcs (dict): Dictionary containing model functions. + model_structure (dict): Dictionary containing model structure. + batch_info (dict): Dictionary containing batch information. Returns: - dict: Dictionary containing the period-specific endog_grid, policy, and value + Tuple: Tuple containing the period-specific endog_grid, policy, and value from the backward induction. """ continuous_states_info = model_config["continuous_states_info"] - cont_grids_next_period = calc_cont_grids_next_period( - model_structure=model_structure, - model_config=model_config, - income_shock_draws_unscaled=income_shock_draws_unscaled, - params=params, - model_funcs=model_funcs, + # + calc_grids_jit = jax.jit( + lambda income_shock_draws, params_inner: calc_cont_grids_next_period( + model_structure=model_structure, + model_config=model_config, + income_shock_draws_unscaled=income_shock_draws, + params=params_inner, + model_funcs=model_funcs, + ) ) + cont_grids_next_period = calc_grids_jit(income_shock_draws_unscaled, params) + ( value_solved, policy_solved, @@ -101,32 +65,42 @@ def backward_induction( n_state_choices=model_structure["state_choice_space"].shape[0], ) - # Solve the last two periods. We do this separately as the marginal utility of - # the child states in the last period is calculated from the marginal utility - # function of the bequest function, which might differ. + # Solve the last two periods using lambda to capture static arguments + solve_last_two_period_jit = jax.jit( + lambda params_inner, cont_grids, weights, val_solved, pol_solved, endog_solved: solve_last_two_periods( + params=params_inner, + continuous_states_info=continuous_states_info, + cont_grids_next_period=cont_grids, + income_shock_weights=weights, + model_funcs=model_funcs, + last_two_period_batch_info=batch_info["last_two_period_info"], + value_solved=val_solved, + policy_solved=pol_solved, + endog_grid_solved=endog_solved, + debug_info=None, + ) + ) + ( value_solved, policy_solved, endog_grid_solved, - ) = solve_last_two_periods( - params=params, - continuous_states_info=continuous_states_info, - cont_grids_next_period=cont_grids_next_period, - income_shock_weights=income_shock_weights, - model_funcs=model_funcs, - last_two_period_batch_info=batch_info["last_two_period_info"], - value_solved=value_solved, - policy_solved=policy_solved, - endog_grid_solved=endog_grid_solved, - debug_info=None, + ) = solve_last_two_period_jit( + params, + cont_grids_next_period, + income_shock_weights, + value_solved, + policy_solved, + endog_grid_solved, ) # If it is a two period model we are done. if batch_info["two_period_model"]: return value_solved, policy_solved, endog_grid_solved - def partial_single_period(carry, xs): - return solve_single_period( + # Create JIT-compiled single period solver using lambda + solve_single_period_jit = jax.jit( + lambda carry, xs: solve_single_period( carry=carry, xs=xs, params=params, @@ -136,6 +110,10 @@ def partial_single_period(carry, xs): income_shock_weights=income_shock_weights, debug_info=None, ) + ) + + def partial_single_period(carry, xs): + return solve_single_period_jit(carry, xs) for id_segment in range(batch_info["n_segments"]): segment_info = batch_info[f"batches_info_segment_{id_segment}"] diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index fc6e8cd7..e3625e0f 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -101,6 +101,15 @@ def __init__( self.income_shock_draws_unscaled = income_shock_draws_unscaled self.income_shock_weights = income_shock_weights + if alternative_sim_specifications is not None: + self.alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=model_specs, + model_specs_jax=self.model_specs, + **alternative_sim_specifications, + ) + else: + self.alternative_sim_funcs = None + backward_jit = jax.jit( partial( backward_induction, @@ -115,14 +124,18 @@ def __init__( self.backward_induction_jit = backward_jit - if alternative_sim_specifications is not None: - self.alternative_sim_funcs = generate_alternative_sim_functions( - model_specs=model_specs, **alternative_sim_specifications - ) - else: - self.alternative_sim_funcs = None + def backward_induction_inner_jit(self, params): + return backward_induction( + params=params, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + batch_info=self.batch_info, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + ) - def solve(self, params, load_sol_path=None, save_sol_path=None): + def solve(self, params, load_sol_path=None, save_sol_path=None, slow_version=False): """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm. Args: @@ -151,8 +164,15 @@ def solve(self, params, load_sol_path=None, save_sol_path=None): if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - # Solve the model - value, policy, endog_grid = self.backward_induction_jit(params_processed) + if slow_version: + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) + else: + # Solve the model + value, policy, endog_grid = self.backward_induction_jit( + params_processed + ) sol_dict = { "value": value, "policy": policy, @@ -177,6 +197,7 @@ def solve_and_simulate( seed, load_sol_path=None, save_sol_path=None, + slow_version=False, ): """Solve the model and simulate it. @@ -197,8 +218,16 @@ def solve_and_simulate( if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - # Solve the model - value, policy, endog_grid = self.backward_induction_jit(params_processed) + if slow_version: + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) + else: + # Solve the model + value, policy, endog_grid = self.backward_induction_jit( + params_processed + ) + sol_dict = { "value": value, "policy": policy, @@ -228,6 +257,7 @@ def get_solve_and_simulate_func( self, states_initial, seed, + slow_version=False, ): sim_func = lambda params, value, policy, endog_gid: simulate_all_periods( @@ -258,10 +288,13 @@ def solve_and_simulate_function_to_jit(params): return sim_dict - jit_solve_simulate = jax.jit(solve_and_simulate_function_to_jit) + if slow_version: + solve_simulate_func = solve_and_simulate_function_to_jit + else: + solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) def solve_and_simulate_function(params): - sim_dict = jit_solve_simulate(params) + sim_dict = solve_simulate_func(params) df = create_simulation_df(sim_dict) return df @@ -275,6 +308,7 @@ def create_experimental_ll_func( unobserved_state_specs=None, return_model_solution=False, use_probability_of_observed_states=True, + slow_version=False, ): return create_individual_likelihood_function( @@ -282,13 +316,14 @@ def create_experimental_ll_func( model_config=self.model_config, model_funcs=self.model_funcs, model_specs=self.model_specs, - backwards_induction=self.backward_induction_jit, + backwards_induction_inner_jit=self.backward_induction_inner_jit, observed_states=observed_states, observed_choices=observed_choices, params_all=params_all, unobserved_state_specs=unobserved_state_specs, return_model_solution=return_model_solution, use_probability_of_observed_states=use_probability_of_observed_states, + slow_version=slow_version, ) def validate_exogenous(self, params): diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 7acf59ee..f8ae76e0 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -24,13 +24,14 @@ def create_individual_likelihood_function( model_config, model_funcs, model_specs, - backwards_induction, + backwards_induction_inner_jit, observed_states: Dict[str, int], observed_choices, params_all, unobserved_state_specs=None, return_model_solution=False, use_probability_of_observed_states=True, + slow_version=False, ): choice_prob_func = create_choice_prob_function( @@ -48,7 +49,7 @@ def individual_likelihood(params): params_update = params_all.copy() params_update.update(params) - value, policy, endog_grid = backwards_induction(params_update) + value, policy, endog_grid = backwards_induction_inner_jit(params_update) choice_probs = choice_prob_func( value_in=value, @@ -69,7 +70,10 @@ def individual_likelihood(params): else: return neg_likelihood_contributions - return jax.jit(individual_likelihood) + if slow_version: + return individual_likelihood + else: + return jax.jit(individual_likelihood) def create_choice_prob_function( diff --git a/src/dcegm/numerical_integration.py b/src/dcegm/numerical_integration.py index 8c895449..014d1cbf 100644 --- a/src/dcegm/numerical_integration.py +++ b/src/dcegm/numerical_integration.py @@ -1,5 +1,6 @@ from typing import Tuple +import jax.numpy as jnp import numpy as np from scipy.special import roots_hermite, roots_sh_legendre from scipy.stats import norm @@ -33,10 +34,10 @@ def quadrature_hermite( quad_points_scaled = quad_points * np.sqrt(2) * income_shock_std quad_weights *= 1 / np.sqrt(np.pi) - return quad_points_scaled, quad_weights + return jnp.asarray(quad_points_scaled), jnp.asarray(quad_weights) -def quadrature_legendre(n_quad_points: int) -> Tuple[np.ndarray, np.ndarray]: +def quadrature_legendre(n_quad_points: int) -> Tuple[jnp.ndarray, jnp.ndarray]: """Return the Gauss-Legendre quadrature points and weights. The stochastic Gauss-Legendre quadrature points are shifted points @@ -58,4 +59,4 @@ def quadrature_legendre(n_quad_points: int) -> Tuple[np.ndarray, np.ndarray]: quad_points, quad_weights = roots_sh_legendre(n_quad_points) quad_points_normal = norm.ppf(quad_points) - return quad_points_normal, quad_weights + return jnp.asarray(quad_points_normal), jnp.asarray(quad_weights) diff --git a/src/dcegm/pre_processing/alternative_sim_functions.py b/src/dcegm/pre_processing/alternative_sim_functions.py index bc106881..2e7c60ce 100644 --- a/src/dcegm/pre_processing/alternative_sim_functions.py +++ b/src/dcegm/pre_processing/alternative_sim_functions.py @@ -26,6 +26,7 @@ def generate_alternative_sim_functions( model_config: Dict, model_specs: Dict, + model_specs_jax: Dict, state_space_functions: Dict[str, Callable], budget_constraint: Callable, shock_functions: Dict[str, Callable] = None, @@ -53,6 +54,7 @@ def generate_alternative_sim_functions( model_funcs, _ = process_alternative_sim_functions( model_config=model_config, model_specs=model_specs, + model_specs_jax=model_specs_jax, state_space_functions=state_space_functions, budget_constraint=budget_constraint, shock_functions=shock_functions, @@ -80,6 +82,7 @@ def generate_alternative_sim_functions( def process_alternative_sim_functions( model_config: Dict, model_specs: Dict, + model_specs_jax: Dict, stochastic_states_transition, state_space_functions: Dict[str, Callable], budget_constraint: Callable, @@ -138,7 +141,7 @@ def process_alternative_sim_functions( create_stochastic_transition_function( stochastic_states_transition, model_config=model_config, - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) @@ -154,7 +157,7 @@ def process_alternative_sim_functions( ) next_period_continuous_state = process_second_continuous_update_function( - second_continuous_state_name, state_space_functions, model_specs=model_specs + second_continuous_state_name, state_space_functions, model_specs=model_specs_jax ) # Budget equation @@ -162,7 +165,7 @@ def process_alternative_sim_functions( determine_function_arguments_and_partial_model_specs( func=budget_constraint, continuous_state_name=second_continuous_state_name, - model_specs=model_specs, + model_specs=model_specs_jax, ) ) @@ -174,8 +177,9 @@ def process_alternative_sim_functions( taste_shock_function_processed, taste_shock_scale_in_params = ( process_shock_functions( - shock_functions, - model_specs, + shock_functions=shock_functions, + model_specs=model_specs, + model_specs_jax=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) diff --git a/src/dcegm/pre_processing/model_functions/process_model_functions.py b/src/dcegm/pre_processing/model_functions/process_model_functions.py index 44a0467e..354eef4c 100644 --- a/src/dcegm/pre_processing/model_functions/process_model_functions.py +++ b/src/dcegm/pre_processing/model_functions/process_model_functions.py @@ -162,8 +162,9 @@ def process_model_functions_and_extract_info( taste_shock_function_processed, taste_shock_scale_in_params = ( process_shock_functions( - shock_functions, - model_specs_jax, + shock_functions=shock_functions, + model_specs=model_specs, + model_specs_jax=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) diff --git a/src/dcegm/pre_processing/model_functions/taste_shock_function.py b/src/dcegm/pre_processing/model_functions/taste_shock_function.py index 13251fc6..d674f8c9 100644 --- a/src/dcegm/pre_processing/model_functions/taste_shock_function.py +++ b/src/dcegm/pre_processing/model_functions/taste_shock_function.py @@ -5,13 +5,15 @@ ) -def process_shock_functions(shock_functions, model_specs, continuous_state_name): +def process_shock_functions( + shock_functions, model_specs, model_specs_jax, continuous_state_name +): taste_shock_function_processed = {} shock_functions = {} if shock_functions is None else shock_functions if "taste_shock_scale_per_state" in shock_functions.keys(): taste_shock_scale_per_state = get_taste_shock_function_for_state( draw_function_taste_shocks=shock_functions["taste_shock_scale_per_state"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=continuous_state_name, ) taste_shock_function_processed["taste_shock_scale_per_state"] = ( @@ -28,10 +30,10 @@ def process_shock_functions(shock_functions, model_specs, continuous_state_name) f"Lambda is not a scalar. If there is no draw function provided, " f"lambda must be a scalar. Got {lambda_val}." ) - read_func = lambda params: jnp.asarray([model_specs["taste_shock_scale"]]) + read_func = lambda params: model_specs_jax["taste_shock_scale"] taste_shock_scale_in_params = False else: - read_func = lambda params: jnp.asarray([params["taste_shock_scale"]]) + read_func = lambda params: params["taste_shock_scale"] taste_shock_scale_in_params = True diff --git a/tests/test_replication.py b/tests/test_replication.py index 1a27b958..580e28dd 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -43,7 +43,7 @@ def test_benchmark_models(model_name): **model_funcs, ) - model_solved = model.solve(params) + model_solved = model.solve(params, slow_version=True) policy_expected = pickle.load( (REPLICATION_TEST_RESOURCES_DIR / f"{model_name}" / "policy.pkl").open("rb") From 6a9213464222dfa5fa97ff66424ae4ea72c3c481 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 11 Sep 2025 19:11:21 +0200 Subject: [PATCH 17/34] Removed alway partialling --- src/dcegm/interfaces/model_class.py | 56 ++++++++++++----------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index e3625e0f..8ccbefbb 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -110,6 +110,18 @@ def __init__( else: self.alternative_sim_funcs = None + def backward_induction_inner_jit(self, params): + return backward_induction( + params=params, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + batch_info=self.batch_info, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + ) + + def get_fast_solve_func(self): backward_jit = jax.jit( partial( backward_induction, @@ -122,20 +134,9 @@ def __init__( ) ) - self.backward_induction_jit = backward_jit - - def backward_induction_inner_jit(self, params): - return backward_induction( - params=params, - income_shock_draws_unscaled=self.income_shock_draws_unscaled, - income_shock_weights=self.income_shock_weights, - model_config=self.model_config, - batch_info=self.batch_info, - model_funcs=self.model_funcs, - model_structure=self.model_structure, - ) + return backward_jit - def solve(self, params, load_sol_path=None, save_sol_path=None, slow_version=False): + def solve(self, params, load_sol_path=None, save_sol_path=None): """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm. Args: @@ -164,15 +165,9 @@ def solve(self, params, load_sol_path=None, save_sol_path=None, slow_version=Fal if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - if slow_version: - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed - ) - else: - # Solve the model - value, policy, endog_grid = self.backward_induction_jit( - params_processed - ) + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) sol_dict = { "value": value, "policy": policy, @@ -197,7 +192,6 @@ def solve_and_simulate( seed, load_sol_path=None, save_sol_path=None, - slow_version=False, ): """Solve the model and simulate it. @@ -218,15 +212,9 @@ def solve_and_simulate( if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - if slow_version: - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed - ) - else: - # Solve the model - value, policy, endog_grid = self.backward_induction_jit( - params_processed - ) + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) sol_dict = { "value": value, @@ -277,7 +265,9 @@ def get_solve_and_simulate_func( def solve_and_simulate_function_to_jit(params): params_processed = process_params(params, self.params_check_info) # Solve the model - value, policy, endog_grid = self.backward_induction_jit(params_processed) + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) sim_dict = sim_func( params=params_processed, From b715fcb610f86c986214f8d328426f8ea5989278 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 11 Sep 2025 20:09:55 +0200 Subject: [PATCH 18/34] Fix --- tests/test_replication.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_replication.py b/tests/test_replication.py index 580e28dd..d1edd8d5 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -3,13 +3,13 @@ import jax.numpy as jnp import pytest +from interp1d_auxiliary import ( + linear_interpolation_with_extrapolation, +) from numpy.testing import assert_array_almost_equal as aaae import dcegm import dcegm.toy_models as toy_models -from tests.utils.interp1d_auxiliary import ( - linear_interpolation_with_extrapolation, -) # Obtain the test directory of the package TEST_DIR = Path(__file__).parent @@ -43,7 +43,7 @@ def test_benchmark_models(model_name): **model_funcs, ) - model_solved = model.solve(params, slow_version=True) + model_solved = model.solve(params) policy_expected = pickle.load( (REPLICATION_TEST_RESOURCES_DIR / f"{model_name}" / "policy.pkl").open("rb") From 0650b14219ff909b780c7e7c0650d6b435c17b52 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 11 Sep 2025 23:15:58 +0200 Subject: [PATCH 19/34] jax --- src/dcegm/pre_processing/check_model_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dcegm/pre_processing/check_model_config.py b/src/dcegm/pre_processing/check_model_config.py index ba41cb67..47083118 100644 --- a/src/dcegm/pre_processing/check_model_config.py +++ b/src/dcegm/pre_processing/check_model_config.py @@ -93,9 +93,9 @@ def check_model_config_and_process(model_config): second_continuous_state_name ) - second_continuous_state_grid = continuous_states_grids[ - second_continuous_state_name - ] + second_continuous_state_grid = jnp.asarray( + continuous_states_grids[second_continuous_state_name] + ) continuous_states_info["second_continuous_grid"] = second_continuous_state_grid # ToDo: Check if grid is array or list and monotonic increasing From 8d65dd29e78c054256fd005fdecafa6bc614cff2 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 12 Sep 2025 00:52:16 +0200 Subject: [PATCH 20/34] Remove one jit --- src/dcegm/backward_induction.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/dcegm/backward_induction.py b/src/dcegm/backward_induction.py index af59fbf7..a6bb1b87 100644 --- a/src/dcegm/backward_induction.py +++ b/src/dcegm/backward_induction.py @@ -99,22 +99,17 @@ def backward_induction( return value_solved, policy_solved, endog_grid_solved # Create JIT-compiled single period solver using lambda - solve_single_period_jit = jax.jit( - lambda carry, xs: solve_single_period( - carry=carry, - xs=xs, - params=params, - continuous_grids_info=continuous_states_info, - cont_grids_next_period=cont_grids_next_period, - model_funcs=model_funcs, - income_shock_weights=income_shock_weights, - debug_info=None, - ) + partial_single_period = lambda carry, xs: solve_single_period( + carry=carry, + xs=xs, + params=params, + continuous_grids_info=continuous_states_info, + cont_grids_next_period=cont_grids_next_period, + model_funcs=model_funcs, + income_shock_weights=income_shock_weights, + debug_info=None, ) - def partial_single_period(carry, xs): - return solve_single_period_jit(carry, xs) - for id_segment in range(batch_info["n_segments"]): segment_info = batch_info[f"batches_info_segment_{id_segment}"] From 4bc9b79b7a91c329978de50be8ad49331070f236 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 19 Sep 2025 15:39:17 +0200 Subject: [PATCH 21/34] Weight func separate! --- src/dcegm/likelihood.py | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index f8ae76e0..bb6f2df7 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -43,6 +43,7 @@ def create_individual_likelihood_function( observed_choices=observed_choices, unobserved_state_specs=unobserved_state_specs, use_probability_of_observed_states=use_probability_of_observed_states, + return_weight_func=False, ) def individual_likelihood(params): @@ -85,6 +86,7 @@ def create_choice_prob_function( observed_choices, unobserved_state_specs, use_probability_of_observed_states, + return_weight_func, ): if unobserved_state_specs is None: choice_prob_func = create_partial_choice_prob_calculation( @@ -105,6 +107,7 @@ def create_choice_prob_function( observed_choices=observed_choices, unobserved_state_specs=unobserved_state_specs, use_probability_of_observed_states=use_probability_of_observed_states, + return_weight_func=return_weight_func, ) return choice_prob_func @@ -119,16 +122,14 @@ def create_choice_prob_func_unobserved_states( observed_choices, unobserved_state_specs, use_probability_of_observed_states=True, + return_weight_func=False, ): unobserved_state_names = unobserved_state_specs["observed_bools_states"].keys() observed_bools = unobserved_state_specs["observed_bools_states"] # Create weighting vars by extracting states and choices - weighting_vars = unobserved_state_specs["state_choices_weighing"]["states"] - weighting_vars["choice"] = unobserved_state_specs["state_choices_weighing"][ - "choices" - ] + weighting_vars = unobserved_state_specs["weighting_vars"] # Add unobserved states with appendix new and bools indicating if state is observed for state_name in unobserved_state_names: @@ -175,7 +176,7 @@ def create_choice_prob_func_unobserved_states( for possible_state in possible_states: possible_state[state_name][unobserved_state_bool] = state_value new_possible_states.append(copy.deepcopy(possible_state)) - # Same for pre period states + # Same for variables to weight function for weighting_vars in weighting_vars_for_possible_states: weighting_vars[state_name + "_new"][unobserved_state_bool] = state_value new_weighting_vars_for_possible_states.append( @@ -262,7 +263,35 @@ def choice_prob_func(value_in, endog_grid_in, params_in): return choice_probs_final - return choice_prob_func + def weight_only_func(params_in): + weights = np.zeros((n_obs, len(possible_states)), dtype=np.float64) + count = 0 + for partial_choice_prob, unobserved_state, weighting_vars in zip( + partial_choice_probs_unobserved_states, + possible_states, + weighting_vars_for_possible_states, + ): + unobserved_weights = jax.vmap( + partial_weight_func, + in_axes=(None, 0), + )( + params_in, + weighting_vars, + ) + + weights[:, count] = unobserved_weights + count += 1 + return ( + weights, + observed_weights, + possible_states, + weighting_vars_for_possible_states, + ) + + if return_weight_func: + return choice_prob_func, weight_only_func + else: + return choice_prob_func def create_partial_choice_prob_calculation( From c39a66d45bce6ec71d3521da7745ad6e8f4ac964 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Sun, 21 Sep 2025 20:33:02 +0200 Subject: [PATCH 22/34] Fix --- src/dcegm/interfaces/index_functions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/dcegm/interfaces/index_functions.py b/src/dcegm/interfaces/index_functions.py index e3016391..2a2e5645 100644 --- a/src/dcegm/interfaces/index_functions.py +++ b/src/dcegm/interfaces/index_functions.py @@ -1,3 +1,6 @@ +import numpy as np + + def get_child_state_index_per_states_and_choices(states, choices, model_structure): state_choice_index = get_state_choice_index_per_discrete_states_and_choices( model_structure, states, choices @@ -28,6 +31,15 @@ def get_state_choice_index_per_discrete_states( indexes = map_state_choice_to_index[ tuple((states[key],) for key in discrete_states_names) ] + max_values_per_state = {key: np.max(states[key]) for key in discrete_states_names} + # Check that max value does not exceed the dimension + dim = map_state_choice_to_index.shape + for i, key in enumerate(discrete_states_names): + if max_values_per_state[key] > dim[i] - 1: + raise ValueError( + f"Max value of state {key} exceeds the dimension of the model." + ) + # As the code above generates a dummy dimension in the first index, remove it return indexes[0] From 0bd462355c4a94a33d56c69d50ad3133398c16ee Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Mon, 22 Sep 2025 00:27:23 +0200 Subject: [PATCH 23/34] comment out --- src/dcegm/interfaces/index_functions.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/dcegm/interfaces/index_functions.py b/src/dcegm/interfaces/index_functions.py index 2a2e5645..f4a84e7a 100644 --- a/src/dcegm/interfaces/index_functions.py +++ b/src/dcegm/interfaces/index_functions.py @@ -31,14 +31,15 @@ def get_state_choice_index_per_discrete_states( indexes = map_state_choice_to_index[ tuple((states[key],) for key in discrete_states_names) ] - max_values_per_state = {key: np.max(states[key]) for key in discrete_states_names} - # Check that max value does not exceed the dimension - dim = map_state_choice_to_index.shape - for i, key in enumerate(discrete_states_names): - if max_values_per_state[key] > dim[i] - 1: - raise ValueError( - f"Max value of state {key} exceeds the dimension of the model." - ) + # Need flag to only evaluate in non jit mode + # max_values_per_state = {key: np.max(states[key]) for key in discrete_states_names} + # # Check that max value does not exceed the dimension + # dim = map_state_choice_to_index.shape + # for i, key in enumerate(discrete_states_names): + # if max_values_per_state[key] > dim[i] - 1: + # raise ValueError( + # f"Max value of state {key} exceeds the dimension of the model." + # ) # As the code above generates a dummy dimension in the first index, remove it return indexes[0] From 3a237e26d558061a3b8b04434f95c81f33e8176e Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Wed, 24 Sep 2025 11:28:54 +0200 Subject: [PATCH 24/34] Align interface --- docs/source/background/two_period_model_tutorial.ipynb | 2 +- src/dcegm/interfaces/sol_interface.py | 2 +- tests/test_replication.py | 2 +- tests/test_varying_shock_scale.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/background/two_period_model_tutorial.ipynb b/docs/source/background/two_period_model_tutorial.ipynb index 6f588557..572354ab 100644 --- a/docs/source/background/two_period_model_tutorial.ipynb +++ b/docs/source/background/two_period_model_tutorial.ipynb @@ -762,7 +762,7 @@ "}\n", "\n", "\n", - "cons_calc, value = solved_model.value_and_policy_for_states_and_choices(\n", + "cons_calc, value = solved_model.policy_and_value_for_states_and_choices(\n", " states=state_dict,\n", " choices=choice_in_period_0,\n", ")" diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index f4f656bd..4c32550f 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -60,7 +60,7 @@ def simulate(self, states_initial, seed): ) return create_simulation_df(sim_dict) - def value_and_policy_for_states_and_choices(self, states, choices): + def policy_and_value_for_states_and_choices(self, states, choices): """Get the value and policy for a given state and choice. Args: diff --git a/tests/test_replication.py b/tests/test_replication.py index d1edd8d5..ed97715b 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -82,7 +82,7 @@ def test_benchmark_models(model_name): "assets_begin_of_period": wealth_grid_to_test, } policy_calc_interp, value_calc_interp = ( - model_solved.value_and_policy_for_states_and_choices( + model_solved.policy_and_value_for_states_and_choices( states=state, choices=jnp.ones_like(wealth_grid_to_test, dtype=int) * choice, ) diff --git a/tests/test_varying_shock_scale.py b/tests/test_varying_shock_scale.py index c36dc91b..0dc8c9fd 100644 --- a/tests/test_varying_shock_scale.py +++ b/tests/test_varying_shock_scale.py @@ -72,7 +72,7 @@ def test_benchmark_models(): ( policy_calc_interp, value_calc_interp, - ) = model_solved.value_and_policy_for_states_and_choices( + ) = model_solved.policy_and_value_for_states_and_choices( states=state, choices=jnp.ones_like(wealth_grid_to_test, dtype=int) * choice, ) From 54abffdcf937b2f7721a42d2127e08c529b70552 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Wed, 24 Sep 2025 22:57:56 +0200 Subject: [PATCH 25/34] Done --- src/dcegm/interfaces/model_class.py | 22 +++++++++++++++++++++- src/dcegm/interfaces/sol_interface.py | 26 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 8ccbefbb..187f6f4f 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -86,6 +86,7 @@ def __init__( ) self.model_specs = jax.tree_util.tree_map(try_jax_array, model_specs) + self.specs_without_jax = model_specs self.model_config = model_dict["model_config"] self.model_funcs = model_dict["model_funcs"] @@ -103,13 +104,32 @@ def __init__( if alternative_sim_specifications is not None: self.alternative_sim_funcs = generate_alternative_sim_functions( - model_specs=model_specs, + model_specs=self.specs_without_jax, model_specs_jax=self.model_specs, **alternative_sim_specifications, ) else: self.alternative_sim_funcs = None + def set_alternative_sim_funcs( + self, alternative_sim_specifications, alternative_specs=None + ): + if alternative_specs is None: + self.alternative_sim_specs = self.model_specs + alternative_specs_without_jax = self.specs_without_jax + else: + self.alternative_sim_specs = jax.tree_util.tree_map( + try_jax_array, alternative_specs + ) + alternative_specs_without_jax = alternative_specs + + alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=alternative_specs_without_jax, + model_specs_jax=self.alternative_sim_specs, + **alternative_sim_specifications, + ) + self.alternative_sim_funcs = alternative_sim_funcs + def backward_induction_inner_jit(self, params): return backward_induction( params=params, diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index 4c32550f..b4f44a62 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp from dcegm.interfaces.index_functions import ( @@ -14,6 +15,10 @@ choice_values_for_states, get_state_choice_index_per_discrete_states, ) +from dcegm.pre_processing.alternative_sim_functions import ( + generate_alternative_sim_functions, +) +from dcegm.pre_processing.shared import try_jax_array from dcegm.simulation.sim_utils import create_simulation_df from dcegm.simulation.simulate import simulate_all_periods @@ -41,8 +46,29 @@ def __init__( self.model_structure = model.model_structure self.model_funcs = model.model_funcs self.model_specs = model.model_specs + self.specs_without_jax = model.specs_without_jax self.alternative_sim_funcs = model.alternative_sim_funcs + def set_alternative_sim_funcs( + self, alternative_sim_specifications, alternative_specs=None + ): + if alternative_specs is None: + self.alternative_sim_specs = self.model_specs + alternative_specs_without_jax = self.specs_without_jax + else: + self.alternative_sim_specs = jax.tree_util.tree_map( + try_jax_array, alternative_specs + ) + alternative_specs_without_jax = alternative_specs + + alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=alternative_specs_without_jax, + model_specs_jax=self.alternative_sim_specs, + **alternative_sim_specifications, + ) + self.model.alternative_sim_funcs = alternative_sim_funcs + self.alternative_sim_funcs = alternative_sim_funcs + def simulate(self, states_initial, seed): sim_dict = simulate_all_periods( From 10db3941ab6e80ac9549e2096325e2d237bedeb8 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 2 Oct 2025 11:24:57 +0200 Subject: [PATCH 26/34] more interface --- src/dcegm/interfaces/interface.py | 115 ++++++++++++++++++++++++++ src/dcegm/interfaces/sol_interface.py | 49 ++++++++++- src/dcegm/likelihood.py | 63 +------------- 3 files changed, 165 insertions(+), 62 deletions(-) diff --git a/src/dcegm/interfaces/interface.py b/src/dcegm/interfaces/interface.py index 6a665755..965706ff 100644 --- a/src/dcegm/interfaces/interface.py +++ b/src/dcegm/interfaces/interface.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp import pandas as pd +from jax import numpy as jnp from dcegm.interfaces.index_functions import ( get_state_choice_index_per_discrete_states_and_choices, @@ -289,3 +290,117 @@ def stochastic_transition_vec(state_choice_vec_dict, func, params): """ return func(**state_choice_vec_dict, params=params) + + +def choice_values_for_states( + value_solved, + endog_grid_solved, + state_choice_indexes, + params, + states, + model_config, + model_funcs, +): + value_grid_states = jnp.take( + value_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + endog_grid_states = jnp.take( + endog_grid_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + + def wrapper_interp_value_for_choice( + state, + value_grid_state_choice, + endog_grid_state_choice, + choice, + ): + state_choice_vec = {**state, "choice": choice} + + return interpolate_value_for_state_and_choice( + value_grid_state_choice=value_grid_state_choice, + endog_grid_state_choice=endog_grid_state_choice, + state_choice_vec=state_choice_vec, + params=params, + model_config=model_config, + model_funcs=model_funcs, + ) + + # Read out choice range to loop over + choice_range = model_config["choices"] + + choice_values_per_state = jax.vmap( + jax.vmap( + wrapper_interp_value_for_choice, + in_axes=(None, 0, 0, 0), + ), + in_axes=(0, 0, 0, None), + )( + states, + value_grid_states, + endog_grid_states, + choice_range, + ) + return choice_values_per_state + + +def choice_policies_for_states( + policy_solved, + endog_grid_solved, + state_choice_indexes, + states, + model_config, +): + policy_grid_states = jnp.take( + policy_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + endog_grid_states = jnp.take( + endog_grid_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + + def wrapper_interp_value_for_choice( + state, + policy_grid_state_choice, + endog_grid_state_choice, + choice, + ): + state_choice_vec = {**state, "choice": choice} + + return interpolate_policy_for_state_and_choice( + policy_grid_state_choice=policy_grid_state_choice, + endog_grid_state_choice=endog_grid_state_choice, + state_choice_vec=state_choice_vec, + model_config=model_config, + ) + + # Read out choice range to loop over + choice_range = model_config["choices"] + + choice_values_per_state = jax.vmap( + jax.vmap( + wrapper_interp_value_for_choice, + in_axes=(None, 0, 0, 0), + ), + in_axes=(0, 0, 0, None), + )( + states, + policy_grid_states, + endog_grid_states, + choice_range, + ) + return choice_values_per_state diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index b4f44a62..d0f3270f 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -5,6 +5,8 @@ get_state_choice_index_per_discrete_states_and_choices, ) from dcegm.interfaces.interface import ( + choice_policies_for_states, + choice_values_for_states, policy_and_value_for_states_and_choices, policy_for_state_choice_vec, value_for_state_and_choice, @@ -12,7 +14,6 @@ from dcegm.interfaces.interface_checks import check_states_and_choices from dcegm.likelihood import ( calc_choice_probs_for_states, - choice_values_for_states, get_state_choice_index_per_discrete_states, ) from dcegm.pre_processing.alternative_sim_functions import ( @@ -186,6 +187,16 @@ def get_solution_for_discrete_state_choice(self, states, choices): def choice_probabilities_for_states(self, states): + # To check structure, add dummy choice for now and delete afterwards. + # Error messages will be misleading though. + state_choices = check_states_and_choices( + states=states, + choices=states["period"], + model_structure=self.model_structure, + ) + state_choices.pop("choice") + states = state_choices + state_choice_idxs = get_state_choice_index_per_discrete_states( states=states, map_state_choice_to_index=self.model_structure[ @@ -205,6 +216,16 @@ def choice_probabilities_for_states(self, states): ) def choice_values_for_states(self, states): + # To check structure, add dummy choice for now and delete afterwards. + # Error messages will be misleading though. + state_choices = check_states_and_choices( + states=states, + choices=states["period"], + model_structure=self.model_structure, + ) + state_choices.pop("choice") + states = state_choices + state_choice_idxs = get_state_choice_index_per_discrete_states( states=states, map_state_choice_to_index=self.model_structure[ @@ -221,3 +242,29 @@ def choice_values_for_states(self, states): model_config=self.model_config, model_funcs=self.model_funcs, ) + + def choice_policies_for_states(self, states): + # To check structure, add dummy choice for now and delete afterwards. + # Error messages will be misleading though. + state_choices = check_states_and_choices( + states=states, + choices=states["period"], + model_structure=self.model_structure, + ) + state_choices.pop("choice") + states = state_choices + + state_choice_idxs = get_state_choice_index_per_discrete_states( + states=states, + map_state_choice_to_index=self.model_structure[ + "map_state_choice_to_index_with_proxy" + ], + discrete_states_names=self.model_structure["discrete_states_names"], + ) + return choice_policies_for_states( + policy_solved=self.policy, + endog_grid_solved=self.endog_grid, + state_choice_indexes=state_choice_idxs, + states=states, + model_config=self.model_config, + ) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index bb6f2df7..3c9c66bc 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -5,7 +5,7 @@ """ import copy -from typing import Any, Dict +from typing import Dict import jax import jax.numpy as jnp @@ -16,7 +16,7 @@ calculate_choice_probs_and_unsqueezed_logsum, ) from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states -from dcegm.interpolation.interp_interfaces import interpolate_value_for_state_and_choice +from dcegm.interfaces.interface import choice_values_for_states def create_individual_likelihood_function( @@ -395,65 +395,6 @@ def calc_choice_probs_for_states( return choice_prob_across_choices -def choice_values_for_states( - value_solved, - endog_grid_solved, - state_choice_indexes, - params, - states, - model_config, - model_funcs, -): - value_grid_states = jnp.take( - value_solved, - state_choice_indexes, - axis=0, - mode="fill", - fill_value=jnp.nan, - ) - endog_grid_states = jnp.take( - endog_grid_solved, - state_choice_indexes, - axis=0, - mode="fill", - fill_value=jnp.nan, - ) - - def wrapper_interp_value_for_choice( - state, - value_grid_state_choice, - endog_grid_state_choice, - choice, - ): - state_choice_vec = {**state, "choice": choice} - - return interpolate_value_for_state_and_choice( - value_grid_state_choice=value_grid_state_choice, - endog_grid_state_choice=endog_grid_state_choice, - state_choice_vec=state_choice_vec, - params=params, - model_config=model_config, - model_funcs=model_funcs, - ) - - # Read out choice range to loop over - choice_range = model_config["choices"] - - choice_values_per_state = jax.vmap( - jax.vmap( - wrapper_interp_value_for_choice, - in_axes=(None, 0, 0, 0), - ), - in_axes=(0, 0, 0, None), - )( - states, - value_grid_states, - endog_grid_states, - choice_range, - ) - return choice_values_per_state - - def calculate_weights_for_each_state(params, weight_vars, model_specs, weight_func): """Calculate the weights for each state. From 41c142be2a5048842beb29dae23d34983e91b5e6 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Mon, 27 Oct 2025 15:27:41 +0100 Subject: [PATCH 27/34] Interfacing --- src/dcegm/interfaces/interface.py | 4 ++-- src/dcegm/interfaces/model_class.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/dcegm/interfaces/interface.py b/src/dcegm/interfaces/interface.py index 965706ff..20a92518 100644 --- a/src/dcegm/interfaces/interface.py +++ b/src/dcegm/interfaces/interface.py @@ -16,7 +16,7 @@ ) -def get_n_state_choice_period(model): +def get_n_state_choice_period(model_structure): """Get the number of state-choice periods from the model. Args: @@ -29,7 +29,7 @@ def get_n_state_choice_period(model): """ return ( - pd.Series(model["model_structure"]["state_choice_space"][:, 0]) + pd.Series(model_structure["state_choice_space"][:, 0]) .value_counts() .sort_index() ) diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 187f6f4f..82486bbc 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -10,7 +10,10 @@ get_child_state_index_per_states_and_choices, get_state_choice_index_per_discrete_states, ) -from dcegm.interfaces.interface import validate_stochastic_transition +from dcegm.interfaces.interface import ( + get_n_state_choice_period, + validate_stochastic_transition, +) from dcegm.interfaces.sol_interface import model_solved from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.likelihood import create_individual_likelihood_function @@ -445,3 +448,6 @@ def compute_law_of_motions(self, params): model_funcs=self.model_funcs, income_shock_draws_unscaled=self.income_shock_draws_unscaled, ) + + def get_n_state_choices_per_period(self): + return get_n_state_choice_period(self.model_structure) From ed738af0389659f7fc1a16d980e33a963f5e9236 Mon Sep 17 00:00:00 2001 From: Maximilian Blesch Date: Tue, 13 Jan 2026 13:19:02 +0100 Subject: [PATCH 28/34] Add sparse stochastic transitions (#189) --- src/dcegm/asset_correction.py | 44 ++- src/dcegm/final_periods.py | 2 +- src/dcegm/interfaces/model_class.py | 11 +- src/dcegm/interpolation/interp2d.py | 2 +- src/dcegm/law_of_motion.py | 196 +++++------- .../batches/last_two_periods.py | 4 +- .../pre_processing/batches/single_segment.py | 27 +- .../model_structure/stochastic_states.py | 141 ++++++++- src/dcegm/pre_processing/setup_model.py | 37 +++ src/dcegm/simulation/sim_utils.py | 32 +- tests/test_interpolation.py | 3 +- tests/test_sparse_stochastic_transitions.py | 205 +++++++++++++ tests/test_stochastic_transitions.py | 285 +----------------- 13 files changed, 513 insertions(+), 476 deletions(-) create mode 100644 tests/test_sparse_stochastic_transitions.py diff --git a/src/dcegm/asset_correction.py b/src/dcegm/asset_correction.py index 17a6e2b4..f15a9837 100644 --- a/src/dcegm/asset_correction.py +++ b/src/dcegm/asset_correction.py @@ -2,8 +2,7 @@ from jax import vmap from dcegm.law_of_motion import ( - calc_assets_beginning_of_period_2cont_vec, - calc_beginning_of_period_assets_1cont_vec, + calc_beginning_of_period_assets_for_single_state, ) @@ -37,30 +36,23 @@ def adjust_observed_assets(observed_states_dict, params, model_class, aux_outs=F second_cont_state_vars = observed_states_dict[second_cont_state_name] observed_states_dict_int.pop(second_cont_state_name) - adjusted_assets = vmap( - calc_assets_beginning_of_period_2cont_vec, - in_axes=(0, 0, 0, None, None, None, None), - )( - observed_states_dict_int, - second_cont_state_vars, - assets_end_last_period, - jnp.array(0.0, dtype=jnp.float64), - params, - model_funcs["compute_assets_begin_of_period"], - aux_outs, - ) - + all_states = { + **observed_states_dict_int, + "continuous_state": second_cont_state_vars, + } else: - adjusted_assets = vmap( - calc_beginning_of_period_assets_1cont_vec, - in_axes=(0, 0, None, None, None, None), - )( - observed_states_dict, - assets_end_last_period, - jnp.array(0.0, dtype=jnp.float64), - params, - model_funcs["compute_assets_begin_of_period"], - aux_outs, - ) + all_states = observed_states_dict_int + + adjusted_assets = vmap( + calc_beginning_of_period_assets_for_single_state, + in_axes=(0, 0, None, None, None, None), + )( + all_states, + assets_end_last_period, + jnp.array(0.0, dtype=jnp.float64), + params, + model_funcs["compute_assets_begin_of_period"], + aux_outs, + ) return adjusted_assets diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index a42787d5..a3a2a2f4 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -485,7 +485,7 @@ def calc_value_and_budget_for_each_gridpoint( wealth_final_period = calc_assets_beginning_of_period_2cont_vec( state_vec=state_vec, continuous_state_beginning_of_period=second_continuous_state, - asset_grid_point_end_of_previous_period=asset_grid_point_end_of_previous_period, + asset_end_of_previous_period=asset_grid_point_end_of_previous_period, income_shock_draw=jnp.array(0.0), params=params, compute_assets_begin_of_period=compute_assets_begin_of_period, diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 82486bbc..336a5162 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -47,8 +47,14 @@ def __init__( debug_info: str = None, model_save_path: str = None, model_load_path: str = None, + use_stochastic_sparsity: bool = False, ): - """Setup the model and check if load or save is required.""" + """Setup the model and check if load or save is required. + + Args: + use_stochastic_sparsity (bool, optional): EXPERIMENTAL: Use stochastic transition sparsity. + + """ if model_load_path is not None: model_dict = load_model_dict( @@ -61,6 +67,7 @@ def __init__( stochastic_states_transitions=stochastic_states_transitions, shock_functions=shock_functions, path=model_load_path, + use_stochastic_sparsity=use_stochastic_sparsity, ) elif model_save_path is not None: model_dict = create_model_dict_and_save( @@ -74,6 +81,7 @@ def __init__( shock_functions=shock_functions, path=model_save_path, debug_info=debug_info, + use_stochastic_sparsity=use_stochastic_sparsity, ) else: model_dict = create_model_dict( @@ -86,6 +94,7 @@ def __init__( stochastic_states_transitions=stochastic_states_transitions, shock_functions=shock_functions, debug_info=debug_info, + use_stochastic_sparsity=use_stochastic_sparsity, ) self.model_specs = jax.tree_util.tree_map(try_jax_array, model_specs) diff --git a/src/dcegm/interpolation/interp2d.py b/src/dcegm/interpolation/interp2d.py index cfadd788..8dcfe71e 100644 --- a/src/dcegm/interpolation/interp2d.py +++ b/src/dcegm/interpolation/interp2d.py @@ -193,7 +193,7 @@ def interp2d_policy_on_wealth_and_regular_grid( to interpolate. Returns: - float: The interpolated value of the policy function at the given + jnp.ndarray | float: The interpolated value of the policy function at the given (regular, wealth) point. """ diff --git a/src/dcegm/law_of_motion.py b/src/dcegm/law_of_motion.py index d5698c4d..b3c28d2c 100644 --- a/src/dcegm/law_of_motion.py +++ b/src/dcegm/law_of_motion.py @@ -23,6 +23,9 @@ def calc_cont_grids_next_period( income_shock_draws_unscaled * income_shock_std + income_shock_mean ) + # Generate result dict + cont_grids_next_period = {} + if continuous_states_info["second_continuous_exists"]: continuous_state_next_period = calculate_continuous_state( discrete_states_beginning_of_period=state_space_dict, @@ -30,73 +33,83 @@ def calc_cont_grids_next_period( params=params, compute_continuous_state=model_funcs["next_period_continuous_state"], ) + # Fill in result dict + cont_grids_next_period["second_continuous"] = continuous_state_next_period - # Extra dimension for continuous state - assets_beginning_of_next_period = calc_assets_beginning_of_period_2cont( - discrete_states_beginning_of_next_period=state_space_dict, - continuous_state_beginning_of_next_period=continuous_state_next_period, - assets_grid_end_of_period=continuous_states_info[ - "assets_grid_end_of_period" - ], - income_shocks=income_shocks_scaled, - params=params, - compute_assets_begin_of_period=model_funcs[ - "compute_assets_begin_of_period" - ], - ) - - cont_grids_next_period = { - "assets_begin_of_period": assets_beginning_of_next_period, - "second_continuous": continuous_state_next_period, + # Prepare dict used to calculate beginning of period assets + state_specific_grids = { + "states": state_space_dict, + "continuous_state": continuous_state_next_period, } - else: - assets_begin_of_next_period = calc_beginning_of_period_assets_1cont( - discrete_states_beginning_of_period=state_space_dict, - assets_grid_end_of_period=continuous_states_info[ - "assets_grid_end_of_period" - ], - income_shocks_current_period=income_shocks_scaled, + state_specific_grids = { + "states": state_space_dict, + } + + def fix_assets_and_shocks_for_broadcast( + states, + asset_end_of_previous_period, + income_draw, + ): + assets_begin_of_period = calc_beginning_of_period_assets_for_single_state( + state_vec=states, + asset_end_of_previous_period=asset_end_of_previous_period, + income_shock_draw=income_draw, params=params, compute_assets_begin_of_period=model_funcs[ "compute_assets_begin_of_period" ], + aux_outs=False, ) - cont_grids_next_period = { - "assets_begin_of_period": assets_begin_of_next_period, - } + return assets_begin_of_period - return cont_grids_next_period - - -def calc_beginning_of_period_assets_1cont( - discrete_states_beginning_of_period, - assets_grid_end_of_period, - income_shocks_current_period, - params, - compute_assets_begin_of_period, -): - assets_begin_of_period = vmap( + broadcast_function = lambda states: vmap( vmap( - vmap( - calc_beginning_of_period_assets_1cont_vec, - in_axes=(None, None, 0, None, None, None), # income shocks - ), - in_axes=(None, 0, None, None, None, None), # assets + fix_assets_and_shocks_for_broadcast, + in_axes=(None, None, 0), # income shocks ), - in_axes=(0, None, None, None, None, None), # discrete states + in_axes=(None, 0, None), # assets )( - discrete_states_beginning_of_period, - assets_grid_end_of_period, - income_shocks_current_period, - params, - compute_assets_begin_of_period, - False, + states, + continuous_states_info["assets_grid_end_of_period"], + income_shocks_scaled, ) - return assets_begin_of_period + + final_args = () + # Default is no chaining of vmaps. Then I add consequently vmap over specific grids + vmap_chain = broadcast_function + + for grid_name in state_specific_grids.keys(): + if grid_name != "states": + # Use default argument to capture current values + vmap_chain = add_vmap_chain_for_grid(vmap_chain, grid_name) + final_args += (state_specific_grids[grid_name],) + + final_args = (state_specific_grids["states"],) + final_args + assets_begin_of_next_period = vmap(vmap_chain)(*final_args) + cont_grids_next_period["assets_begin_of_period"] = assets_begin_of_next_period + return cont_grids_next_period -def calc_beginning_of_period_assets_1cont_vec( +def add_vmap_chain_for_grid(inner_func, gname): + """The function adds a vmap layer for a specific grid. + + It vmaps over the remaining dimension of the grid. So if we have a grid that is + (n_discrete_states, n_grid_points), we can later vmap over the discrete states and + this function will add the n_grid_points dimension to be vmapped over. The function + only expects later the grid to arrive in n_grid_points. So we can also use the + function in the final period calculation. + + """ + + def grid_wrapper(states, new_state_grid): + all_states = {**states, gname: new_state_grid} + return inner_func(all_states) + + return vmap(grid_wrapper, in_axes=(None, 0)) + + +def calc_beginning_of_period_assets_for_single_state( state_vec, asset_end_of_previous_period, income_shock_draw, @@ -124,22 +137,23 @@ def calc_beginning_of_period_assets_1cont_vec( def calc_assets_beginning_of_period_2cont_vec( state_vec, continuous_state_beginning_of_period, - asset_grid_point_end_of_previous_period, + asset_end_of_previous_period, income_shock_draw, params, compute_assets_begin_of_period, aux_outs, ): - - out_budget = compute_assets_begin_of_period( + all_states = { **state_vec, - continuous_state=continuous_state_beginning_of_period, - asset_end_of_previous_period=asset_grid_point_end_of_previous_period, - income_shock_previous_period=income_shock_draw, + "continuous_state": continuous_state_beginning_of_period, + } + checked_out = calc_beginning_of_period_assets_for_single_state( + state_vec=all_states, + asset_end_of_previous_period=asset_end_of_previous_period, + income_shock_draw=income_shock_draw, params=params, - ) - checked_out = check_budget_equation_and_return_wealth_plus_optional_aux( - out_budget, optional_aux=aux_outs + compute_assets_begin_of_period=compute_assets_begin_of_period, + aux_outs=aux_outs, ) return checked_out @@ -179,39 +193,6 @@ def calc_continuous_state_for_each_grid_point( return out -def calc_assets_beginning_of_period_2cont( - discrete_states_beginning_of_next_period, - continuous_state_beginning_of_next_period, - assets_grid_end_of_period, - income_shocks, - params, - compute_assets_begin_of_period, -): - - assets_begin_of_period = vmap( - vmap( - vmap( - vmap( - calc_assets_beginning_of_period_2cont_vec, - in_axes=(None, None, None, 0, None, None, None), # income shocks - ), - in_axes=(None, None, 0, None, None, None, None), # assets - ), - in_axes=(None, 0, None, None, None, None, None), # continuous state - ), - in_axes=(0, 0, None, None, None, None, None), # discrete states - )( - discrete_states_beginning_of_next_period, - continuous_state_beginning_of_next_period, - assets_grid_end_of_period, - income_shocks, - params, - compute_assets_begin_of_period, - False, - ) - return assets_begin_of_period - - # ===================================================================================== # Simulation # ===================================================================================== @@ -226,7 +207,7 @@ def calculate_assets_begin_of_period_for_all_agents( ): """Simulation.""" assets_begin_of_next_period = vmap( - calc_beginning_of_period_assets_1cont_vec, + calc_beginning_of_period_assets_for_single_state, in_axes=(0, 0, 0, None, None, None), )( states_beginning_of_period, @@ -256,28 +237,3 @@ def calculate_second_continuous_state_for_all_agents( compute_continuous_state, ) return continuous_state_beginning_of_next_period - - -def calc_assets_begin_of_period_for_all_agents( - states_beginning_of_period, - continuous_state_beginning_of_period, - assets_end_of_period, - income_shocks_of_period, - params, - compute_assets_begin_of_period, -): - """Simulation.""" - - assets_begin_of_next_period, aux_dict = vmap( - calc_assets_beginning_of_period_2cont_vec, - in_axes=(0, 0, 0, 0, None, None, None), - )( - states_beginning_of_period, - continuous_state_beginning_of_period, - assets_end_of_period, - income_shocks_of_period, - params, - compute_assets_begin_of_period, - True, - ) - return assets_begin_of_next_period, aux_dict diff --git a/src/dcegm/pre_processing/batches/last_two_periods.py b/src/dcegm/pre_processing/batches/last_two_periods.py index fcaba81c..8638014d 100644 --- a/src/dcegm/pre_processing/batches/last_two_periods.py +++ b/src/dcegm/pre_processing/batches/last_two_periods.py @@ -74,13 +74,13 @@ def add_last_two_period_information( "child_states_second_last_period": child_states_second_last_period, } + state_choice_space_dict = model_structure["state_choice_space_dict"] # Also add state choice mat as dictionary for each of the two periods for idx, period_name in [ (idx_state_choice_final_period, "final"), (idx_state_choice_second_last_period, "second_last"), ]: last_two_period_info[f"state_choice_mat_{period_name}_period"] = { - key: state_choice_space[:, i][idx] - for i, key in enumerate(discrete_states_names + ["choice"]) + key: var[idx] for key, var in state_choice_space_dict.items() } return last_two_period_info diff --git a/src/dcegm/pre_processing/batches/single_segment.py b/src/dcegm/pre_processing/batches/single_segment.py index d480cc2a..de7db506 100644 --- a/src/dcegm/pre_processing/batches/single_segment.py +++ b/src/dcegm/pre_processing/batches/single_segment.py @@ -11,6 +11,7 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur """ state_choice_space = model_structure["state_choice_space"] + state_choice_space_dict = model_structure["state_choice_space_dict"] state_space = model_structure["state_space"] discrete_states_names = model_structure["discrete_states_names"] @@ -48,9 +49,8 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur child_states_to_integrate_stochastic_list, child_state_choices_to_aggr_choice_list, child_state_choice_idxs_to_interp_list, - state_choice_space, + state_choice_space_dict, map_state_choice_to_parent_state, - discrete_states_names, ) single_batch_segment_info = prepare_and_align_batch_arrays( @@ -58,7 +58,7 @@ def create_single_segment_of_batches(bool_state_choices_to_batch, model_structur child_states_to_integrate_stochastic_list, child_state_choices_to_aggr_choice_list, child_state_choice_idxs_to_interp_list, - state_choice_space, + state_choice_space_dict, map_state_choice_to_parent_state, discrete_states_names, ) @@ -74,9 +74,8 @@ def correct_for_uneven_last_batch( child_states_to_integrate_stochastic_list, child_state_choices_to_aggr_choice_list, child_state_choice_idxs_to_interp_list, - state_choice_space, + state_choice_space_dict, map_state_choice_to_parent_state, - discrete_states_names, ): """Check if the last batch has the same length as the others. @@ -108,12 +107,11 @@ def correct_for_uneven_last_batch( last_child_state_idx_interp = child_state_choice_idxs_to_interp_list[-1] last_state_choices = { - key: state_choice_space[:, i][last_batch] - for i, key in enumerate(discrete_states_names + ["choice"]) + key: var[last_batch] for key, var in state_choice_space_dict.items() } last_state_choices_childs = { - key: state_choice_space[:, i][last_child_state_idx_interp] - for i, key in enumerate(discrete_states_names + ["choice"]) + key: var[last_child_state_idx_interp] + for key, var in state_choice_space_dict.items() } last_parent_state_idx_of_state_choice = map_state_choice_to_parent_state[ last_child_state_idx_interp @@ -154,7 +152,7 @@ def prepare_and_align_batch_arrays( child_states_to_integrate_stochastic_list, child_state_choices_to_aggr_choice_list, child_state_choice_idxs_to_interp_list, - state_choice_space, + state_choice_space_dict, map_state_choice_to_parent_state, discrete_states_names, ): @@ -165,15 +163,14 @@ def prepare_and_align_batch_arrays( """ # Get out of bound state choice idx, by taking the number of state choices + 1 - out_of_bounds_state_choice_idx = state_choice_space.shape[0] + 1 + out_of_bounds_state_choice_idx = state_choice_space_dict["period"].shape[0] + 1 # First convert batch information batch_array = np.array(batches_list) child_states_to_integrate_exog = np.array(child_states_to_integrate_stochastic_list) state_choices_batches = { - key: state_choice_space[:, i][batch_array] - for i, key in enumerate(discrete_states_names + ["choice"]) + key: var[batch_array] for key, var in state_choice_space_dict.items() } # Now create the child state arrays. As these can have different shapes than the @@ -192,8 +189,8 @@ def prepare_and_align_batch_arrays( child_state_choice_idxs_to_interp ] state_choices_childs = { - key: state_choice_space[:, i][child_state_choice_idxs_to_interp] - for i, key in enumerate(discrete_states_names + ["choice"]) + key: var[child_state_choice_idxs_to_interp] + for key, var in state_choice_space_dict.items() } batch_info = { diff --git a/src/dcegm/pre_processing/model_structure/stochastic_states.py b/src/dcegm/pre_processing/model_structure/stochastic_states.py index f1ca380f..8c4ebdab 100644 --- a/src/dcegm/pre_processing/model_structure/stochastic_states.py +++ b/src/dcegm/pre_processing/model_structure/stochastic_states.py @@ -1,11 +1,14 @@ +import inspect from functools import partial from typing import Callable +import jax import numpy as np from jax import numpy as jnp from dcegm.pre_processing.model_structure.shared import span_subspace from dcegm.pre_processing.shared import ( + create_array_with_smallest_int_dtype, determine_function_arguments_and_partial_model_specs, ) @@ -24,15 +27,17 @@ def create_stochastic_transition_function( compute_stochastic_transition_vec = return_dummy_stochastic_transition func_dict = {} else: - func_list, func_dict = process_stochastic_transitions( + func_dict = process_stochastic_transitions( stochastic_states_transitions, model_config=model_config, model_specs=model_specs, continuous_state_name=continuous_state_name, ) + trans_func_list = [func_dict[name] for name in func_dict.keys()] + compute_stochastic_transition_vec = partial( - get_stochastic_transition_vec, transition_funcs=func_list + get_stochastic_transition_vec, transition_funcs=trans_func_list ) return compute_stochastic_transition_vec, func_dict @@ -46,9 +51,6 @@ def process_stochastic_transitions( Args: options (dict): Options dictionary. - Returns: - tuple: Tuple of exogenous processes. - """ func_list = [] @@ -68,7 +70,7 @@ def process_stochastic_transitions( else: raise ValueError(f"Stochastic transition function {name} is not callable. ") - return func_list, func_dict + return func_dict def get_stochastic_transition_vec(transition_funcs, params, **state_choice_vars): @@ -117,3 +119,130 @@ def process_stochastic_model_specifications(model_config): stochastic_state_space = np.array([[0]], dtype=np.uint8) return stochastic_state_names, stochastic_state_space + + +def create_sparse_stochastic_trans_map( + model_structure, model_funcs, model_config_processed, from_saved=False +): + """Create sparse mapping from state-choice to stochastic states.""" + state_choice_dict = model_structure["state_choice_space_dict"] + stochastic_transitions_dict = model_funcs["processed_stochastic_funcs"] + threshold = 1e-6 + + # Add index to state_choice_dict + n_state_choices = len(state_choice_dict[next(iter(state_choice_dict))]) + sparse_index_functions = [] + spares_stoch_trans_funcs = [] + trans_func_dict = {} + + for stoch_name, stoch_states in model_config_processed["stochastic_states"].items(): + if stoch_name == "dummy_stochastic": + continue + + has_params = ( + "params" + in inspect.signature(stochastic_transitions_dict[stoch_name]).parameters + ) + n_states = len(stoch_states) + trans_func = stochastic_transitions_dict[stoch_name] + + if has_params: + index_eval = lambda index, n=n_states: np.ones(n) / n + sparse_index_functions.append(index_eval) + spares_stoch_trans_funcs += [trans_func] + trans_func_dict[stoch_name] = trans_func + else: + + eval_func = lambda state_choice, f=trans_func: f(**state_choice) + + # Compute transitions and find indices to keep + single_transitions = jax.vmap(eval_func)(state_choice_dict) + zero_mask = single_transitions < threshold + + max_n_zeros = zero_mask.sum(axis=1).min() + if max_n_zeros == 0: + # No sparsity for this state + index_eval = lambda index, n=n_states: np.ones(n) / n + sparse_index_functions.append(index_eval) + spares_stoch_trans_funcs += [trans_func] + trans_func_dict[stoch_name] = trans_func + continue + n_keep = zero_mask.shape[1] - max_n_zeros + + keep_mask = ~zero_mask + keep_indices = np.where(keep_mask, np.arange(keep_mask.shape[1]), -1) + + # Get sorted positions and the original indices + sort_order = np.argsort(keep_indices, axis=1) + indices_to_keep = np.take_along_axis(keep_indices, sort_order, axis=1)[ + :, -n_keep: + ] + + # For positions with -1, use the original position from sort_order + original_positions = sort_order[:, -n_keep:] + indices_to_keep = np.where( + indices_to_keep == -1, original_positions, indices_to_keep + ) + # Sort again to get indices in ascending order + indices_to_keep = np.sort(indices_to_keep, axis=1) + indices_to_keep = jnp.asarray(indices_to_keep) + indices_to_keep = create_array_with_smallest_int_dtype(indices_to_keep) + + def create_sparse_func(trans_f, indices_keep): + def sparse_trans_func(**kwargs): + index_to_keep = indices_keep[kwargs["index"]] + return trans_f(**kwargs)[index_to_keep] + + return sparse_trans_func + + sparse_trans_func = create_sparse_func(trans_func, indices_to_keep) + + spares_stoch_trans_funcs += [sparse_trans_func] + trans_func_dict[stoch_name] = sparse_trans_func + + # Create sparse eval function that returns NaN for deleted states + def create_nan_padded_eval(indices_keep, n_total): + def nan_padded_eval(index): + result = jnp.full(n_total, jnp.nan) + result = result.at[indices_keep[index]].set(1.0) + return result + + return nan_padded_eval + + index_eval = create_nan_padded_eval(indices_to_keep, n_states) + sparse_index_functions.append(index_eval) + + compute_stochastic_transition_vec = partial( + get_stochastic_transition_vec, transition_funcs=spares_stoch_trans_funcs + ) + if from_saved: + return compute_stochastic_transition_vec, trans_func_dict + + # Evaluate kronecker product with NaNs + def kronecker_with_index(idx): + trans_vector = sparse_index_functions[0](idx) + for func in sparse_index_functions[1:]: + trans_vector = jnp.kron(trans_vector, func(idx)) + return trans_vector + + all_transitions = jax.vmap(kronecker_with_index)( + jnp.arange(n_state_choices), + ) + + # Find non-NaN positions (states to keep) + keep_mask = ~np.isnan(all_transitions) + + # Select non-NaN child states directly + sparse_child_states_mapping = model_structure["map_state_choice_to_child_states"][ + keep_mask + ].reshape(keep_mask.shape[0], -1) + state_choice_dict_with_idx = { + **state_choice_dict, + "index": jnp.arange(n_state_choices), + } + return ( + sparse_child_states_mapping, + state_choice_dict_with_idx, + compute_stochastic_transition_vec, + trans_func_dict, + ) diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index 49e16350..6de10304 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -14,6 +14,7 @@ from dcegm.pre_processing.model_structure.model_structure import create_model_structure from dcegm.pre_processing.model_structure.state_space import create_state_space from dcegm.pre_processing.model_structure.stochastic_states import ( + create_sparse_stochastic_trans_map, create_stochastic_state_mapping, ) from dcegm.pre_processing.shared import ( @@ -32,6 +33,7 @@ def create_model_dict( stochastic_states_transitions: Dict[str, Callable] = None, shock_functions: Dict[str, Callable] = None, debug_info: str = None, + use_stochastic_sparsity=False, ): """Set up the model for dcegm. @@ -94,6 +96,26 @@ def create_model_dict( model_funcs=model_funcs, ) + if use_stochastic_sparsity: + n_stochastic_original = model_structure[ + "map_state_choice_to_child_states" + ].shape[1] + ( + model_structure["map_state_choice_to_child_states"], + model_structure["state_choice_space_dict"], + model_funcs["compute_stochastic_transition_vec"], + model_funcs["sparse_processed_stochastic_funcs"], + ) = create_sparse_stochastic_trans_map( + model_structure=model_structure, + model_funcs=model_funcs, + model_config_processed=model_config_processed, + from_saved=False, + ) + n_sparse = model_structure["map_state_choice_to_child_states"].shape[1] + print( + f"Stochastic transition mapping sparsified from {n_stochastic_original} to {n_sparse} " + ) + model_funcs["stochastic_state_mapping"] = create_stochastic_state_mapping( model_structure["stochastic_state_space"], model_structure["stochastic_states_names"], @@ -135,6 +157,7 @@ def create_model_dict_and_save( shock_functions: Dict[str, Callable] = None, path: str = "model.pkl", debug_info=None, + use_stochastic_sparsity=False, ): """Set up the model and save. @@ -154,6 +177,7 @@ def create_model_dict_and_save( stochastic_states_transitions=stochastic_states_transitions, shock_functions=shock_functions, debug_info=debug_info, + use_stochastic_sparsity=use_stochastic_sparsity, ) dict_to_save = { @@ -175,6 +199,7 @@ def load_model_dict( stochastic_states_transitions: Dict[str, Callable] = None, shock_functions: Dict[str, Callable] = None, path: str = "model.pkl", + use_stochastic_sparsity=False, ): """Load the model from file.""" @@ -203,6 +228,18 @@ def load_model_dict( **specs_params_info, } + # Save full and then create sparsity + if use_stochastic_sparsity: + ( + model["model_funcs"]["compute_stochastic_transition_vec"], + model["model_funcs"]["sparse_processed_stochastic_funcs"], + ) = create_sparse_stochastic_trans_map( + model_structure=model["model_structure"], + model_funcs=model["model_funcs"], + model_config_processed=model["model_config"], + from_saved=True, + ) + model["model_funcs"]["stochastic_state_mapping"] = create_stochastic_state_mapping( stochastic_state_space=model["model_structure"]["stochastic_state_space"], stochastic_state_names=model["model_structure"]["stochastic_states_names"], diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index 1b3d9667..38670b49 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -10,7 +10,6 @@ interp2d_policy_and_value_on_wealth_and_regular_grid, ) from dcegm.law_of_motion import ( - calc_assets_begin_of_period_for_all_agents, calculate_assets_begin_of_period_for_all_agents, calculate_second_continuous_state_for_all_agents, ) @@ -204,28 +203,23 @@ def transition_to_next_period( compute_continuous_state=model_funcs_sim["next_period_continuous_state"], ) - assets_beginning_of_next_period, budget_aux = ( - calc_assets_begin_of_period_for_all_agents( - states_beginning_of_period=discrete_states_next_period, - continuous_state_beginning_of_period=continuous_state_next_period, - assets_end_of_period=assets_end_of_period, - income_shocks_of_period=income_shocks_next_period, - params=params, - compute_assets_begin_of_period=next_period_wealth, - ) - ) + all_states_next_period = { + **discrete_states_next_period, + "continuous_state": continuous_state_next_period, + } else: + all_states_next_period = discrete_states_next_period.copy() continuous_state_next_period = None - assets_beginning_of_next_period, budget_aux = ( - calculate_assets_begin_of_period_for_all_agents( - states_beginning_of_period=discrete_states_next_period, - asset_grid_point_end_of_previous_period=assets_end_of_period, - income_shocks_of_period=income_shocks_next_period, - params=params, - compute_assets_begin_of_period=next_period_wealth, - ) + assets_beginning_of_next_period, budget_aux = ( + calculate_assets_begin_of_period_for_all_agents( + states_beginning_of_period=all_states_next_period, + asset_grid_point_end_of_previous_period=assets_end_of_period, + income_shocks_of_period=income_shocks_next_period, + params=params, + compute_assets_begin_of_period=next_period_wealth, ) + ) return ( assets_beginning_of_next_period, diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index d3dddab7..d14e5653 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -170,7 +170,8 @@ def test_interp2d_against_scipy(test_cases, test_id): jnp.array(test_y), ) - aaae(policy_interp_jax, policy_interp_scipy, decimal=7) + # First element is out of bounds. We use internally extrapolation, while griddata returns a nan + aaae(policy_interp_jax[1:], policy_interp_scipy[1:]) @pytest.mark.parametrize("test_id", range(N_TEST_CASES)) diff --git a/tests/test_sparse_stochastic_transitions.py b/tests/test_sparse_stochastic_transitions.py new file mode 100644 index 00000000..f6ff6a16 --- /dev/null +++ b/tests/test_sparse_stochastic_transitions.py @@ -0,0 +1,205 @@ +"""Test sparse stochastic transitions.""" + +import os + +import jax +import jax.numpy as jnp +import numpy as np +from numpy.testing import assert_array_almost_equal as aaae + +import dcegm +from dcegm.toy_models.cons_ret_model_dcegm_paper import ( + budget_constraint, + create_final_period_utility_function_dict, + create_state_space_function_dict, + create_utility_function_dict, +) + + +def prob_exog_health_father(health_mother): + """Sparse transition: can only go to 2 out of 3 states depending on mother's health.""" + # health_mother == 0: can go to states 0, 1 (not 2) + # health_mother == 1: can go to states 1, 2 (not 0) + # health_mother == 2: can go to states 0, 2 (not 1) + prob_good_health = ( + (health_mother == 0) * 0.7 + + (health_mother == 1) * 0.0 + + (health_mother == 2) * 0.3 + ) + prob_medium_health = ( + (health_mother == 0) * 0.3 + + (health_mother == 1) * 0.6 + + (health_mother == 2) * 0.0 + ) + prob_bad_health = ( + (health_mother == 0) * 0.0 + + (health_mother == 1) * 0.4 + + (health_mother == 2) * 0.7 + ) + return jnp.array([prob_good_health, prob_medium_health, prob_bad_health]) + + +def prob_exog_health_mother(health_father): + """Sparse transition: can only go to 2 out of 3 states depending on father's health.""" + # health_father == 0: can go to states 0, 1 (not 2) + # health_father == 1: can go to states 0, 2 (not 1) + # health_father == 2: can go to states 1, 2 (not 0) + prob_good_health = ( + (health_father == 0) * 0.8 + + (health_father == 1) * 0.4 + + (health_father == 2) * 0.0 + ) + prob_medium_health = ( + (health_father == 0) * 0.2 + + (health_father == 1) * 0.0 + + (health_father == 2) * 0.3 + ) + prob_bad_health = ( + (health_father == 0) * 0.0 + + (health_father == 1) * 0.6 + + (health_father == 2) * 0.7 + ) + return jnp.array([prob_good_health, prob_medium_health, prob_bad_health]) + + +def prob_exog_health_child(health_child, params): + """Compute transition probabilities for a child's health.""" + prob_good_health = (health_child == 0) * 0.7 + (health_child == 1) * 0.1 + prob_medium_health = (health_child == 0) * 0.3 + (health_child == 1) * 0.9 + return jnp.array([prob_good_health, prob_medium_health]) + + +def prob_exog_health_grandma(health_grandma): + """Compute transition probabilities for a grandmother's health.""" + # This function has not every state reachable from every other state. This should not be reduced + # in sparsity. + prob_good_health = (health_grandma == 0) * 1.0 + (health_grandma == 1) * 0.15 + return jnp.array([prob_good_health, 1 - prob_good_health]) + + +def util_new( + consumption, + choice, + params, + health_mother, + health_father, + health_child, + health_grandma, +): + + utility_consumption = jax.lax.select( + jnp.allclose(params["rho"], 1), + jnp.log(consumption), + (consumption ** (1 - params["rho"]) - 1) / (1 - params["rho"]), + ) + + utility = ( + utility_consumption + - (1 - choice) * params["delta"] + + 2 * health_mother + + 1.5 * health_father + + health_child + + 0.5 * health_grandma + + (1 - health_grandma) * 5 + ) + + return utility + + +def test_sparse_stochastic_transitions(): + """Test that solving with sparse transitions gives same results.""" + + params = { + "rho": 2, + "delta": 0.5, + "discount_factor": 0.95, + "taste_shock_scale": 1, + "income_shock_std": 1, + "income_shock_mean": 0.0, + "interest_rate": 0.05, + "constant": 1, + "exp": 0.1, + "exp_squared": -0.01, + "consumption_floor": 0.5, + } + + model_specs = {"n_choices": 2, "taste_shock_scale": 1, "min_age": 20} + + model_config = { + "n_quad_points": 5, + "n_periods": 10, + "choices": np.arange(2), + "deterministic_states": { + "married": [0, 1], + }, + "continuous_states": { + "assets_end_of_period": np.linspace(0, 50, 100), + }, + "stochastic_states": { + "health_mother": [0, 1, 2], + "health_grandma": [0, 1], + "health_father": [0, 1, 2], + "health_child": [0, 1], + }, + } + + stochastic_state_transitions = { + "health_mother": prob_exog_health_mother, + "health_grandma": prob_exog_health_grandma, + "health_child": prob_exog_health_child, + "health_father": prob_exog_health_father, + } + + # Setup model first time + model_1 = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + state_space_functions=create_state_space_function_dict(), + utility_functions=create_utility_function_dict(), + utility_functions_final_period=create_final_period_utility_function_dict(), + budget_constraint=budget_constraint, + stochastic_states_transitions=stochastic_state_transitions, + use_stochastic_sparsity=False, + ) + + # Set up with saving and also with loading: + # Setup model second time with same config + model_2 = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + state_space_functions=create_state_space_function_dict(), + utility_functions=create_utility_function_dict(), + utility_functions_final_period=create_final_period_utility_function_dict(), + budget_constraint=budget_constraint, + stochastic_states_transitions=stochastic_state_transitions, + model_save_path="model_stoch_sparse.pkl", + use_stochastic_sparsity=True, + ) + + model_3 = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + state_space_functions=create_state_space_function_dict(), + utility_functions=create_utility_function_dict(), + utility_functions_final_period=create_final_period_utility_function_dict(), + budget_constraint=budget_constraint, + stochastic_states_transitions=stochastic_state_transitions, + model_load_path="model_stoch_sparse.pkl", + use_stochastic_sparsity=True, + ) + + # Solve both models + model_solved_1 = model_1.solve(params=params) + model_solved_2 = model_2.solve(params=params) + model_solved_3 = model_3.solve(params=params) + + # Test that solutions are identical + aaae(model_solved_1.endog_grid, model_solved_2.endog_grid) + aaae(model_solved_1.policy, model_solved_2.policy) + aaae(model_solved_1.value, model_solved_2.value) + aaae(model_solved_1.endog_grid, model_solved_3.endog_grid) + aaae(model_solved_1.policy, model_solved_3.policy) + aaae(model_solved_1.value, model_solved_3.value) + + # Clean up saved model file + os.remove("model_stoch_sparse.pkl") diff --git a/tests/test_stochastic_transitions.py b/tests/test_stochastic_transitions.py index 26eaabeb..65d4486a 100644 --- a/tests/test_stochastic_transitions.py +++ b/tests/test_stochastic_transitions.py @@ -1,286 +1,3 @@ -# """Test module for exogenous processes.""" - -# import copy -# from itertools import product - -# import jax.numpy as jnp -# import numpy as np -# import pytest -# from numpy.testing import assert_almost_equal as aaae - -# from dcegm.interface import validate_stochastic_states -# from dcegm.pre_processing.check_options import check_options_and_set_defaults -# from dcegm.pre_processing.model_functions import process_model_functions -# from dcegm.pre_processing.model_structure.stochastic_states import ( -# create_stochastic_states_mapping, -# ) -# from dcegm.pre_processing.model_structure.model_structure import create_model_structure -# from dcegm.pre_processing.setup_model import setup_model -# from toy_models.cons_ret_model_dcegm_paper.budget_constraint import budget_constraint -# from toy_models.cons_ret_model_dcegm_paper.state_space_objects import ( -# create_state_space_function_dict, -# ) -# from toy_models.cons_ret_model_dcegm_paper.utility_functions import ( -# create_final_period_utility_function_dict, -# create_utility_function_dict, -# ) - - -# def trans_prob_care_demand(health_state, params): -# prob_care_demand = ( -# (health_state == 0) * params["care_demand_good_health"] -# + (health_state == 1) * params["care_demand_medium_health"] -# + (health_state == 2) * params["care_demand_bad_health"] -# ) - -# return prob_care_demand - - -# def prob_exog_health_father(health_mother, params): -# prob_good_health = ( -# (health_mother == 0) * 0.7 -# + (health_mother == 1) * 0.3 -# + (health_mother == 2) * 0.2 -# ) -# prob_medium_health = ( -# (health_mother == 0) * 0.2 -# + (health_mother == 1) * 0.5 -# + (health_mother == 2) * 0.2 -# ) -# prob_bad_health = ( -# (health_mother == 0) * 0.1 -# + (health_mother == 1) * 0.2 -# + (health_mother == 2) * 0.6 -# ) - -# return jnp.array([prob_good_health, prob_medium_health, prob_bad_health]) - - -# def prob_exog_health_mother(health_father, params): -# prob_good_health = ( -# (health_father == 0) * 0.7 -# + (health_father == 1) * 0.3 -# + (health_father == 2) * 0.2 -# ) -# prob_medium_health = ( -# (health_father == 0) * 0.2 -# + (health_father == 1) * 0.5 -# + (health_father == 2) * 0.2 -# ) -# prob_bad_health = ( -# (health_father == 0) * 0.1 -# + (health_father == 1) * 0.2 -# + (health_father == 2) * 0.6 -# ) - -# return jnp.array([prob_good_health, prob_medium_health, prob_bad_health]) - - -# def prob_exog_health_child(health_child, params): -# prob_good_health = (health_child == 0) * 0.7 + (health_child == 1) * 0.1 -# prob_medium_health = (health_child == 0) * 0.3 + (health_child == 1) * 0.9 - -# return jnp.array([prob_good_health, prob_medium_health]) - - -# def prob_exog_health_grandma(health_grandma, params): -# prob_good_health = (health_grandma == 0) * 0.8 + (health_grandma == 1) * 0.15 -# prob_medium_health = (health_grandma == 0) * 0.2 + (health_grandma == 1) * 0.85 - -# return jnp.array([prob_good_health, prob_medium_health]) - - -# EXOG_STATE_GRID = [0, 1, 2] -# EXOG_STATE_GRID_SMALL = [0, 1] - - -# @pytest.mark.parametrize( -# "health_state_mother, health_state_father, health_state_child, health_state_grandma", -# product( -# EXOG_STATE_GRID, EXOG_STATE_GRID, EXOG_STATE_GRID_SMALL, EXOG_STATE_GRID_SMALL -# ), -# ) -# def test_exog_processes( -# health_state_mother, health_state_father, health_state_child, health_state_grandma -# ): -# params = { -# "rho": 0.5, -# "delta": 0.5, -# "interest_rate": 0.02, -# "ltc_cost": 5, -# "wage_avg": 8, -# "income_shock_std": 1, -# "taste_shock_scale": 1, -# "ltc_prob": 0.3, -# "discount_factor": 0.95, -# } - -# options = { -# "model_params": { -# "quadrature_points_stochastic": 5, -# "n_choices": 2, -# }, -# "state_space": { -# "n_periods": 2, -# "choices": np.arange(2), -# "deterministic_states": { -# "married": [0, 1], -# }, -# "continuous_states": { -# "assets_end_of_period": np.linspace(0, 50, 100), -# }, -# "stochastic_states": { -# "health_mother": { -# "transition": prob_exog_health_mother, -# "states": [0, 1, 2], -# }, -# "health_father": { -# "transition": prob_exog_health_father, -# "states": [0, 1, 2], -# }, -# "health_child": { -# "transition": prob_exog_health_child, -# "states": [0, 1], -# }, -# "health_grandma": { -# "transition": prob_exog_health_grandma, -# "states": [0, 1], -# }, -# }, -# }, -# } - -# options = check_options_and_set_defaults(options) - -# model = setup_model( -# options, -# state_space_functions=create_state_space_function_dict(), -# utility_functions=create_utility_function_dict(), -# utility_functions_final_period=create_final_period_utility_function_dict(), -# budget_constraint=budget_constraint, -# ) -# model_funcs = model["model_funcs"] -# model_structure = model["model_structure"] - -# stochastic_state_mapping = create_stochastic_states_mapping( -# model_structure["stochastic_state_space"].astype(np.int16), -# model_structure["stochastic_states_names"], -# ) - -# # Test the interface validation function for exogenous processes -# invalid_model = copy.deepcopy(model) -# with pytest.raises( -# ValueError, match="does not return float transition probabilities" -# ): -# invalid_model["model_funcs"]["processed_stochastic_funcs"]["health_mother"] = ( -# lambda **kwargs: jnp.array([1, 3, 4]) -# ) # Returns an array instead of a float -# validate_stochastic_states(invalid_model, params) - -# with pytest.raises( -# ValueError, match="does not return non-negative transition probabilities" -# ): -# invalid_model["model_funcs"]["processed_stochastic_funcs"]["health_mother"] = ( -# lambda **kwargs: jnp.array([0.7, -0.3, 0.6]) -# ) # Contains negative values -# validate_stochastic_states(invalid_model, params) - -# with pytest.raises( -# ValueError, match="does not return transition probabilities less or equal to 1" -# ): -# invalid_model["model_funcs"]["processed_stochastic_funcs"]["health_mother"] = ( -# lambda **kwargs: jnp.array([0.7, 1.3, 0.6]) -# ) # Contains values geq 1 -# validate_stochastic_states(invalid_model, params) - -# with pytest.raises( -# ValueError, match="does not return the correct number of transitions" -# ): -# invalid_model["model_funcs"]["processed_stochastic_funcs"]["health_mother"] = ( -# lambda **kwargs: jnp.array([0.7, 0.3]) -# ) # Wrong number of states (only 2 instead of 3) -# validate_stochastic_states(invalid_model, params) - -# with pytest.raises(ValueError, match="transition probabilities do not sum to 1"): -# invalid_model["model_funcs"]["processed_stochastic_funcs"]["health_mother"] = ( -# lambda **kwargs: jnp.array([0.6, 0.3, 0.2]) -# ) # Doesn't sum to 1 -# validate_stochastic_states(invalid_model, params) - -# # Check if valid model passes -# assert validate_stochastic_states(model, params) - -# # Check if mapping works -# mother_bad_health = np.where(model_structure["stochastic_state_space"][:, 0] == 2)[0] - -# for exog_state in mother_bad_health: -# assert stochastic_state_mapping(exog_proc_state=exog_state)["health_mother"] == 2 - -# # Now check probabilities -# state_choices_test = { -# "period": 0, -# "lagged_choice": 0, -# "married": 0, -# "health_mother": health_state_mother, -# "health_father": health_state_father, -# "health_child": health_state_child, -# "health_grandma": health_state_grandma, -# "choice": 0, -# } -# prob_vector = model_funcs["compute_stochastic_transition_vec"]( -# params=params, **state_choices_test -# ) -# prob_mother_health = model_funcs["processed_stochastic_funcs"]["health_mother"]( -# params=params, **state_choices_test -# ) -# prob_father_health = model_funcs["processed_stochastic_funcs"]["health_father"]( -# params=params, **state_choices_test -# ) -# prob_child_health = model_funcs["processed_stochastic_funcs"]["health_child"]( -# params=params, **state_choices_test -# ) -# prob_grandma_health = model_funcs["processed_stochastic_funcs"]["health_grandma"]( -# params=params, **state_choices_test -# ) - -# for exog_val, prob in enumerate(prob_vector): -# child_prob_states = stochastic_state_mapping(exog_val) -# prob_mother = prob_mother_health[child_prob_states["health_mother"]] -# prob_father = prob_father_health[child_prob_states["health_father"]] -# prob_child = prob_child_health[child_prob_states["health_child"]] -# prob_grandma = prob_grandma_health[child_prob_states["health_grandma"]] -# prob_expec = prob_mother * prob_father * prob_child * prob_grandma -# aaae(prob, prob_expec) - - -# def test_nested_exog_process(): -# """Tests that nested exogenous transition probs are calculated correctly. - -# >>> 0.3 * 0.8 + 0.3 * 0.7 + 0.4 * 0.6 -# 0.69 -# >>> 0.3 * 0.2 + 0.3 * 0.3 + 0.4 * 0.4 -# 0.31000000000000005 - -# """ -# params = { -# "care_demand_good_health": 0.2, -# "care_demand_medium_health": 0.3, -# "care_demand_bad_health": 0.4, -# } - -# trans_probs_health = jnp.array([0.3, 0.3, 0.4]) - -# prob_care_good = trans_prob_care_demand(health_state=0, params=params) -# prob_care_medium = trans_prob_care_demand(health_state=1, params=params) -# prob_care_bad = trans_prob_care_demand(health_state=2, params=params) - -# _trans_probs_care_demand = jnp.array( -# [prob_care_good, prob_care_medium, prob_care_bad] -# ) -# joint_trans_prob = trans_probs_health @ _trans_probs_care_demand -# expected = 0.3 * 0.2 + 0.3 * 0.3 + 0.4 * 0.4 - -# aaae(joint_trans_prob, expected) """Test module for exogenous processes.""" import copy @@ -355,7 +72,7 @@ def prob_exog_health_father(health_mother, params): return jnp.array([prob_good_health, prob_medium_health, prob_bad_health]) -def prob_exog_health_mother(health_father, params): +def prob_exog_health_mother(health_father): """Compute transition probabilities for mother's health, given father's health. Args: From da278893dcecedec8a112042896e126ed44b7954 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:22:00 +0100 Subject: [PATCH 29/34] [pre-commit.ci] pre-commit autoupdate (#190) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 703de93f..5aa655fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: hooks: - id: yamllint - repo: https://github.com/lyz-code/yamlfix - rev: 1.19.0 + rev: 1.19.1 hooks: - id: yamlfix - repo: https://github.com/pre-commit/pre-commit-hooks @@ -59,7 +59,7 @@ repos: # hooks: # - id: setup-cfg-fmt - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.11.0 + rev: 25.12.0 hooks: - id: black language_version: python3.12 From e1d8e815eb43b49c0f9e8f7fa60f7797ffaa117b Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 13 Jan 2026 17:08:20 +0100 Subject: [PATCH 30/34] Partial test --- .pre-commit-config.yaml | 2 +- environment.yml | 1 - src/dcegm/interfaces/model_class.py | 15 +++++++++++ tests/test_partial.py | 39 +++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 tests/test_partial.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5aa655fb..80e662a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: rev: 25.12.0 hooks: - id: black - language_version: python3.12 + language_version: python3.13 # - repo: https://github.com/charliermarsh/ruff-pre-commit # rev: v0.0.282 # hooks: diff --git a/environment.yml b/environment.yml index 1839acdd..f08c689a 100644 --- a/environment.yml +++ b/environment.yml @@ -22,7 +22,6 @@ dependencies: - flake8 - jupyterlab - matplotlib - - pdbpp - pre-commit - setuptools_scm - toml diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 336a5162..a753899b 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -10,6 +10,7 @@ get_child_state_index_per_states_and_choices, get_state_choice_index_per_discrete_states, ) +from dcegm.interfaces.inspect_solution import partially_solve from dcegm.interfaces.interface import ( get_n_state_choice_period, validate_stochastic_transition, @@ -460,3 +461,17 @@ def compute_law_of_motions(self, params): def get_n_state_choices_per_period(self): return get_n_state_choice_period(self.model_structure) + + def solve_partially(self, params, n_periods, return_candidates=False): + + return partially_solve( + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + batch_info=self.batch_info, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + params=params, + n_periods=n_periods, + return_candidates=return_candidates, + ) diff --git a/tests/test_partial.py b/tests/test_partial.py new file mode 100644 index 00000000..2c81b50d --- /dev/null +++ b/tests/test_partial.py @@ -0,0 +1,39 @@ +from pathlib import Path + +from numpy.testing import assert_array_almost_equal as aaae + +import dcegm +import dcegm.toy_models as toy_models + +# Obtain the test directory of the package +TEST_DIR = Path(__file__).parent + + +def test_partial_solve_func(): + model_funcs = toy_models.load_example_model_functions("dcegm_paper") + + model_name = "retirement_with_shocks" + params, model_specs, model_config = ( + toy_models.load_example_params_model_specs_and_config( + "dcegm_paper_" + model_name + ) + ) + + model = dcegm.setup_model( + model_config=model_config, + model_specs=model_specs, + **model_funcs, + ) + + model_solved = model.solve(params) + + partial_sol = model.solve_partially( + params=params, + n_periods=model_config["n_periods"], + return_candidates=True, + ) + + # Now without loop + aaae(model_solved.policy, partial_sol["policy"]) + aaae(model_solved.value, partial_sol["value"]) + aaae(model_solved.endog_grid, partial_sol["endog_grid"]) From 125aba146021d2903bf20c4cb90bab23b2c796a0 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 13 Jan 2026 17:20:37 +0100 Subject: [PATCH 31/34] More tests --- ...tial.py => test_partial_and_interfaces.py} | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) rename tests/{test_partial.py => test_partial_and_interfaces.py} (50%) diff --git a/tests/test_partial.py b/tests/test_partial_and_interfaces.py similarity index 50% rename from tests/test_partial.py rename to tests/test_partial_and_interfaces.py index 2c81b50d..5aaaca23 100644 --- a/tests/test_partial.py +++ b/tests/test_partial_and_interfaces.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np from numpy.testing import assert_array_almost_equal as aaae import dcegm @@ -37,3 +38,30 @@ def test_partial_solve_func(): aaae(model_solved.policy, partial_sol["policy"]) aaae(model_solved.value, partial_sol["value"]) aaae(model_solved.endog_grid, partial_sol["endog_grid"]) + + state_choices = model_solved.model_structure["state_choice_space"] + choices = state_choices[:, -1] + states_dict = { + state: state_choices[:, id] + for id, state in enumerate( + model_solved.model_structure["discrete_states_names"] + ) + } + states_dict["assets_begin_of_period"] = model_solved.endog_grid[:, 5] + value_states_all_choices = model_solved.choice_values_for_states(states=states_dict) + + # Take in each row the value corresponding to the choice made + value_choices = value_states_all_choices[ + np.arange(value_states_all_choices.shape[0]), choices + ] + + aaae(model_solved.value[:, 5], value_choices) + + # Same for policies + policy_states_all_choices = model_solved.choice_policies_for_states( + states=states_dict + ) + policy_choices = policy_states_all_choices[ + np.arange(policy_states_all_choices.shape[0]), choices + ] + aaae(model_solved.policy[:, 5], policy_choices) From ba2b67f37de84345b3570615bdc435b8081018f3 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 13 Jan 2026 17:24:01 +0100 Subject: [PATCH 32/34] More tests --- src/dcegm/interfaces/inspect_solution.py | 19 +++++++++++-------- tests/test_partial_and_interfaces.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py index d3e4e63c..0bac8b28 100644 --- a/src/dcegm/interfaces/inspect_solution.py +++ b/src/dcegm/interfaces/inspect_solution.py @@ -93,14 +93,7 @@ def partially_solve( last_two_period_batch_info["idx_state_choices_second_last_period"] = ( last_two_period_batch_info["idx_state_choices_second_last_period"] - rescale_idx ) - ( - value_solved, - policy_solved, - endog_grid_solved, - value_candidates_second_last, - policy_candidates_second_last, - endog_grid_candidates_second_last, - ) = solve_last_two_periods( + last_two_period_sols = solve_last_two_periods( params=params, continuous_states_info=continuous_states_info, cont_grids_next_period=cont_grids_next_period, @@ -113,6 +106,14 @@ def partially_solve( debug_info=debug_info, ) if return_candidates: + ( + value_solved, + policy_solved, + endog_grid_solved, + value_candidates_second_last, + policy_candidates_second_last, + endog_grid_candidates_second_last, + ) = last_two_period_sols idx_second_last = batch_info_internal["last_two_period_info"][ "idx_state_choices_second_last_period" ] @@ -125,6 +126,8 @@ def partially_solve( endog_grid_candidates = endog_grid_candidates.at[idx_second_last, ...].set( endog_grid_candidates_second_last ) + else: + value_solved, policy_solved, endog_grid_solved = last_two_period_sols if n_periods <= 2: out_dict = { diff --git a/tests/test_partial_and_interfaces.py b/tests/test_partial_and_interfaces.py index 5aaaca23..ed2af21b 100644 --- a/tests/test_partial_and_interfaces.py +++ b/tests/test_partial_and_interfaces.py @@ -39,6 +39,16 @@ def test_partial_solve_func(): aaae(model_solved.value, partial_sol["value"]) aaae(model_solved.endog_grid, partial_sol["endog_grid"]) + partial_sol_2 = model.solve_partially( + params=params, + n_periods=model_config["n_periods"], + return_candidates=False, + ) + + aaae(model_solved.policy, partial_sol_2["policy"]) + aaae(model_solved.value, partial_sol_2["value"]) + aaae(model_solved.endog_grid, partial_sol_2["endog_grid"]) + state_choices = model_solved.model_structure["state_choice_space"] choices = state_choices[:, -1] states_dict = { From 3791fad6ef6d9dc54ddd788cc2eeee32c9d6881c Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 13 Jan 2026 17:29:39 +0100 Subject: [PATCH 33/34] cover --- tests/test_partial_and_interfaces.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_partial_and_interfaces.py b/tests/test_partial_and_interfaces.py index ed2af21b..563cfca5 100644 --- a/tests/test_partial_and_interfaces.py +++ b/tests/test_partial_and_interfaces.py @@ -75,3 +75,10 @@ def test_partial_solve_func(): np.arange(policy_states_all_choices.shape[0]), choices ] aaae(model_solved.policy[:, 5], policy_choices) + + value_solved_fast, policy_solved_fast, endog_grid_solved_fast = ( + model.get_fast_solve_func()(params) + ) + aaae(model_solved.value, value_solved_fast) + aaae(model_solved.policy, policy_solved_fast) + aaae(model_solved.endog_grid, endog_grid_solved_fast) From 6d22fc5f9e1b629d5f4c8bf5b4702d68e7f094ce Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Tue, 13 Jan 2026 17:35:51 +0100 Subject: [PATCH 34/34] loads more tests --- tests/test_biased_sim.py | 49 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/tests/test_biased_sim.py b/tests/test_biased_sim.py index d42bdb37..5a87f925 100644 --- a/tests/test_biased_sim.py +++ b/tests/test_biased_sim.py @@ -1,12 +1,10 @@ import jax.numpy as jnp import numpy as np +import pandas as pd import pytest import dcegm import dcegm.toy_models as toy_models -from dcegm.pre_processing.setup_model import create_model_dict -from dcegm.simulation.sim_utils import create_simulation_df -from dcegm.simulation.simulate import simulate_all_periods def utility_crra( @@ -96,7 +94,7 @@ def test_sim_and_sol_model(model_configs): "stochastic_states_transitions": stochastic_states_transitions, } - model_sol = dcegm.setup_model( + model = dcegm.setup_model( model_config=model_configs["solution"], model_specs=model_specs, state_space_functions=model_funcs["state_space_functions"], @@ -119,12 +117,53 @@ def test_sim_and_sol_model(model_configs): "assets_begin_of_period": np.ones(n_agents, dtype=float) * 10, } - df = model_sol.solve_and_simulate( + df = model.solve_and_simulate( params=params, states_initial=states_initial, seed=123, ) + ################################## + # First compare with other ways to setup up alternative sim specificatios + ################################## + model_init = dcegm.setup_model( + model_config=model_configs["solution"], + model_specs=model_specs, + state_space_functions=model_funcs["state_space_functions"], + utility_functions=utility_functions, + utility_functions_final_period=model_funcs["utility_functions_final_period"], + budget_constraint=model_funcs["budget_constraint"], + ) + model_init.set_alternative_sim_funcs(alternative_sim_specifications=alt_model_specs) + + df_2 = model_init.solve_and_simulate( + params=params, + states_initial=states_initial, + seed=123, + ) + pd.testing.assert_frame_equal(df, df_2) + + model_init = dcegm.setup_model( + model_config=model_configs["solution"], + model_specs=model_specs, + state_space_functions=model_funcs["state_space_functions"], + utility_functions=utility_functions, + utility_functions_final_period=model_funcs["utility_functions_final_period"], + budget_constraint=model_funcs["budget_constraint"], + ) + + model_init.set_alternative_sim_funcs( + alternative_sim_specifications=alt_model_specs, + alternative_specs=model_specs, + ) + + df_3 = model_init.solve_and_simulate( + params=params, + states_initial=states_initial, + seed=123, + ) + pd.testing.assert_frame_equal(df, df_3) + ########################################### # Compare marriage shares as they must be governed # by the transition matrix in the simulation