diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80e662a7..a4b9c958 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: check-useless-excludes # - id: identity # Prints all files passed to pre-commits. Debugging. - repo: https://github.com/adrienverge/yamllint.git - rev: v1.37.1 + rev: v1.38.0 hooks: - id: yamllint - repo: https://github.com/lyz-code/yamlfix @@ -48,7 +48,7 @@ repos: # args: # - --py37-plus - repo: https://github.com/pycqa/isort - rev: 7.0.0 + rev: 8.0.1 hooks: - id: isort name: isort @@ -59,7 +59,7 @@ repos: # hooks: # - id: setup-cfg-fmt - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.12.0 + rev: 26.1.0 hooks: - id: black language_version: python3.13 diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py index 0bac8b28..52c9ad5f 100644 --- a/src/dcegm/interfaces/inspect_solution.py +++ b/src/dcegm/interfaces/inspect_solution.py @@ -70,7 +70,7 @@ def partially_solve( n_assets_end_of_period = model_config["continuous_states_info"][ "assets_grid_end_of_period" ].shape[0] - (value_candidates, policy_candidates, endog_grid_candidates) = ( + value_candidates, policy_candidates, endog_grid_candidates = ( create_solution_container( continuous_states_info=model_config["continuous_states_info"], n_total_wealth_grid=n_assets_end_of_period, diff --git a/src/dcegm/pre_processing/check_model_config.py b/src/dcegm/pre_processing/check_model_config.py index 47083118..718a9133 100644 --- a/src/dcegm/pre_processing/check_model_config.py +++ b/src/dcegm/pre_processing/check_model_config.py @@ -131,14 +131,12 @@ def check_model_config_and_process(model_config): n_assets_end_of_period * (1 + tuning_params["extra_wealth_grid_factor"]) < n_assets_end_of_period + tuning_params["n_constrained_points_to_add"] ): - raise ValueError( - f"""\n\n + raise ValueError(f"""\n\n When preparing the tuning parameters for the upper envelope, we found the following contradicting parameters: \n The extra wealth grid factor of {tuning_params["extra_wealth_grid_factor"]} is too small to cover the {tuning_params["n_constrained_points_to_add"]} wealth points which are added in - the credit constrained part of the wealth grid. \n\n""" - ) + the credit constrained part of the wealth grid. \n\n""") tuning_params["n_total_wealth_grid"] = int( n_assets_end_of_period * (1 + tuning_params["extra_wealth_grid_factor"]) ) diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index 1ee2be5c..7cf40435 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -18,7 +18,7 @@ def solve_single_period( debug_info, ): """Solve a single period of the model using DCEGM.""" - (value_solved, policy_solved, endog_grid_solved) = carry + value_solved, policy_solved, endog_grid_solved = carry ( state_choices_idxs, diff --git a/tests/sandbox/jax_timeit_large_toy_model.ipynb b/tests/sandbox/jax_timeit_large_toy_model.ipynb index 447338d6..10f0f1ff 100644 --- a/tests/sandbox/jax_timeit_large_toy_model.ipynb +++ b/tests/sandbox/jax_timeit_large_toy_model.ipynb @@ -25,7 +25,6 @@ "import numpy as np\n", "import time\n", "\n", - "\n", "TEST_RESOURCES_DIR = \"../resources/\"" ] }, diff --git a/tests/sandbox/time_functions_jax.ipynb b/tests/sandbox/time_functions_jax.ipynb index c642478d..74b72134 100644 --- a/tests/sandbox/time_functions_jax.ipynb +++ b/tests/sandbox/time_functions_jax.ipynb @@ -28,6 +28,7 @@ "def func_a(x, y):\n", " return x + y\n", "\n", + "\n", "jax.vmap(func_a, in_axes=(0, None))(np.array([2]), 3)" ], "id": "83f45f46db8be341", @@ -53,7 +54,9 @@ } }, "cell_type": "code", - "source": "isinstance(np.array(2), np.ndarray)", + "source": [ + "isinstance(np.array(2), np.ndarray)" + ], "id": "d2b3690f1f318672", "outputs": [ { @@ -122,10 +125,10 @@ "evalue": "No module named 'tests'", "output_type": "error", "traceback": [ - "\u001B[31m---------------------------------------------------------------------------\u001B[39m", - "\u001B[31mModuleNotFoundError\u001B[39m Traceback (most recent call last)", - "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[8]\u001B[39m\u001B[32m, line 9\u001B[39m\n\u001B[32m 7\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjax\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjnp\u001B[39;00m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnp\u001B[39;00m\n\u001B[32m----> \u001B[39m\u001B[32m9\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtests\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mutils\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mmarkov_simulator\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m markov_simulator\n", - "\u001B[31mModuleNotFoundError\u001B[39m: No module named 'tests'" + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mjax\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mjnp\u001b[39;00m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtests\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmarkov_simulator\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m markov_simulator\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'tests'" ] } ], @@ -305,19 +308,19 @@ "evalue": "len() of unsized object", "output_type": "error", "traceback": [ - "\u001B[31m---------------------------------------------------------------------------\u001B[39m", - "\u001B[31mIndexError\u001B[39m Traceback (most recent call last)", - "\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1895\u001B[39m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mshape\u001B[49m\u001B[43m[\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n", - "\u001B[31mIndexError\u001B[39m: tuple index out of range", + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1895\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", + "\u001b[31mIndexError\u001b[39m: tuple index out of range", "\nThe above exception was the direct cause of the following exception:\n", - "\u001B[31mTypeError\u001B[39m Traceback (most recent call last)", - "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f, x, y))\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m \u001B[43mjit_g\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtest_a\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtest_b\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4\u001B[39m jit_g_aux(test_a, test_b)\n", - " \u001B[31m[... skipping hidden 13 frame]\u001B[39m\n", - "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 1\u001B[39m, in \u001B[36m\u001B[39m\u001B[34m(x, y)\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: \u001B[43mg\u001B[49m\u001B[43m(\u001B[49m\u001B[43mf\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m 3\u001B[39m jit_g(test_a, test_b)\n", - "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[4]\u001B[39m\u001B[32m, line 10\u001B[39m, in \u001B[36mg\u001B[39m\u001B[34m(func, x, y)\u001B[39m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mg\u001B[39m(func, x, y):\n\u001B[32m 9\u001B[39m func_val = func(x, y)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28;43mlen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfunc_val\u001B[49m\u001B[43m)\u001B[49m == \u001B[32m2\u001B[39m:\n\u001B[32m 11\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m func_val[\u001B[32m0\u001B[39m]\n\u001B[32m 12\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n", - " \u001B[31m[... skipping hidden 1 frame]\u001B[39m\n", - "\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m.shape[\u001B[32m0\u001B[39m]\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n\u001B[32m-> \u001B[39m\u001B[32m1898\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[33m\"\u001B[39m\u001B[33mlen() of unsized object\u001B[39m\u001B[33m\"\u001B[39m) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01merr\u001B[39;00m\n", - "\u001B[31mTypeError\u001B[39m: len() of unsized object" + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f, x, y))\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mjit_g\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_a\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_b\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m jit_g_aux(test_a, test_b)\n", + " \u001b[31m[... skipping hidden 13 frame]\u001b[39m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 1\u001b[39m, in \u001b[36m\u001b[39m\u001b[34m(x, y)\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: \u001b[43mg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m 3\u001b[39m jit_g(test_a, test_b)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 10\u001b[39m, in \u001b[36mg\u001b[39m\u001b[34m(func, x, y)\u001b[39m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mg\u001b[39m(func, x, y):\n\u001b[32m 9\u001b[39m func_val = func(x, y)\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfunc_val\u001b[49m\u001b[43m)\u001b[49m == \u001b[32m2\u001b[39m:\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m func_val[\u001b[32m0\u001b[39m]\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", + " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.shape[\u001b[32m0\u001b[39m]\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[32m-> \u001b[39m\u001b[32m1898\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mlen() of unsized object\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01merr\u001b[39;00m\n", + "\u001b[31mTypeError\u001b[39m: len() of unsized object" ] } ], diff --git a/tests/test_sparse_stochastic_and_batch_sep.py b/tests/test_sparse_stochastic_and_batch_sep.py index ba9707c0..4a2d4d22 100644 --- a/tests/test_sparse_stochastic_and_batch_sep.py +++ b/tests/test_sparse_stochastic_and_batch_sep.py @@ -125,7 +125,7 @@ def test_benchmark_models(): state: state_choices_sparse[:, id] for id, state in enumerate(discrete_states_names + ["choice"]) } - (endog_grid_full, policy_full, value_full) = ( + endog_grid_full, policy_full, value_full = ( model_solved_full.get_solution_for_discrete_state_choice( states=states_dict, choices=state_choices_sparse[:, -1] )