diff --git a/src/dcegm/interfaces/jit_large_arrays.py b/src/dcegm/interfaces/jit_large_arrays.py new file mode 100644 index 00000000..682f3ff8 --- /dev/null +++ b/src/dcegm/interfaces/jit_large_arrays.py @@ -0,0 +1,66 @@ +def split_structure_and_batch_info(model_structure, batch_info): + """Splits the model structure and batch info into static parts, which we can not jit + compile and (large) arrays that we want to include in the function call for + jitting.""" + + struct_keys_not_for_jit = [ + "discrete_states_names", + "state_names_without_stochastic", + "stochastic_states_names", + ] + model_structure_non_jit = { + key: model_structure[key] for key in struct_keys_not_for_jit + } + model_structure_jit = model_structure.copy() + # Remove non-jittable items + for key in struct_keys_not_for_jit: + model_structure_jit.pop(key, None) + + # Remove non-jittable items from batch_info + batch_info_jit = batch_info.copy() + batch_info_non_jit = { + "two_period_model": batch_info["two_period_model"], + } + batch_info_jit.pop("two_period_model", None) + # If it is not a two period model, there is more + if not batch_info["two_period_model"]: + batch_info_non_jit["n_segments"] = batch_info["n_segments"] + batch_info_jit.pop("n_segments", None) + for batch_id in range(batch_info_non_jit["n_segments"]): + batch_key = f"batches_info_segment_{batch_id}" + batch_info_non_jit[batch_key] = {} + batch_info_non_jit[batch_key]["batches_cover_all"] = batch_info[batch_key][ + "batches_cover_all" + ] + batch_info_jit[batch_key].pop("batches_cover_all", None) + + return ( + model_structure_jit, + batch_info_jit, + model_structure_non_jit, + batch_info_non_jit, + ) + + +def merge_non_jit_and_jit_model_structure(model_structure_jit, model_structure_non_jit): + """Generate one model_structure to handle inside the package functions.""" + model_structure = { + **model_structure_jit, + **model_structure_non_jit, + } + return model_structure + + +def merg_non_jit_batch_info_and_jit_batch_info(batch_info_jit, batch_info_non_jit): + batch_info = { + **batch_info_jit, + "two_period_model": batch_info_non_jit["two_period_model"], + } + if not batch_info_non_jit["two_period_model"]: + batch_info["n_segments"] = batch_info_non_jit["n_segments"] + for batch_id in range(batch_info_non_jit["n_segments"]): + batch_key = f"batches_info_segment_{batch_id}" + batch_info[batch_key]["batches_cover_all"] = batch_info_non_jit[batch_key][ + "batches_cover_all" + ] + return batch_info diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index a753899b..ea72ce01 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -1,5 +1,6 @@ import pickle as pkl from functools import partial +from grp import struct_group from typing import Callable, Dict import jax @@ -15,6 +16,11 @@ get_n_state_choice_period, validate_stochastic_transition, ) +from dcegm.interfaces.jit_large_arrays import ( + merg_non_jit_batch_info_and_jit_batch_info, + merge_non_jit_and_jit_model_structure, + split_structure_and_batch_info, +) from dcegm.interfaces.sol_interface import model_solved from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.likelihood import create_individual_likelihood_function @@ -124,51 +130,6 @@ def __init__( else: self.alternative_sim_funcs = None - def set_alternative_sim_funcs( - self, alternative_sim_specifications, alternative_specs=None - ): - if alternative_specs is None: - self.alternative_sim_specs = self.model_specs - alternative_specs_without_jax = self.specs_without_jax - else: - self.alternative_sim_specs = jax.tree_util.tree_map( - try_jax_array, alternative_specs - ) - alternative_specs_without_jax = alternative_specs - - alternative_sim_funcs = generate_alternative_sim_functions( - model_specs=alternative_specs_without_jax, - model_specs_jax=self.alternative_sim_specs, - **alternative_sim_specifications, - ) - self.alternative_sim_funcs = alternative_sim_funcs - - def backward_induction_inner_jit(self, params): - return backward_induction( - params=params, - income_shock_draws_unscaled=self.income_shock_draws_unscaled, - income_shock_weights=self.income_shock_weights, - model_config=self.model_config, - batch_info=self.batch_info, - model_funcs=self.model_funcs, - model_structure=self.model_structure, - ) - - def get_fast_solve_func(self): - backward_jit = jax.jit( - partial( - backward_induction, - income_shock_draws_unscaled=self.income_shock_draws_unscaled, - income_shock_weights=self.income_shock_weights, - model_config=self.model_config, - batch_info=self.batch_info, - model_funcs=self.model_funcs, - model_structure=self.model_structure, - ) - ) - - return backward_jit - def solve(self, params, load_sol_path=None, save_sol_path=None): """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm. @@ -198,8 +159,14 @@ def solve(self, params, load_sol_path=None, save_sol_path=None): if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed + value, policy, endog_grid = backward_induction( + params=params_processed, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + batch_info=self.batch_info, ) sol_dict = { "value": value, @@ -245,8 +212,14 @@ def solve_and_simulate( if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed + value, policy, endog_grid = backward_induction( + params=params_processed, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + batch_info=self.batch_info, ) sol_dict = { @@ -274,14 +247,69 @@ def solve_and_simulate( sim_df = create_simulation_df(sim_dict) return sim_df + def get_solve_func(self): + """Create a fast function for solving that is jit compiled in the first call.""" + + ( + model_structure_for_jit, + batch_info_for_jit, + model_structure_non_jit, + batch_info_non_jit, + ) = split_structure_and_batch_info(self.model_structure, self.batch_info) + + def solve_function_to_jit(params, model_structure_jit, batch_info_jit): + params_processed = process_params(params, self.params_check_info) + + # Merge back parts together. The non_jit objects are fixed in the closure. + model_structure = merge_non_jit_and_jit_model_structure( + model_structure_jit, model_structure_non_jit + ) + batch_info = merg_non_jit_batch_info_and_jit_batch_info( + batch_info_jit, batch_info_non_jit + ) + + # Solve the model. + value, policy, endog_grid = backward_induction( + params=params_processed, + model_structure=model_structure, + batch_info=batch_info, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, + ) + + return value, policy, endog_grid + + solve_func = jax.jit(solve_function_to_jit) + + # Generate the function. The user only needs to provide params, but we call with the objects for jit. + def solve_function(params): + """Solve the model for given params.""" + value, policy, endog_grid = solve_func( + params, model_structure_for_jit, batch_info_for_jit + ) + model_solved_class = model_solved( + model=self, + params=params, + value=value, + policy=policy, + endog_grid=endog_grid, + ) + return model_solved_class + + return solve_function + def get_solve_and_simulate_func( self, states_initial, seed, - slow_version=False, ): + """Create a fast function for solving and simulation that is jit compiled in the + first call.""" - sim_func = lambda params, value, policy, endog_gid: simulate_all_periods( + # Fix everything except params, solution of the model and model_structure which contains large arrays. + sim_func = lambda params, value, policy, endog_gid, model_structure: simulate_all_periods( states_initial=states_initial, n_periods=self.model_config["n_periods"], params=params, @@ -290,16 +318,40 @@ def get_solve_and_simulate_func( policy_solved=policy, value_solved=value, model_config=self.model_config, - model_structure=self.model_structure, + model_structure=model_structure, model_funcs=self.model_funcs, alt_model_funcs_sim=self.alternative_sim_funcs, ) - def solve_and_simulate_function_to_jit(params): + ( + model_structure_for_jit, + batch_info_for_jit, + model_structure_non_jit, + batch_info_non_jit, + ) = split_structure_and_batch_info(self.model_structure, self.batch_info) + + def solve_and_simulate_function_to_jit( + params, model_structure_jit, batch_info_jit + ): params_processed = process_params(params, self.params_check_info) - # Solve the model - value, policy, endog_grid = self.backward_induction_inner_jit( - params_processed + + # Merge back parts together. The non_jit objects are fixed in the closure. + model_structure = merge_non_jit_and_jit_model_structure( + model_structure_jit, model_structure_non_jit + ) + batch_info = merg_non_jit_batch_info_and_jit_batch_info( + batch_info_jit, batch_info_non_jit + ) + + # Solve the model. + value, policy, endog_grid = backward_induction( + params=params_processed, + model_structure=model_structure, + batch_info=batch_info, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + model_funcs=self.model_funcs, ) sim_dict = sim_func( @@ -307,17 +359,18 @@ def solve_and_simulate_function_to_jit(params): value=value, policy=policy, endog_gid=endog_grid, + model_structure=model_structure, ) return sim_dict - if slow_version: - solve_simulate_func = solve_and_simulate_function_to_jit - else: - solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) + solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) + # Generate the function. The user only needs to provide params, but we call with the objects for jit. def solve_and_simulate_function(params): - sim_dict = solve_simulate_func(params) + sim_dict = solve_simulate_func( + params, model_structure_for_jit, batch_info_for_jit + ) df = create_simulation_df(sim_dict) return df @@ -335,11 +388,13 @@ def create_experimental_ll_func( ): return create_individual_likelihood_function( + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + batch_info=self.batch_info, model_structure=self.model_structure, model_config=self.model_config, model_funcs=self.model_funcs, model_specs=self.model_specs, - backwards_induction_inner_jit=self.backward_induction_inner_jit, observed_states=observed_states, observed_choices=observed_choices, params_all=params_all, @@ -475,3 +530,22 @@ def solve_partially(self, params, n_periods, return_candidates=False): n_periods=n_periods, return_candidates=return_candidates, ) + + def set_alternative_sim_funcs( + self, alternative_sim_specifications, alternative_specs=None + ): + if alternative_specs is None: + self.alternative_sim_specs = self.model_specs + alternative_specs_without_jax = self.specs_without_jax + else: + self.alternative_sim_specs = jax.tree_util.tree_map( + try_jax_array, alternative_specs + ) + alternative_specs_without_jax = alternative_specs + + alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=alternative_specs_without_jax, + model_specs_jax=self.alternative_sim_specs, + **alternative_sim_specifications, + ) + self.alternative_sim_funcs = alternative_sim_funcs diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 3c9c66bc..2bc29f38 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -12,19 +12,27 @@ import numpy as np from jax import vmap +from dcegm.backward_induction import backward_induction from dcegm.egm.aggregate_marginal_utility import ( calculate_choice_probs_and_unsqueezed_logsum, ) from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.interfaces.interface import choice_values_for_states +from dcegm.interfaces.jit_large_arrays import ( + merg_non_jit_batch_info_and_jit_batch_info, + merge_non_jit_and_jit_model_structure, + split_structure_and_batch_info, +) def create_individual_likelihood_function( + income_shock_draws_unscaled, + income_shock_weights, + batch_info, model_structure, model_config, model_funcs, model_specs, - backwards_induction_inner_jit, observed_states: Dict[str, int], observed_choices, params_all, @@ -34,7 +42,7 @@ def create_individual_likelihood_function( slow_version=False, ): - choice_prob_func = create_choice_prob_function( + choice_prob_func, data_from_observed_states = create_choice_prob_function( model_structure=model_structure, model_config=model_config, model_funcs=model_funcs, @@ -46,16 +54,40 @@ def create_individual_likelihood_function( return_weight_func=False, ) - def individual_likelihood(params): + ( + model_structure_for_jit, + batch_info_for_jit, + model_structure_non_jit, + batch_info_non_jit, + ) = split_structure_and_batch_info(model_structure, batch_info) + + def individual_likelihood_to_jit(params, model_structure_jit, batch_info_jit): params_update = params_all.copy() params_update.update(params) - value, policy, endog_grid = backwards_induction_inner_jit(params_update) + # Merge back parts together. The non_jit objects are fixed in the closure. + model_structure_merged = merge_non_jit_and_jit_model_structure( + model_structure_jit, model_structure_non_jit + ) + batch_info_merged = merg_non_jit_batch_info_and_jit_batch_info( + batch_info_jit, batch_info_non_jit + ) + + value, policy, endog_grid = backward_induction( + params=params_update, + income_shock_draws_unscaled=income_shock_draws_unscaled, + income_shock_weights=income_shock_weights, + model_config=model_config, + model_funcs=model_funcs, + model_structure=model_structure_merged, + batch_info=batch_info_merged, + ) choice_probs = choice_prob_func( value_in=value, endog_grid_in=endog_grid, params_in=params_update, + data_from_observed=data_from_observed_states, ) # Negative ll contributions are positive numbers. The smaller the better the fit # Add high fixed punishment for not explained choices @@ -72,9 +104,18 @@ def individual_likelihood(params): return neg_likelihood_contributions if slow_version: - return individual_likelihood + likelihood_function_int = individual_likelihood_to_jit else: - return jax.jit(individual_likelihood) + likelihood_function_int = jax.jit(individual_likelihood_to_jit) + + def likelihood_function(params): + return likelihood_function_int( + params=params, + model_structure_jit=model_structure_for_jit, + batch_info_jit=batch_info_for_jit, + ) + + return likelihood_function def create_choice_prob_function( @@ -89,28 +130,32 @@ def create_choice_prob_function( return_weight_func, ): if unobserved_state_specs is None: - choice_prob_func = create_partial_choice_prob_calculation( - observed_states=observed_states, - observed_choices=observed_choices, - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, + choice_prob_func, data_from_observed_states = ( + create_partial_choice_prob_calculation( + observed_states=observed_states, + observed_choices=observed_choices, + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + ) ) else: - choice_prob_func = create_choice_prob_func_unobserved_states( - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - model_specs=model_specs, - observed_states=observed_states, - observed_choices=observed_choices, - unobserved_state_specs=unobserved_state_specs, - use_probability_of_observed_states=use_probability_of_observed_states, - return_weight_func=return_weight_func, + choice_prob_func, data_from_observed_states = ( + create_choice_prob_func_unobserved_states( + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + model_specs=model_specs, + observed_states=observed_states, + observed_choices=observed_choices, + unobserved_state_specs=unobserved_state_specs, + use_probability_of_observed_states=use_probability_of_observed_states, + return_weight_func=return_weight_func, + ) ) - return choice_prob_func + return choice_prob_func, data_from_observed_states def create_choice_prob_func_unobserved_states( @@ -201,16 +246,18 @@ def create_choice_prob_func_unobserved_states( # Create a list of partial choice probability functions for each unique # combination of unobserved states. partial_choice_probs_unobserved_states = [] + data_for_unobserved_states = [] for states in possible_states: - partial_choice_probs_unobserved_states.append( - create_partial_choice_prob_calculation( - observed_states=states, - observed_choices=observed_choices, - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - ) + choice_func, data = create_partial_choice_prob_calculation( + observed_states=states, + observed_choices=observed_choices, + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, ) + partial_choice_probs_unobserved_states.append(choice_func) + data_for_unobserved_states.append(data) + partial_weight_func = ( lambda params_in, weight_vars: calculate_weights_for_each_state( params=params_in, @@ -228,12 +275,12 @@ def create_choice_prob_func_unobserved_states( lambda x: jnp.asarray(x), weighting_vars_for_possible_states ) - def choice_prob_func(value_in, endog_grid_in, params_in): + def choice_prob_func(value_in, endog_grid_in, params_in, data_from_observed): choice_probs_final = jnp.zeros(n_obs, dtype=jnp.float64) integrate_out_weights = jnp.zeros(n_obs, dtype=jnp.float64) - for partial_choice_prob, unobserved_state, weighting_vars in zip( + for partial_choice_prob, data_for_choice_func, weighting_vars in zip( partial_choice_probs_unobserved_states, - possible_states, + data_from_observed, weighting_vars_for_possible_states, ): unobserved_weights = jax.vmap( @@ -248,6 +295,7 @@ def choice_prob_func(value_in, endog_grid_in, params_in): value_in=value_in, endog_grid_in=endog_grid_in, params_in=params_in, + data_from_observed=data_for_choice_func, ) weighted_choice_prob = jnp.nan_to_num( @@ -266,11 +314,7 @@ def choice_prob_func(value_in, endog_grid_in, params_in): def weight_only_func(params_in): weights = np.zeros((n_obs, len(possible_states)), dtype=np.float64) count = 0 - for partial_choice_prob, unobserved_state, weighting_vars in zip( - partial_choice_probs_unobserved_states, - possible_states, - weighting_vars_for_possible_states, - ): + for weighting_vars in weighting_vars_for_possible_states: unobserved_weights = jax.vmap( partial_weight_func, in_axes=(None, 0), @@ -289,9 +333,9 @@ def weight_only_func(params_in): ) if return_weight_func: - return choice_prob_func, weight_only_func + return choice_prob_func, weight_only_func, data_for_unobserved_states else: - return choice_prob_func + return choice_prob_func, data_for_unobserved_states def create_partial_choice_prob_calculation( @@ -309,19 +353,27 @@ def create_partial_choice_prob_calculation( discrete_states_names=model_structure["discrete_states_names"], ) - def partial_choice_prob_func(value_in, endog_grid_in, params_in): + data_from_observed_wrapped = ( + observed_states, + observed_choices, + discrete_observed_state_choice_indexes, + ) + + def partial_choice_prob_func( + value_in, endog_grid_in, params_in, data_from_observed + ): return calc_choice_prob_for_state_choices( value_solved=value_in, endog_grid_solved=endog_grid_in, params=params_in, - states=observed_states, - choices=observed_choices, - state_choice_indexes=discrete_observed_state_choice_indexes, + states=data_from_observed[0], + choices=data_from_observed[1], + state_choice_indexes=data_from_observed[2], model_config=model_config, model_funcs=model_funcs, ) - return partial_choice_prob_func + return partial_choice_prob_func, data_from_observed_wrapped def calc_choice_prob_for_state_choices( diff --git a/tests/test_partial_and_interfaces.py b/tests/test_partial_and_interfaces.py index 563cfca5..4ecd52fc 100644 --- a/tests/test_partial_and_interfaces.py +++ b/tests/test_partial_and_interfaces.py @@ -76,9 +76,7 @@ def test_partial_solve_func(): ] aaae(model_solved.policy[:, 5], policy_choices) - value_solved_fast, policy_solved_fast, endog_grid_solved_fast = ( - model.get_fast_solve_func()(params) - ) - aaae(model_solved.value, value_solved_fast) - aaae(model_solved.policy, policy_solved_fast) - aaae(model_solved.endog_grid, endog_grid_solved_fast) + model_solved_fast = model.get_solve_func()(params) + aaae(model_solved.value, model_solved_fast.value) + aaae(model_solved.policy, model_solved_fast.policy) + aaae(model_solved.endog_grid, model_solved_fast.endog_grid)