From b7047d5625b9c14a14ea46633088558398bb317a Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Thu, 15 Jan 2026 17:07:35 +0100 Subject: [PATCH 1/4] Initial fix --- src/dcegm/interfaces/model_class.py | 70 +++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index a753899b..e3209f96 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -1,5 +1,6 @@ import pickle as pkl from functools import partial +from grp import struct_group from typing import Callable, Dict import jax @@ -143,15 +144,15 @@ def set_alternative_sim_funcs( ) self.alternative_sim_funcs = alternative_sim_funcs - def backward_induction_inner_jit(self, params): + def backward_induction_inner_jit(self, params, model_structure, batch_info): return backward_induction( params=params, income_shock_draws_unscaled=self.income_shock_draws_unscaled, income_shock_weights=self.income_shock_weights, model_config=self.model_config, - batch_info=self.batch_info, + batch_info=batch_info, model_funcs=self.model_funcs, - model_structure=self.model_structure, + model_structure=model_structure, ) def get_fast_solve_func(self): @@ -281,7 +282,7 @@ def get_solve_and_simulate_func( slow_version=False, ): - sim_func = lambda params, value, policy, endog_gid: simulate_all_periods( + sim_func = lambda params, value, policy, endog_gid, model_structure: simulate_all_periods( states_initial=states_initial, n_periods=self.model_config["n_periods"], params=params, @@ -290,16 +291,68 @@ def get_solve_and_simulate_func( policy_solved=policy, value_solved=value, model_config=self.model_config, - model_structure=self.model_structure, + model_structure=model_structure, model_funcs=self.model_funcs, alt_model_funcs_sim=self.alternative_sim_funcs, ) - def solve_and_simulate_function_to_jit(params): + struct_keys_not_for_jit = [ + "discrete_states_names", + "state_names_without_stochastic", + "stochastic_states_names", + ] + model_structure_non_jit = { + key: self.model_structure[key] for key in struct_keys_not_for_jit + } + model_structure_jit = self.model_structure.copy() + # Remove non-jittable items + for key in struct_keys_not_for_jit: + model_structure_jit.pop(key, None) + + # Remove non-jittable items from batch_info + batch_info_jit = self.batch_info.copy() + batch_info_non_jit = { + "two_period_model": self.batch_info["two_period_model"], + } + batch_info_jit.pop("two_period_model", None) + # If it is not a two period model, there is more + if not self.batch_info["two_period_model"]: + batch_info_non_jit["n_segments"] = self.batch_info["n_segments"] + batch_info_jit.pop("n_segments", None) + for batch_id in range(batch_info_non_jit["n_segments"]): + batch_key = f"batches_info_segment_{batch_id}" + batch_info_non_jit[batch_key] = {} + batch_info_non_jit[batch_key]["batches_cover_all"] = self.batch_info[ + batch_key + ]["batches_cover_all"] + batch_info_jit[batch_key].pop("batches_cover_all", None) + + def solve_and_simulate_function_to_jit( + params, model_structure_jit, batch_info_jit + ): params_processed = process_params(params, self.params_check_info) + + model_structure = { + **model_structure_jit, + **model_structure_non_jit, + } + batch_info = { + **batch_info_jit, + "two_period_model": batch_info_non_jit["two_period_model"], + } + if not batch_info_non_jit["two_period_model"]: + batch_info["n_segments"] = batch_info_non_jit["n_segments"] + for batch_id in range(batch_info_non_jit["n_segments"]): + batch_key = f"batches_info_segment_{batch_id}" + batch_info[batch_key]["batches_cover_all"] = batch_info_non_jit[ + batch_key + ]["batches_cover_all"] + # Solve the model value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed + params_processed, + model_structure=model_structure, + batch_info=batch_info, ) sim_dict = sim_func( @@ -307,6 +360,7 @@ def solve_and_simulate_function_to_jit(params): value=value, policy=policy, endog_gid=endog_grid, + model_structure=model_structure, ) return sim_dict @@ -317,7 +371,7 @@ def solve_and_simulate_function_to_jit(params): solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) def solve_and_simulate_function(params): - sim_dict = solve_simulate_func(params) + sim_dict = solve_simulate_func(params, model_structure_jit, batch_info_jit) df = create_simulation_df(sim_dict) return df From 4efd5063654e49458c85c5d55e23b7c86df9cb16 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 16 Jan 2026 09:53:29 +0100 Subject: [PATCH 2/4] Done --- src/dcegm/interfaces/jit_large_arrays.py | 66 +++++++ src/dcegm/interfaces/model_class.py | 230 ++++++++++++----------- src/dcegm/likelihood.py | 15 +- tests/test_partial_and_interfaces.py | 10 +- 4 files changed, 208 insertions(+), 113 deletions(-) create mode 100644 src/dcegm/interfaces/jit_large_arrays.py diff --git a/src/dcegm/interfaces/jit_large_arrays.py b/src/dcegm/interfaces/jit_large_arrays.py new file mode 100644 index 00000000..682f3ff8 --- /dev/null +++ b/src/dcegm/interfaces/jit_large_arrays.py @@ -0,0 +1,66 @@ +def split_structure_and_batch_info(model_structure, batch_info): + """Splits the model structure and batch info into static parts, which we can not jit + compile and (large) arrays that we want to include in the function call for + jitting.""" + + struct_keys_not_for_jit = [ + "discrete_states_names", + "state_names_without_stochastic", + "stochastic_states_names", + ] + model_structure_non_jit = { + key: model_structure[key] for key in struct_keys_not_for_jit + } + model_structure_jit = model_structure.copy() + # Remove non-jittable items + for key in struct_keys_not_for_jit: + model_structure_jit.pop(key, None) + + # Remove non-jittable items from batch_info + batch_info_jit = batch_info.copy() + batch_info_non_jit = { + "two_period_model": batch_info["two_period_model"], + } + batch_info_jit.pop("two_period_model", None) + # If it is not a two period model, there is more + if not batch_info["two_period_model"]: + batch_info_non_jit["n_segments"] = batch_info["n_segments"] + batch_info_jit.pop("n_segments", None) + for batch_id in range(batch_info_non_jit["n_segments"]): + batch_key = f"batches_info_segment_{batch_id}" + batch_info_non_jit[batch_key] = {} + batch_info_non_jit[batch_key]["batches_cover_all"] = batch_info[batch_key][ + "batches_cover_all" + ] + batch_info_jit[batch_key].pop("batches_cover_all", None) + + return ( + model_structure_jit, + batch_info_jit, + model_structure_non_jit, + batch_info_non_jit, + ) + + +def merge_non_jit_and_jit_model_structure(model_structure_jit, model_structure_non_jit): + """Generate one model_structure to handle inside the package functions.""" + model_structure = { + **model_structure_jit, + **model_structure_non_jit, + } + return model_structure + + +def merg_non_jit_batch_info_and_jit_batch_info(batch_info_jit, batch_info_non_jit): + batch_info = { + **batch_info_jit, + "two_period_model": batch_info_non_jit["two_period_model"], + } + if not batch_info_non_jit["two_period_model"]: + batch_info["n_segments"] = batch_info_non_jit["n_segments"] + for batch_id in range(batch_info_non_jit["n_segments"]): + batch_key = f"batches_info_segment_{batch_id}" + batch_info[batch_key]["batches_cover_all"] = batch_info_non_jit[batch_key][ + "batches_cover_all" + ] + return batch_info diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index e3209f96..ea72ce01 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -16,6 +16,11 @@ get_n_state_choice_period, validate_stochastic_transition, ) +from dcegm.interfaces.jit_large_arrays import ( + merg_non_jit_batch_info_and_jit_batch_info, + merge_non_jit_and_jit_model_structure, + split_structure_and_batch_info, +) from dcegm.interfaces.sol_interface import model_solved from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.likelihood import create_individual_likelihood_function @@ -125,51 +130,6 @@ def __init__( else: self.alternative_sim_funcs = None - def set_alternative_sim_funcs( - self, alternative_sim_specifications, alternative_specs=None - ): - if alternative_specs is None: - self.alternative_sim_specs = self.model_specs - alternative_specs_without_jax = self.specs_without_jax - else: - self.alternative_sim_specs = jax.tree_util.tree_map( - try_jax_array, alternative_specs - ) - alternative_specs_without_jax = alternative_specs - - alternative_sim_funcs = generate_alternative_sim_functions( - model_specs=alternative_specs_without_jax, - model_specs_jax=self.alternative_sim_specs, - **alternative_sim_specifications, - ) - self.alternative_sim_funcs = alternative_sim_funcs - - def backward_induction_inner_jit(self, params, model_structure, batch_info): - return backward_induction( - params=params, - income_shock_draws_unscaled=self.income_shock_draws_unscaled, - income_shock_weights=self.income_shock_weights, - model_config=self.model_config, - batch_info=batch_info, - model_funcs=self.model_funcs, - model_structure=model_structure, - ) - - def get_fast_solve_func(self): - backward_jit = jax.jit( - partial( - backward_induction, - income_shock_draws_unscaled=self.income_shock_draws_unscaled, - income_shock_weights=self.income_shock_weights, - model_config=self.model_config, - batch_info=self.batch_info, - model_funcs=self.model_funcs, - model_structure=self.model_structure, - ) - ) - - return backward_jit - def solve(self, params, load_sol_path=None, save_sol_path=None): """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm. @@ -199,8 +159,14 @@ def solve(self, params, load_sol_path=None, save_sol_path=None): if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed + value, policy, endog_grid = backward_induction( + params=params_processed, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + batch_info=self.batch_info, ) sol_dict = { "value": value, @@ -246,8 +212,14 @@ def solve_and_simulate( if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed + value, policy, endog_grid = backward_induction( + params=params_processed, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + batch_info=self.batch_info, ) sol_dict = { @@ -275,13 +247,68 @@ def solve_and_simulate( sim_df = create_simulation_df(sim_dict) return sim_df + def get_solve_func(self): + """Create a fast function for solving that is jit compiled in the first call.""" + + ( + model_structure_for_jit, + batch_info_for_jit, + model_structure_non_jit, + batch_info_non_jit, + ) = split_structure_and_batch_info(self.model_structure, self.batch_info) + + def solve_function_to_jit(params, model_structure_jit, batch_info_jit): + params_processed = process_params(params, self.params_check_info) + + # Merge back parts together. The non_jit objects are fixed in the closure. + model_structure = merge_non_jit_and_jit_model_structure( + model_structure_jit, model_structure_non_jit + ) + batch_info = merg_non_jit_batch_info_and_jit_batch_info( + batch_info_jit, batch_info_non_jit + ) + + # Solve the model. + value, policy, endog_grid = backward_induction( + params=params_processed, + model_structure=model_structure, + batch_info=batch_info, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, + ) + + return value, policy, endog_grid + + solve_func = jax.jit(solve_function_to_jit) + + # Generate the function. The user only needs to provide params, but we call with the objects for jit. + def solve_function(params): + """Solve the model for given params.""" + value, policy, endog_grid = solve_func( + params, model_structure_for_jit, batch_info_for_jit + ) + model_solved_class = model_solved( + model=self, + params=params, + value=value, + policy=policy, + endog_grid=endog_grid, + ) + return model_solved_class + + return solve_function + def get_solve_and_simulate_func( self, states_initial, seed, - slow_version=False, ): + """Create a fast function for solving and simulation that is jit compiled in the + first call.""" + # Fix everything except params, solution of the model and model_structure which contains large arrays. sim_func = lambda params, value, policy, endog_gid, model_structure: simulate_all_periods( states_initial=states_initial, n_periods=self.model_config["n_periods"], @@ -296,63 +323,35 @@ def get_solve_and_simulate_func( alt_model_funcs_sim=self.alternative_sim_funcs, ) - struct_keys_not_for_jit = [ - "discrete_states_names", - "state_names_without_stochastic", - "stochastic_states_names", - ] - model_structure_non_jit = { - key: self.model_structure[key] for key in struct_keys_not_for_jit - } - model_structure_jit = self.model_structure.copy() - # Remove non-jittable items - for key in struct_keys_not_for_jit: - model_structure_jit.pop(key, None) - - # Remove non-jittable items from batch_info - batch_info_jit = self.batch_info.copy() - batch_info_non_jit = { - "two_period_model": self.batch_info["two_period_model"], - } - batch_info_jit.pop("two_period_model", None) - # If it is not a two period model, there is more - if not self.batch_info["two_period_model"]: - batch_info_non_jit["n_segments"] = self.batch_info["n_segments"] - batch_info_jit.pop("n_segments", None) - for batch_id in range(batch_info_non_jit["n_segments"]): - batch_key = f"batches_info_segment_{batch_id}" - batch_info_non_jit[batch_key] = {} - batch_info_non_jit[batch_key]["batches_cover_all"] = self.batch_info[ - batch_key - ]["batches_cover_all"] - batch_info_jit[batch_key].pop("batches_cover_all", None) + ( + model_structure_for_jit, + batch_info_for_jit, + model_structure_non_jit, + batch_info_non_jit, + ) = split_structure_and_batch_info(self.model_structure, self.batch_info) def solve_and_simulate_function_to_jit( params, model_structure_jit, batch_info_jit ): params_processed = process_params(params, self.params_check_info) - model_structure = { - **model_structure_jit, - **model_structure_non_jit, - } - batch_info = { - **batch_info_jit, - "two_period_model": batch_info_non_jit["two_period_model"], - } - if not batch_info_non_jit["two_period_model"]: - batch_info["n_segments"] = batch_info_non_jit["n_segments"] - for batch_id in range(batch_info_non_jit["n_segments"]): - batch_key = f"batches_info_segment_{batch_id}" - batch_info[batch_key]["batches_cover_all"] = batch_info_non_jit[ - batch_key - ]["batches_cover_all"] - - # Solve the model - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed, + # Merge back parts together. The non_jit objects are fixed in the closure. + model_structure = merge_non_jit_and_jit_model_structure( + model_structure_jit, model_structure_non_jit + ) + batch_info = merg_non_jit_batch_info_and_jit_batch_info( + batch_info_jit, batch_info_non_jit + ) + + # Solve the model. + value, policy, endog_grid = backward_induction( + params=params_processed, model_structure=model_structure, batch_info=batch_info, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, ) sim_dict = sim_func( @@ -365,13 +364,13 @@ def solve_and_simulate_function_to_jit( return sim_dict - if slow_version: - solve_simulate_func = solve_and_simulate_function_to_jit - else: - solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) + solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) + # Generate the function. The user only needs to provide params, but we call with the objects for jit. def solve_and_simulate_function(params): - sim_dict = solve_simulate_func(params, model_structure_jit, batch_info_jit) + sim_dict = solve_simulate_func( + params, model_structure_for_jit, batch_info_for_jit + ) df = create_simulation_df(sim_dict) return df @@ -389,11 +388,13 @@ def create_experimental_ll_func( ): return create_individual_likelihood_function( + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + batch_info=self.batch_info, model_structure=self.model_structure, model_config=self.model_config, model_funcs=self.model_funcs, model_specs=self.model_specs, - backwards_induction_inner_jit=self.backward_induction_inner_jit, observed_states=observed_states, observed_choices=observed_choices, params_all=params_all, @@ -529,3 +530,22 @@ def solve_partially(self, params, n_periods, return_candidates=False): n_periods=n_periods, return_candidates=return_candidates, ) + + def set_alternative_sim_funcs( + self, alternative_sim_specifications, alternative_specs=None + ): + if alternative_specs is None: + self.alternative_sim_specs = self.model_specs + alternative_specs_without_jax = self.specs_without_jax + else: + self.alternative_sim_specs = jax.tree_util.tree_map( + try_jax_array, alternative_specs + ) + alternative_specs_without_jax = alternative_specs + + alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=alternative_specs_without_jax, + model_specs_jax=self.alternative_sim_specs, + **alternative_sim_specifications, + ) + self.alternative_sim_funcs = alternative_sim_funcs diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 3c9c66bc..f05ab010 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -12,6 +12,7 @@ import numpy as np from jax import vmap +from dcegm.backward_induction import backward_induction from dcegm.egm.aggregate_marginal_utility import ( calculate_choice_probs_and_unsqueezed_logsum, ) @@ -20,11 +21,13 @@ def create_individual_likelihood_function( + income_shock_draws_unscaled, + income_shock_weights, + batch_info, model_structure, model_config, model_funcs, model_specs, - backwards_induction_inner_jit, observed_states: Dict[str, int], observed_choices, params_all, @@ -50,7 +53,15 @@ def individual_likelihood(params): params_update = params_all.copy() params_update.update(params) - value, policy, endog_grid = backwards_induction_inner_jit(params_update) + value, policy, endog_grid = backward_induction( + params=params_update, + income_shock_draws_unscaled=income_shock_draws_unscaled, + income_shock_weights=income_shock_weights, + model_config=model_config, + model_funcs=model_funcs, + model_structure=model_structure, + batch_info=batch_info, + ) choice_probs = choice_prob_func( value_in=value, diff --git a/tests/test_partial_and_interfaces.py b/tests/test_partial_and_interfaces.py index 563cfca5..4ecd52fc 100644 --- a/tests/test_partial_and_interfaces.py +++ b/tests/test_partial_and_interfaces.py @@ -76,9 +76,7 @@ def test_partial_solve_func(): ] aaae(model_solved.policy[:, 5], policy_choices) - value_solved_fast, policy_solved_fast, endog_grid_solved_fast = ( - model.get_fast_solve_func()(params) - ) - aaae(model_solved.value, value_solved_fast) - aaae(model_solved.policy, policy_solved_fast) - aaae(model_solved.endog_grid, endog_grid_solved_fast) + model_solved_fast = model.get_solve_func()(params) + aaae(model_solved.value, model_solved_fast.value) + aaae(model_solved.policy, model_solved_fast.policy) + aaae(model_solved.endog_grid, model_solved_fast.endog_grid) From f43a0a4bb649eb61179aa3e3db1696440f9bcae4 Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 16 Jan 2026 10:39:58 +0100 Subject: [PATCH 3/4] implemented also in likelihood --- src/dcegm/likelihood.py | 133 ++++++++++++++++++++++++++-------------- 1 file changed, 87 insertions(+), 46 deletions(-) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index f05ab010..9cc5859a 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -18,6 +18,11 @@ ) from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.interfaces.interface import choice_values_for_states +from dcegm.interfaces.jit_large_arrays import ( + merg_non_jit_batch_info_and_jit_batch_info, + merge_non_jit_and_jit_model_structure, + split_structure_and_batch_info, +) def create_individual_likelihood_function( @@ -37,7 +42,7 @@ def create_individual_likelihood_function( slow_version=False, ): - choice_prob_func = create_choice_prob_function( + choice_prob_func, data_from_observed_states = create_choice_prob_function( model_structure=model_structure, model_config=model_config, model_funcs=model_funcs, @@ -49,24 +54,40 @@ def create_individual_likelihood_function( return_weight_func=False, ) - def individual_likelihood(params): + ( + model_structure_for_jit, + batch_info_for_jit, + model_structure_non_jit, + batch_info_non_jit, + ) = split_structure_and_batch_info(model_structure, batch_info) + + def individual_likelihood_to_jit(params, model_structure_jit, batch_info_jit): params_update = params_all.copy() params_update.update(params) + # Merge back parts together. The non_jit objects are fixed in the closure. + model_structure_merged = merge_non_jit_and_jit_model_structure( + model_structure_jit, model_structure_non_jit + ) + batch_info_merged = merg_non_jit_batch_info_and_jit_batch_info( + batch_info_jit, batch_info_non_jit + ) + value, policy, endog_grid = backward_induction( params=params_update, income_shock_draws_unscaled=income_shock_draws_unscaled, income_shock_weights=income_shock_weights, model_config=model_config, model_funcs=model_funcs, - model_structure=model_structure, - batch_info=batch_info, + model_structure=model_structure_merged, + batch_info=batch_info_merged, ) choice_probs = choice_prob_func( value_in=value, endog_grid_in=endog_grid, params_in=params_update, + data_from_observed=data_from_observed_states, ) # Negative ll contributions are positive numbers. The smaller the better the fit # Add high fixed punishment for not explained choices @@ -83,9 +104,18 @@ def individual_likelihood(params): return neg_likelihood_contributions if slow_version: - return individual_likelihood + likelihood_function_int = individual_likelihood_to_jit else: - return jax.jit(individual_likelihood) + likelihood_function_int = jax.jit(individual_likelihood_to_jit) + + def likelihood_function(params): + return likelihood_function_int( + params=params, + model_structure_jit=model_structure_for_jit, + batch_info_jit=batch_info_for_jit, + ) + + return likelihood_function def create_choice_prob_function( @@ -100,28 +130,32 @@ def create_choice_prob_function( return_weight_func, ): if unobserved_state_specs is None: - choice_prob_func = create_partial_choice_prob_calculation( - observed_states=observed_states, - observed_choices=observed_choices, - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, + choice_prob_func, data_from_observed_states = ( + create_partial_choice_prob_calculation( + observed_states=observed_states, + observed_choices=observed_choices, + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + ) ) else: - choice_prob_func = create_choice_prob_func_unobserved_states( - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - model_specs=model_specs, - observed_states=observed_states, - observed_choices=observed_choices, - unobserved_state_specs=unobserved_state_specs, - use_probability_of_observed_states=use_probability_of_observed_states, - return_weight_func=return_weight_func, + choice_prob_func, data_from_observed_states = ( + create_choice_prob_func_unobserved_states( + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + model_specs=model_specs, + observed_states=observed_states, + observed_choices=observed_choices, + unobserved_state_specs=unobserved_state_specs, + use_probability_of_observed_states=use_probability_of_observed_states, + return_weight_func=return_weight_func, + ) ) - return choice_prob_func + return choice_prob_func, data_from_observed_states def create_choice_prob_func_unobserved_states( @@ -212,16 +246,18 @@ def create_choice_prob_func_unobserved_states( # Create a list of partial choice probability functions for each unique # combination of unobserved states. partial_choice_probs_unobserved_states = [] + data_for_unobserved_states = [] for states in possible_states: - partial_choice_probs_unobserved_states.append( - create_partial_choice_prob_calculation( - observed_states=states, - observed_choices=observed_choices, - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - ) + choice_func, data = create_partial_choice_prob_calculation( + observed_states=states, + observed_choices=observed_choices, + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, ) + partial_choice_probs_unobserved_states.append(choice_func) + data_for_unobserved_states.append(data) + partial_weight_func = ( lambda params_in, weight_vars: calculate_weights_for_each_state( params=params_in, @@ -239,12 +275,12 @@ def create_choice_prob_func_unobserved_states( lambda x: jnp.asarray(x), weighting_vars_for_possible_states ) - def choice_prob_func(value_in, endog_grid_in, params_in): + def choice_prob_func(value_in, endog_grid_in, params_in, data_for_choice_funcs): choice_probs_final = jnp.zeros(n_obs, dtype=jnp.float64) integrate_out_weights = jnp.zeros(n_obs, dtype=jnp.float64) - for partial_choice_prob, unobserved_state, weighting_vars in zip( + for partial_choice_prob, data_for_choice_func, weighting_vars in zip( partial_choice_probs_unobserved_states, - possible_states, + data_for_choice_funcs, weighting_vars_for_possible_states, ): unobserved_weights = jax.vmap( @@ -259,6 +295,7 @@ def choice_prob_func(value_in, endog_grid_in, params_in): value_in=value_in, endog_grid_in=endog_grid_in, params_in=params_in, + data_from_observed=data_for_choice_func, ) weighted_choice_prob = jnp.nan_to_num( @@ -277,11 +314,7 @@ def choice_prob_func(value_in, endog_grid_in, params_in): def weight_only_func(params_in): weights = np.zeros((n_obs, len(possible_states)), dtype=np.float64) count = 0 - for partial_choice_prob, unobserved_state, weighting_vars in zip( - partial_choice_probs_unobserved_states, - possible_states, - weighting_vars_for_possible_states, - ): + for weighting_vars in weighting_vars_for_possible_states: unobserved_weights = jax.vmap( partial_weight_func, in_axes=(None, 0), @@ -300,9 +333,9 @@ def weight_only_func(params_in): ) if return_weight_func: - return choice_prob_func, weight_only_func + return choice_prob_func, weight_only_func, data_for_unobserved_states else: - return choice_prob_func + return choice_prob_func, data_for_unobserved_states def create_partial_choice_prob_calculation( @@ -320,19 +353,27 @@ def create_partial_choice_prob_calculation( discrete_states_names=model_structure["discrete_states_names"], ) - def partial_choice_prob_func(value_in, endog_grid_in, params_in): + data_from_observed_wrapped = ( + observed_states, + observed_choices, + discrete_observed_state_choice_indexes, + ) + + def partial_choice_prob_func( + value_in, endog_grid_in, params_in, data_from_observed + ): return calc_choice_prob_for_state_choices( value_solved=value_in, endog_grid_solved=endog_grid_in, params=params_in, - states=observed_states, - choices=observed_choices, - state_choice_indexes=discrete_observed_state_choice_indexes, + states=data_from_observed[0], + choices=data_from_observed[1], + state_choice_indexes=data_from_observed[2], model_config=model_config, model_funcs=model_funcs, ) - return partial_choice_prob_func + return partial_choice_prob_func, data_from_observed_wrapped def calc_choice_prob_for_state_choices( From 4f7377425ceb8625d33fb1561a485abd1720f2ab Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Fri, 16 Jan 2026 10:48:39 +0100 Subject: [PATCH 4/4] Fixed interface in ll --- src/dcegm/likelihood.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 9cc5859a..2bc29f38 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -275,12 +275,12 @@ def create_choice_prob_func_unobserved_states( lambda x: jnp.asarray(x), weighting_vars_for_possible_states ) - def choice_prob_func(value_in, endog_grid_in, params_in, data_for_choice_funcs): + def choice_prob_func(value_in, endog_grid_in, params_in, data_from_observed): choice_probs_final = jnp.zeros(n_obs, dtype=jnp.float64) integrate_out_weights = jnp.zeros(n_obs, dtype=jnp.float64) for partial_choice_prob, data_for_choice_func, weighting_vars in zip( partial_choice_probs_unobserved_states, - data_for_choice_funcs, + data_from_observed, weighting_vars_for_possible_states, ): unobserved_weights = jax.vmap(