@@ -333,6 +333,63 @@ def get_child_states_and_calc_trans_probs(self, state, choice, params):
333333 child_states_df ["trans_probs" ] = trans_probs
334334 return child_states_df
335335
336+ def get_full_child_states_by_asset_id_and_probs (
337+ self , state , choice , params , asset_id , second_continuous_id = None
338+ ):
339+ """Get the child states for a given state and choice and calculate the
340+ transition probabilities."""
341+ if "map_state_choice_to_child_states" not in self .model_structure :
342+ raise ValueError (
343+ "For this function the model needs to be created with debug_info='all'"
344+ )
345+
346+ child_idx = get_child_state_index_per_state_choice (
347+ states = state , choice = choice , model_structure = self .model_structure
348+ )
349+ state_space_dict = self .model_structure ["state_space_dict" ]
350+ discrete_states_names = self .model_structure ["discrete_states_names" ]
351+ child_states = {
352+ key : state_space_dict [key ][child_idx ] for key in discrete_states_names
353+ }
354+ child_states_df = pd .DataFrame (child_states )
355+
356+ child_continuous_states = self .compute_law_of_motions (params = params )
357+
358+ if "second_continuous" in child_continuous_states .keys ():
359+ if second_continuous_id is None :
360+ raise ValueError ("second_continuous_id must be provided." )
361+ else :
362+ quad_wealth = child_continuous_states ["assets_begin_of_period" ][
363+ child_idx , second_continuous_id , asset_id , :
364+ ]
365+ next_period_second_continuous = child_continuous_states [
366+ "second_continuous"
367+ ][child_idx , second_continuous_id ]
368+
369+ second_continuous_name = self .model_config ["continuous_states_info" ][
370+ "second_continuous_state_name"
371+ ]
372+ child_states_df [second_continuous_name ] = next_period_second_continuous
373+
374+ else :
375+ if second_continuous_id is not None :
376+ raise ValueError ("second_continuous_id must not be provided." )
377+ else :
378+ quad_wealth = child_continuous_states ["assets_begin_of_period" ][
379+ child_idx , asset_id , :
380+ ]
381+
382+ for id_quad in range (quad_wealth .shape [1 ]):
383+ child_states_df [f"assets_begin_of_period_quad_point_{ id_quad } " ] = (
384+ quad_wealth [:, id_quad ]
385+ )
386+
387+ trans_probs = self .model_funcs ["compute_stochastic_transition_vec" ](
388+ params = params , choice = choice , ** state
389+ )
390+ child_states_df ["trans_probs" ] = trans_probs
391+ return child_states_df
392+
336393 def compute_law_of_motions (self , params ):
337394 return calc_cont_grids_next_period (
338395 params = params ,
0 commit comments