88import numpy as np
99from numba import njit
1010
11- import upper_envelope as upenv
11+ import upper_envelope .jax as ue_jax
12+ import upper_envelope .numba as ue_numba
1213
1314jax .config .update ("jax_enable_x64" , True )
1415
@@ -60,7 +61,7 @@ def value_func_jax(consumption, choice, params):
6061
6162
6263def fues_jax_partial (endog , pol , val , exp_val_zero ):
63- return upenv .fues_jax (
64+ return ue_jax .fues_jax (
6465 endog_grid = jnp .asarray (endog ),
6566 policy = jnp .asarray (pol ),
6667 value = jnp .asarray (val ),
@@ -103,7 +104,7 @@ def fues_jax_partial(endog, pol, val, exp_val_zero):
103104
104105
105106def drued_jorg_jax_partial (endog , pol , val , m_grid , exp_val_zero ):
106- return upenv .drued_jorg_jax (
107+ return ue_jax .drued_jorg_jax (
107108 endog_grid = endog ,
108109 policy = pol ,
109110 value = val ,
@@ -168,7 +169,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
168169# Numba FUES
169170start = time .time ()
170171jax .block_until_ready (
171- upenv .fues_numba (
172+ ue_numba .fues_numba (
172173 endog_grid = policy_egm [0 , 1 :],
173174 policy = policy_egm [1 , 1 :],
174175 value = value_egm [1 , 1 :],
@@ -184,7 +185,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
184185for _ in range (n_runs ):
185186 start = time .time ()
186187 jax .block_until_ready (
187- upenv .fues_numba (
188+ ue_numba .fues_numba (
188189 endog_grid = policy_egm [0 , 1 :],
189190 policy = policy_egm [1 , 1 :],
190191 value = value_egm [1 , 1 :],
@@ -201,7 +202,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
201202# Numba DRUED-JORG
202203start = time .time ()
203204jax .block_until_ready (
204- upenv .drued_jorg_numba (
205+ ue_numba .drued_jorg_numba (
205206 endog_grid = policy_egm [0 , 1 :],
206207 policy = policy_egm [1 , 1 :],
207208 value = value_egm [1 , 1 :],
@@ -218,7 +219,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
218219for _ in range (n_runs ):
219220 start = time .time ()
220221 jax .block_until_ready (
221- upenv .drued_jorg_numba (
222+ ue_numba .drued_jorg_numba (
222223 endog_grid = policy_egm [0 , 1 :],
223224 policy = policy_egm [1 , 1 :],
224225 value = value_egm [1 , 1 :],
0 commit comments