Skip to content

Commit d889e5e

Browse files
MaxBleschsegsell
andauthored
JAx and numba new (#21)
Co-authored-by: Sebastian Gsell <sebastian.gsell93@gmail.com>
1 parent 85830bd commit d889e5e

18 files changed

Lines changed: 63 additions & 55 deletions

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
- id: check-useless-excludes
77
# - id: identity # Prints all files passed to pre-commits. Debugging.
88
- repo: https://github.com/adrienverge/yamllint.git
9-
rev: v1.37.1
9+
rev: v1.38.0
1010
hooks:
1111
- id: yamllint
1212
- repo: https://github.com/lyz-code/yamlfix
@@ -42,7 +42,7 @@ repos:
4242
- id: python-use-type-annotations
4343
- id: text-unicode-replacement-char
4444
- repo: https://github.com/pycqa/isort
45-
rev: 7.0.0
45+
rev: 8.0.1
4646
hooks:
4747
- id: isort
4848
name: isort
@@ -55,13 +55,13 @@ repos:
5555
# args:
5656
# - --py37-plus
5757
- repo: https://github.com/psf/black-pre-commit-mirror
58-
rev: 25.12.0
58+
rev: 26.1.0
5959
hooks:
6060
- id: black
6161
language_version: python3.12
6262
exclude: tests/utils/fast_upper_envelope_org.py
6363
- repo: https://github.com/astral-sh/ruff-pre-commit
64-
rev: v0.14.10
64+
rev: v0.15.4
6565
hooks:
6666
- id: ruff
6767
# exclude: |

docs/time_period2_ops.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import numpy as np
99
from numba import njit
1010

11-
import upper_envelope as upenv
11+
import upper_envelope.jax as ue_jax
12+
import upper_envelope.numba as ue_numba
1213

1314
jax.config.update("jax_enable_x64", True)
1415

@@ -60,7 +61,7 @@ def value_func_jax(consumption, choice, params):
6061

6162

6263
def fues_jax_partial(endog, pol, val, exp_val_zero):
63-
return upenv.fues_jax(
64+
return ue_jax.fues_jax(
6465
endog_grid=jnp.asarray(endog),
6566
policy=jnp.asarray(pol),
6667
value=jnp.asarray(val),
@@ -103,7 +104,7 @@ def fues_jax_partial(endog, pol, val, exp_val_zero):
103104

104105

105106
def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
106-
return upenv.drued_jorg_jax(
107+
return ue_jax.drued_jorg_jax(
107108
endog_grid=endog,
108109
policy=pol,
109110
value=val,
@@ -168,7 +169,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
168169
# Numba FUES
169170
start = time.time()
170171
jax.block_until_ready(
171-
upenv.fues_numba(
172+
ue_numba.fues_numba(
172173
endog_grid=policy_egm[0, 1:],
173174
policy=policy_egm[1, 1:],
174175
value=value_egm[1, 1:],
@@ -184,7 +185,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
184185
for _ in range(n_runs):
185186
start = time.time()
186187
jax.block_until_ready(
187-
upenv.fues_numba(
188+
ue_numba.fues_numba(
188189
endog_grid=policy_egm[0, 1:],
189190
policy=policy_egm[1, 1:],
190191
value=value_egm[1, 1:],
@@ -201,7 +202,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
201202
# Numba DRUED-JORG
202203
start = time.time()
203204
jax.block_until_ready(
204-
upenv.drued_jorg_numba(
205+
ue_numba.drued_jorg_numba(
205206
endog_grid=policy_egm[0, 1:],
206207
policy=policy_egm[1, 1:],
207208
value=value_egm[1, 1:],
@@ -218,7 +219,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
218219
for _ in range(n_runs):
219220
start = time.time()
220221
jax.block_until_ready(
221-
upenv.drued_jorg_numba(
222+
ue_numba.drued_jorg_numba(
222223
endog_grid=policy_egm[0, 1:],
223224
policy=policy_egm[1, 1:],
224225
value=value_egm[1, 1:],

docs/tutorials/ue_drued_jorg.ipynb

Lines changed: 11 additions & 11 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
# Project metadata
33
# ======================================================================================
44
[project]
5-
name = "upper_envelope"
6-
description = "Upper envelope scan for dynamic discrete-continuous life cycle models."
7-
version = "0.1.3"
8-
requires-python = ">=3.10"
5+
name = "upper_envelope"
6+
description = "Upper envelope scan for dynamic discrete-continuous life cycle models."
7+
dynamic = ["version"]
8+
requires-python = ">=3.10"
99
dependencies = [
1010
"numpy",
11-
"pandas",
12-
"scipy",
13-
"jax"
1411
]
1512
keywords = [
1613
"Dynamic programming",
@@ -32,7 +29,7 @@ classifiers = [
3229
"Topic :: Scientific/Engineering",
3330
]
3431
authors = [
35-
{ name="Max Blesch", email="maximilian.blesch@hu-berlin.de" },
32+
{ name="Max Blesch", email="maxblesch@gmail.com" },
3633
{ name="Sebastian Gsell", email="gsell.sebastian@gmail.com" },
3734
]
3835
maintainers = [
@@ -55,7 +52,7 @@ Github = "https://github.com/OpensourceEconomics/upper-envelope"
5552
# ======================================================================================
5653

5754
[build-system]
58-
requires = ["hatchling", "hatch_vcs"]
55+
requires = ["hatchling", "hatch-vcs"]
5956
build-backend = "hatchling.build"
6057

6158
[tool.hatch.build.hooks.vcs]

src/upper_envelope/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +0,0 @@
1-
from upper_envelope.drued_jorg_jax import drued_jorg_jax
2-
from upper_envelope.drued_jorg_numba import drued_jorg_numba
3-
from upper_envelope.fues_jax.fues_jax import fues_jax, fues_jax_unconstrained
4-
from upper_envelope.fues_numba.fues_numba import fues_numba, fues_numba_unconstrained

src/upper_envelope/jax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from upper_envelope.jax.drued_jorg_jax import drued_jorg_jax
2+
from upper_envelope.jax.fues_jax.fues_jax import fues_jax, fues_jax_unconstrained
File renamed without changes.
File renamed without changes.

src/upper_envelope/fues_jax/fues_jax.py renamed to src/upper_envelope/jax/fues_jax/fues_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import jax.numpy as jnp
1414
from jax import vmap
1515

16-
from upper_envelope.fues_jax.check_and_scan_funcs import (
16+
from upper_envelope.jax.fues_jax.check_and_scan_funcs import (
1717
determine_cases_and_conduct_necessary_scans,
1818
)
1919
from upper_envelope.math_funcs import calc_intersection_and_extrapolate_policy

0 commit comments

Comments
 (0)