Skip to content

Commit f322853

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 307c079 commit f322853

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

src/upper_envelope/upper_jor_drued.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def upper_jor_drued(
3232
3333
Returns arrays with the convention that index 0 corresponds to zero wealth:
3434
``value_out[0] = expected_value_zero_savings`` and ``endog_out[0] = policy_out[0] = 0``.
35+
3536
"""
3637

3738
if value_function_kwargs is None:

tests/test_upper_jor_drued.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
non-degenerate reference segments (strictly increasing in `m`)
1515
3) interpolate the reference onto `m_grid` using only those safe segments
1616
4) compare `upper_jor_drued` to that reference on the masked points
17+
1718
"""
1819

1920
from pathlib import Path
@@ -27,12 +28,13 @@
2728

2829
import upper_envelope as upenv
2930

30-
3131
TEST_DIR = Path(__file__).parent
3232
TEST_RESOURCES_DIR = TEST_DIR / "resources"
3333

3434

35-
def utility_crra(consumption: jnp.ndarray, choice: int, params: Dict[str, float]) -> jnp.ndarray:
35+
def utility_crra(
36+
consumption: jnp.ndarray, choice: int, params: Dict[str, float]
37+
) -> jnp.ndarray:
3638
utility_consumption = (consumption ** (1 - params["rho"]) - 1) / (1 - params["rho"])
3739
utility = utility_consumption - (1 - choice) * params["delta"]
3840
return utility
@@ -45,9 +47,10 @@ def interpolate_on_safe_reference_segments(
4547
):
4648
"""Interpolate reference (ref_m, ref_y) onto m_grid, ignoring unsafe segments.
4749
48-
A "safe" segment is any adjacent pair (ref_m[i], ref_m[i+1]) with ref_m[i+1] > ref_m[i].
49-
For each x in m_grid, we take the maximum interpolated value over all safe segments
50-
covering x. This avoids ambiguity around duplicated ref_m values.
50+
A "safe" segment is any adjacent pair (ref_m[i], ref_m[i+1]) with ref_m[i+1] >
51+
ref_m[i]. For each x in m_grid, we take the maximum interpolated value over all safe
52+
segments covering x. This avoids ambiguity around duplicated ref_m values.
53+
5154
"""
5255

5356
dm = ref_m[1:] - ref_m[:-1]
@@ -92,7 +95,9 @@ def test_upper_jor_drued_matches_fues_on_safe_segments(period, setup_model):
9295

9396
def value_func(consumption, choice, params):
9497
# Same convention as existing tests: includes continuation value.
95-
return utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0]
98+
return (
99+
utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0]
100+
)
96101

97102
ref_m, ref_c, ref_v = upenv.fues_jax(
98103
endog_grid=jnp.asarray(policy_egm[0, 1:]),

0 commit comments

Comments
 (0)