1414 non-degenerate reference segments (strictly increasing in `m`)
15153) interpolate the reference onto `m_grid` using only those safe segments
16164) compare `upper_jor_drued` to that reference on the masked points
17+
1718"""
1819
1920from pathlib import Path
2728
2829import upper_envelope as upenv
2930
30-
3131TEST_DIR = Path (__file__ ).parent
3232TEST_RESOURCES_DIR = TEST_DIR / "resources"
3333
3434
35- def utility_crra (consumption : jnp .ndarray , choice : int , params : Dict [str , float ]) -> jnp .ndarray :
35+ def utility_crra (
36+ consumption : jnp .ndarray , choice : int , params : Dict [str , float ]
37+ ) -> jnp .ndarray :
3638 utility_consumption = (consumption ** (1 - params ["rho" ]) - 1 ) / (1 - params ["rho" ])
3739 utility = utility_consumption - (1 - choice ) * params ["delta" ]
3840 return utility
@@ -45,9 +47,10 @@ def interpolate_on_safe_reference_segments(
4547):
4648 """Interpolate reference (ref_m, ref_y) onto m_grid, ignoring unsafe segments.
4749
48- A "safe" segment is any adjacent pair (ref_m[i], ref_m[i+1]) with ref_m[i+1] > ref_m[i].
49- For each x in m_grid, we take the maximum interpolated value over all safe segments
50- covering x. This avoids ambiguity around duplicated ref_m values.
50+ A "safe" segment is any adjacent pair (ref_m[i], ref_m[i+1]) with ref_m[i+1] >
51+ ref_m[i]. For each x in m_grid, we take the maximum interpolated value over all safe
52+ segments covering x. This avoids ambiguity around duplicated ref_m values.
53+
5154 """
5255
5356 dm = ref_m [1 :] - ref_m [:- 1 ]
@@ -92,7 +95,9 @@ def test_upper_jor_drued_matches_fues_on_safe_segments(period, setup_model):
9295
9396 def value_func (consumption , choice , params ):
9497 # Same convention as existing tests: includes continuation value.
95- return utility_crra (consumption , choice , params ) + params ["beta" ] * value_egm [1 , 0 ]
98+ return (
99+ utility_crra (consumption , choice , params ) + params ["beta" ] * value_egm [1 , 0 ]
100+ )
96101
97102 ref_m , ref_c , ref_v = upenv .fues_jax (
98103 endog_grid = jnp .asarray (policy_egm [0 , 1 :]),
0 commit comments