Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,6 @@ jax_max.ipynb

# Debugging files should be ignored
*.png

# Local agent notes
AI_instructions/
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: check-useless-excludes
# - id: identity # Prints all files passed to pre-commits. Debugging.
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.37.1
rev: v1.38.0
hooks:
- id: yamllint
- repo: https://github.com/lyz-code/yamlfix
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
# args:
# - --py37-plus
- repo: https://github.com/pycqa/isort
rev: 7.0.0
rev: 8.0.1
hooks:
- id: isort
name: isort
Expand All @@ -59,7 +59,7 @@ repos:
# hooks:
# - id: setup-cfg-fmt
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
rev: 26.1.0
hooks:
- id: black
language_version: python3.13
Expand Down
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[project]
name = "dcegm"
description = "Python package for solving and simulating finite-horizon stochastic discrete-continuous dynamic choice models using the DC-EGM algorithm from Iskhakov, Jørgensen, Rust, and Schjerning (QE, 2017)."
version = "0.1.0.dev0"
dynamic = ["version"]
requires-python = ">=3.10"
dependencies = [
"numpy",
Expand Down Expand Up @@ -54,9 +54,12 @@ Github = "https://github.com/OpenSourceEconomics/dcegm"
# Build system configuration
# ======================================================================================
[build-system]
requires = ["hatchling"]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.hatch.build.hooks.vcs]
version-file = "src/dcegm/_version.py"

[tool.hatch.build.targets.sdist]
exclude = ["tests"]
only-packages = true
Expand All @@ -65,12 +68,12 @@ only-packages = true
only-include = ["src"]
sources = ["src"]

[tool.hatch.version]
source = "vcs"

[tool.hatch.metadata]
allow-direct-references = true

[tool.setuptools.package-data]
"dcegm" = ["templates/**/*"]


# ======================================================================================
# Misc configuration
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/interfaces/inspect_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def partially_solve(
n_assets_end_of_period = model_config["continuous_states_info"][
"assets_grid_end_of_period"
].shape[0]
(value_candidates, policy_candidates, endog_grid_candidates) = (
value_candidates, policy_candidates, endog_grid_candidates = (
create_solution_container(
continuous_states_info=model_config["continuous_states_info"],
n_total_wealth_grid=n_assets_end_of_period,
Expand Down
6 changes: 2 additions & 4 deletions src/dcegm/pre_processing/check_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,12 @@ def check_model_config_and_process(model_config):
n_assets_end_of_period * (1 + tuning_params["extra_wealth_grid_factor"])
< n_assets_end_of_period + tuning_params["n_constrained_points_to_add"]
):
raise ValueError(
f"""\n\n
raise ValueError(f"""\n\n
When preparing the tuning parameters for the upper
envelope, we found the following contradicting parameters: \n
The extra wealth grid factor of {tuning_params["extra_wealth_grid_factor"]} is too small
to cover the {tuning_params["n_constrained_points_to_add"]} wealth points which are added in
the credit constrained part of the wealth grid. \n\n"""
)
the credit constrained part of the wealth grid. \n\n""")
tuning_params["n_total_wealth_grid"] = int(
n_assets_end_of_period * (1 + tuning_params["extra_wealth_grid_factor"])
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jax import numpy as jnp
from upper_envelope import fues_jax
from upper_envelope.jax import fues_jax


def create_upper_envelope_function(model_config, continuous_state=None):
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def solve_single_period(
debug_info,
):
"""Solve a single period of the model using DCEGM."""
(value_solved, policy_solved, endog_grid_solved) = carry
value_solved, policy_solved, endog_grid_solved = carry

(
state_choices_idxs,
Expand Down
109 changes: 96 additions & 13 deletions src/dcegm/templates/simplemodel/run_example.ipynb

Large diffs are not rendered by default.

17 changes: 16 additions & 1 deletion src/dcegm/templates/simplemodel/run_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# packages needed
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import config

# import model_funcs
Expand Down Expand Up @@ -27,6 +28,7 @@
"exp_squared": -0.0002,
# Shock parameters of income
"income_shock_std": 0.35,
"income_shock_mean": 0.0,
"taste_shock_scale": 0.2,
"interest_rate": 0.05,
"consumption_floor": 0.001,
Expand Down Expand Up @@ -73,4 +75,17 @@
"assets_begin_of_period": jnp.ones(n_agents) * 10,
}

model_solved.simulate(states_initial=states_initial, seed=42)
sim_df = model_solved.simulate(states_initial=states_initial, seed=42)

sim_df.groupby("period").choice.value_counts().unstack().plot(
kind="bar",
stacked=True,
title="Choice by period",
xlabel="Period",
ylabel="Count",
figsize=(10, 5),
rot=0,
)
# label choices work and retire in legend
plt.legend(["Work", "Retire"])
plt.show()
1 change: 0 additions & 1 deletion tests/sandbox/jax_timeit_large_toy_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"import numpy as np\n",
"import time\n",
"\n",
"\n",
"TEST_RESOURCES_DIR = \"../resources/\""
]
},
Expand Down
37 changes: 20 additions & 17 deletions tests/sandbox/time_functions_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"def func_a(x, y):\n",
" return x + y\n",
"\n",
"\n",
"jax.vmap(func_a, in_axes=(0, None))(np.array([2]), 3)"
],
"id": "83f45f46db8be341",
Expand All @@ -53,7 +54,9 @@
}
},
"cell_type": "code",
"source": "isinstance(np.array(2), np.ndarray)",
"source": [
"isinstance(np.array(2), np.ndarray)"
],
"id": "d2b3690f1f318672",
"outputs": [
{
Expand Down Expand Up @@ -122,10 +125,10 @@
"evalue": "No module named 'tests'",
"output_type": "error",
"traceback": [
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
"\u001B[31mModuleNotFoundError\u001B[39m Traceback (most recent call last)",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[8]\u001B[39m\u001B[32m, line 9\u001B[39m\n\u001B[32m 7\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjax\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjnp\u001B[39;00m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnp\u001B[39;00m\n\u001B[32m----> \u001B[39m\u001B[32m9\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtests\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mutils\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mmarkov_simulator\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m markov_simulator\n",
"\u001B[31mModuleNotFoundError\u001B[39m: No module named 'tests'"
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mjax\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mjnp\u001b[39;00m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtests\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmarkov_simulator\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m markov_simulator\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'tests'"
]
}
],
Expand Down Expand Up @@ -305,19 +308,19 @@
"evalue": "len() of unsized object",
"output_type": "error",
"traceback": [
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
"\u001B[31mIndexError\u001B[39m Traceback (most recent call last)",
"\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1895\u001B[39m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mshape\u001B[49m\u001B[43m[\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n",
"\u001B[31mIndexError\u001B[39m: tuple index out of range",
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1895\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
"\u001b[31mIndexError\u001b[39m: tuple index out of range",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001B[31mTypeError\u001B[39m Traceback (most recent call last)",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f, x, y))\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m \u001B[43mjit_g\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtest_a\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtest_b\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4\u001B[39m jit_g_aux(test_a, test_b)\n",
" \u001B[31m[... skipping hidden 13 frame]\u001B[39m\n",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 1\u001B[39m, in \u001B[36m<lambda>\u001B[39m\u001B[34m(x, y)\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: \u001B[43mg\u001B[49m\u001B[43m(\u001B[49m\u001B[43mf\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m 3\u001B[39m jit_g(test_a, test_b)\n",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[4]\u001B[39m\u001B[32m, line 10\u001B[39m, in \u001B[36mg\u001B[39m\u001B[34m(func, x, y)\u001B[39m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mg\u001B[39m(func, x, y):\n\u001B[32m 9\u001B[39m func_val = func(x, y)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28;43mlen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfunc_val\u001B[49m\u001B[43m)\u001B[49m == \u001B[32m2\u001B[39m:\n\u001B[32m 11\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m func_val[\u001B[32m0\u001B[39m]\n\u001B[32m 12\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n",
" \u001B[31m[... skipping hidden 1 frame]\u001B[39m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m.shape[\u001B[32m0\u001B[39m]\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n\u001B[32m-> \u001B[39m\u001B[32m1898\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[33m\"\u001B[39m\u001B[33mlen() of unsized object\u001B[39m\u001B[33m\"\u001B[39m) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01merr\u001B[39;00m\n",
"\u001B[31mTypeError\u001B[39m: len() of unsized object"
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f, x, y))\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mjit_g\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_a\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_b\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m jit_g_aux(test_a, test_b)\n",
" \u001b[31m[... skipping hidden 13 frame]\u001b[39m\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 1\u001b[39m, in \u001b[36m<lambda>\u001b[39m\u001b[34m(x, y)\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: \u001b[43mg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m 3\u001b[39m jit_g(test_a, test_b)\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 10\u001b[39m, in \u001b[36mg\u001b[39m\u001b[34m(func, x, y)\u001b[39m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mg\u001b[39m(func, x, y):\n\u001b[32m 9\u001b[39m func_val = func(x, y)\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfunc_val\u001b[49m\u001b[43m)\u001b[49m == \u001b[32m2\u001b[39m:\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m func_val[\u001b[32m0\u001b[39m]\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
" \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.shape[\u001b[32m0\u001b[39m]\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[32m-> \u001b[39m\u001b[32m1898\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mlen() of unsized object\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01merr\u001b[39;00m\n",
"\u001b[31mTypeError\u001b[39m: len() of unsized object"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sparse_stochastic_and_batch_sep.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_benchmark_models():
state: state_choices_sparse[:, id]
for id, state in enumerate(discrete_states_names + ["choice"])
}
(endog_grid_full, policy_full, value_full) = (
endog_grid_full, policy_full, value_full = (
model_solved_full.get_solution_for_discrete_state_choice(
states=states_dict, choices=state_choices_sparse[:, -1]
)
Expand Down
Loading