From b0a7ddee76e311f36c6dff73d1e40db97db7c299 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Tue, 31 Mar 2026 14:27:38 +0200 Subject: [PATCH 01/19] chore: add optree package --- pixi.lock | 10 +++++----- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 4f9960725..cad1467ab 100644 --- a/pixi.lock +++ b/pixi.lock @@ -19402,8 +19402,8 @@ packages: timestamp: 1733688053334 - pypi: ./ name: optimagic - version: 0.5.4.dev5+g8654d6292.d20260312 - sha256: 72e7ed28837a3da869c13448a83bee607a85ab9f1e2d9dc5b5e7604d8ce2bf94 + version: 0.1.dev472+g7cc224247.d20260331 + sha256: 3fd4339b58c2f6c6ec2bec645fbce927d4e9a2180d17832e181c41f478e57915 requires_dist: - annotated-types>=0.4 - cloudpickle>=2.2 @@ -19525,7 +19525,7 @@ packages: license: Apache-2.0 license_family: Apache purls: - - pkg:pypi/optree?source=compressed-mapping + - pkg:pypi/optree?source=hash-mapping size: 378894 timestamp: 1771868468546 - conda: https://conda.anaconda.org/conda-forge/win-64/optree-0.19.0-py313hf069bd2_0.conda @@ -19541,7 +19541,7 @@ packages: license: Apache-2.0 license_family: Apache purls: - - pkg:pypi/optree?source=compressed-mapping + - pkg:pypi/optree?source=hash-mapping size: 385763 timestamp: 1771868441594 - conda: https://conda.anaconda.org/conda-forge/win-64/optree-0.19.0-py314h909e829_0.conda @@ -19557,7 +19557,7 @@ packages: license: Apache-2.0 license_family: Apache purls: - - pkg:pypi/optree?source=compressed-mapping + - pkg:pypi/optree?source=hash-mapping size: 394633 timestamp: 1771868448953 - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl diff --git a/pyproject.toml b/pyproject.toml index af1ac9a23..0c6a4905d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -439,6 +439,7 @@ cma = ">=3.3" pygad = ">=3.2" pytorch-cpu = ">=2.2" ruff = ">=0.15.5,<0.16" +optree = ">=0.19.0,<0.20" [tool.pixi.pypi-dependencies] optimagic = { path = ".", editable = true } From d4d9d79571a0f0fcf7bb45dfd0c3b443ccc66fbf Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Tue, 31 Mar 2026 14:28:43 +0200 Subject: [PATCH 02/19] chore: add methods for working with optree --- src/optimagic/parameters/tree_registry.py | 95 +++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 808ad81ac..501f26dde 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,10 +1,13 @@ """Wrapper around pybaum get_registry to tailor it to optimagic.""" +from collections import OrderedDict from functools import partial from itertools import product import numpy as np +import optree import pandas as pd +from optree.pytree import PyTreeSpec from pybaum import get_registry as get_pybaum_registry @@ -88,3 +91,95 @@ def _index_element_to_string(element): res_string = str(element) return res_string + + +extended = "extended" + + +def tree_flatten(tree, is_leaf=None, registry=None): + if isinstance(tree, dict): + tree = OrderedDict(tree) + return optree.tree_flatten( + tree, is_leaf=is_leaf, namespace=extended if registry else "" + ) + + +def tree_just_flatten(tree, is_leaf=None, registry=None): + if isinstance(tree, dict): + tree = OrderedDict(tree) + + return optree.tree_leaves( + tree, + is_leaf, + namespace=extended if registry else "", + ) + + +def tree_unflatten(treedef, leaves, is_leaf=None, registry=None): + if not isinstance(treedef, PyTreeSpec): + if isinstance(treedef, dict): + treedef = OrderedDict(treedef) + _, treedef = optree.tree_flatten( + treedef, namespace=extended if registry else "" + ) + return optree.tree_unflatten(treespec=treedef, leaves=leaves) + + +def tree_map(func, tree, is_leaf=None, registry=None): + return optree.tree_map( + func, tree, is_leaf=is_leaf, namespace=extended if registry else "" + ) + + +def update_tree(tree, data_col): + return tree_map( + lambda node: ( + node + if not isinstance(node, pd.DataFrame) + else CustomDataFrame(node, data_col=data_col) + ), + tree, + ) + + +optree.register_pytree_node( + pd.Series, + lambda sr: ( + sr.tolist(), + {"index": sr.index, "name": sr.name}, + ), + lambda aux_data, leaves: pd.Series(leaves, **aux_data), + namespace=extended, +) + + +@optree.register_pytree_node_class(namespace=extended) +class CustomDataFrame: + def __init__(self, df, data_col): + self.df = df + self.data_col = data_col + + def __tree_flatten__(self): + return _flatten_df(self.df, self.data_col) + + @classmethod + def __tree_unflatten__(cls, aux_data, leaves): + return _unflatten_df( + aux_data=aux_data, leaves=leaves, data_col=aux_data["data_col"] + ) + + +optree.register_pytree_node( + pd.DataFrame, + partial(_flatten_df, data_col="value"), + partial(_unflatten_df, data_col="value"), + namespace=extended, +) + +optree.register_pytree_node( + np.ndarray, + lambda arr: (arr.flatten().tolist(), arr.shape), + lambda aux_data, leaves: np.array(leaves).reshape(aux_data), + namespace=extended, +) +# DONT FORGET JAX From 3bf5cd7fbfd95b1fefb4478d8ca71b2c0929dd8c Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Tue, 31 Mar 2026 14:29:18 +0200 Subject: [PATCH 03/19] chore: replace pybaum tree methods with optree methods --- src/estimagic/bootstrap.py | 9 +++++++-- src/estimagic/estimate_msm.py | 4 ++-- src/estimagic/msm_weighting.py | 3 +-- src/estimagic/shared_covs.py | 7 +++++-- src/optimagic/benchmarking/run_benchmark.py | 3 +-- src/optimagic/differentiation/derivatives.py | 10 +++++++--- src/optimagic/examples/criterion_functions.py | 7 +++++-- src/optimagic/optimization/fun_value.py | 3 +-- src/optimagic/optimization/history.py | 4 ++-- src/optimagic/parameters/block_trees.py | 9 ++++++--- src/optimagic/parameters/bounds.py | 14 ++++++++++---- src/optimagic/parameters/nonlinear_constraints.py | 8 ++++++-- src/optimagic/parameters/process_selectors.py | 3 +-- src/optimagic/parameters/tree_conversion.py | 9 +++++++-- src/optimagic/visualization/history_plots.py | 9 +++++++-- src/optimagic/visualization/slice_plot.py | 3 +-- src/optimagic/visualization/slice_plot_3d.py | 3 +-- tests/estimagic/test_bootstrap_ci.py | 3 +-- .../test_estimate_msm_dict_params_and_moments.py | 3 +-- tests/optimagic/logging/test_logger.py | 4 ++-- tests/optimagic/optimization/test_history.py | 2 +- .../optimagic/optimization/test_params_versions.py | 3 +-- tests/optimagic/optimization/test_with_logging.py | 3 +-- tests/optimagic/parameters/test_block_trees.py | 2 +- .../parameters/test_nonlinear_constraints.py | 3 +-- .../optimagic/parameters/test_process_selectors.py | 8 ++++++-- tests/optimagic/parameters/test_tree_registry.py | 8 ++++++-- 27 files changed, 91 insertions(+), 56 deletions(-) diff --git a/src/estimagic/bootstrap.py b/src/estimagic/bootstrap.py index 76d75c4fb..6ab35f1a0 100644 --- a/src/estimagic/bootstrap.py +++ b/src/estimagic/bootstrap.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten +from pybaum import leaf_names from estimagic.bootstrap_ci import calculate_ci from estimagic.bootstrap_helpers import check_inputs @@ -13,7 +13,12 @@ from estimagic.shared_covs import calculate_estimation_summary from optimagic.batch_evaluators import joblib_batch_evaluator from optimagic.parameters.block_trees import matrix_to_block_tree -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_just_flatten, + tree_unflatten, +) from optimagic.utilities import get_rng diff --git a/src/estimagic/estimate_msm.py b/src/estimagic/estimate_msm.py index bf17d8f73..62990a620 100644 --- a/src/estimagic/estimate_msm.py +++ b/src/estimagic/estimate_msm.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd -from pybaum import leaf_names, tree_just_flatten +from pybaum import leaf_names from estimagic.msm_covs import cov_optimal, cov_robust from estimagic.msm_sensitivity import ( @@ -51,7 +51,7 @@ from optimagic.parameters.bounds import Bounds, pre_process_bounds from optimagic.parameters.conversion import Converter, get_converter from optimagic.parameters.space_conversion import InternalParams -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.shared.check_option_dicts import ( check_optimization_options, ) diff --git a/src/estimagic/msm_weighting.py b/src/estimagic/msm_weighting.py index 991e9b54f..6be390130 100644 --- a/src/estimagic/msm_weighting.py +++ b/src/estimagic/msm_weighting.py @@ -2,12 +2,11 @@ import numpy as np import pandas as pd -from pybaum import tree_just_flatten from scipy.linalg import block_diag from estimagic.bootstrap import bootstrap from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.utilities import robust_inverse diff --git a/src/estimagic/shared_covs.py b/src/estimagic/shared_covs.py index c4cccc3a2..b95c3eb4a 100644 --- a/src/estimagic/shared_covs.py +++ b/src/estimagic/shared_covs.py @@ -3,10 +3,13 @@ import numpy as np import pandas as pd import scipy -from pybaum import tree_just_flatten, tree_unflatten from optimagic.parameters.block_trees import matrix_to_block_tree -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_just_flatten, + tree_unflatten, +) def transform_covariance( diff --git a/src/optimagic/benchmarking/run_benchmark.py b/src/optimagic/benchmarking/run_benchmark.py index cd6d844c4..265b838d9 100644 --- a/src/optimagic/benchmarking/run_benchmark.py +++ b/src/optimagic/benchmarking/run_benchmark.py @@ -9,12 +9,11 @@ """ import numpy as np -from pybaum import tree_just_flatten from optimagic import batch_evaluators from optimagic.algorithms import AVAILABLE_ALGORITHMS from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten def run_benchmark( diff --git a/src/optimagic/differentiation/derivatives.py b/src/optimagic/differentiation/derivatives.py index e2caf5daf..ac28b3af3 100644 --- a/src/optimagic/differentiation/derivatives.py +++ b/src/optimagic/differentiation/derivatives.py @@ -8,8 +8,6 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from pybaum import tree_flatten, tree_just_flatten, tree_unflatten -from pybaum import tree_just_flatten as tree_leaves from optimagic import batch_evaluators, deprecations from optimagic.config import DEFAULT_N_CORES @@ -22,7 +20,13 @@ from optimagic.differentiation.richardson_extrapolation import richardson_extrapolation from optimagic.parameters.block_trees import hessian_to_block_tree, matrix_to_block_tree from optimagic.parameters.bounds import Bounds, get_internal_bounds, pre_process_bounds -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_just_flatten, + tree_unflatten, +) +from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves from optimagic.typing import BatchEvaluatorLiteral, PyTree diff --git a/src/optimagic/examples/criterion_functions.py b/src/optimagic/examples/criterion_functions.py index bb925d399..ae3708143 100644 --- a/src/optimagic/examples/criterion_functions.py +++ b/src/optimagic/examples/criterion_functions.py @@ -10,14 +10,17 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from pybaum import tree_just_flatten, tree_unflatten from optimagic import mark from optimagic.optimization.fun_value import ( FunctionValue, ) from optimagic.parameters.block_trees import matrix_to_block_tree -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_just_flatten, + tree_unflatten, +) from optimagic.typing import PyTree REGISTRY = get_registry(extended=True) diff --git a/src/optimagic/optimization/fun_value.py b/src/optimagic/optimization/fun_value.py index aeb4f3dda..3cca7cbc2 100644 --- a/src/optimagic/optimization/fun_value.py +++ b/src/optimagic/optimization/fun_value.py @@ -5,10 +5,9 @@ import numpy as np from numpy.typing import NDArray -from pybaum import tree_just_flatten from optimagic.exceptions import InvalidFunctionError -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.typing import AggregationLevel, PyTree, Scalar from optimagic.utilities import isscalar diff --git a/src/optimagic/optimization/history.py b/src/optimagic/optimization/history.py index 73bea2d93..c9d3f1d86 100644 --- a/src/optimagic/optimization/history.py +++ b/src/optimagic/optimization/history.py @@ -6,9 +6,9 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from pybaum import leaf_names, tree_just_flatten +from pybaum import leaf_names -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.timing import CostModel from optimagic.typing import Direction, EvalTask, PyTree diff --git a/src/optimagic/parameters/block_trees.py b/src/optimagic/parameters/block_trees.py index 269898b0d..9e34b842b 100644 --- a/src/optimagic/parameters/block_trees.py +++ b/src/optimagic/parameters/block_trees.py @@ -2,10 +2,13 @@ import numpy as np import pandas as pd -from pybaum import tree_flatten, tree_unflatten -from pybaum import tree_just_flatten as tree_leaves -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_unflatten, +) +from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves def matrix_to_block_tree(matrix, outer_tree, inner_tree): diff --git a/src/optimagic/parameters/bounds.py b/src/optimagic/parameters/bounds.py index 344dca4f4..fbdc2a559 100644 --- a/src/optimagic/parameters/bounds.py +++ b/src/optimagic/parameters/bounds.py @@ -5,12 +5,18 @@ import numpy as np from numpy.typing import NDArray -from pybaum import leaf_names, tree_map -from pybaum import tree_just_flatten as tree_leaves +from pybaum import leaf_names from scipy.optimize import Bounds as ScipyBounds from optimagic.exceptions import InvalidBoundsError -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_map, + update_tree, +) +from optimagic.parameters.tree_registry import ( + tree_just_flatten as tree_leaves, +) from optimagic.typing import PyTree, PyTreeRegistry from optimagic.utilities import fast_numpy_full @@ -177,7 +183,7 @@ def _update_bounds_and_flatten( """ registry = get_registry(extended=True, data_col=kind) - flat_nan_tree = tree_leaves(nan_tree, registry=registry) + flat_nan_tree = tree_leaves(update_tree(nan_tree, data_col=kind), registry=registry) if bounds is not None: registry = get_registry(extended=True) diff --git a/src/optimagic/parameters/nonlinear_constraints.py b/src/optimagic/parameters/nonlinear_constraints.py index 0cdd8e345..1d8a022ae 100644 --- a/src/optimagic/parameters/nonlinear_constraints.py +++ b/src/optimagic/parameters/nonlinear_constraints.py @@ -4,13 +4,17 @@ import numpy as np import pandas as pd -from pybaum import tree_flatten, tree_just_flatten, tree_unflatten from optimagic.differentiation.derivatives import first_derivative from optimagic.exceptions import InvalidConstraintError, InvalidFunctionError from optimagic.optimization.algo_options import CONSTRAINTS_ABSOLUTE_TOLERANCE from optimagic.parameters.block_trees import block_tree_to_matrix -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_just_flatten, + tree_unflatten, +) def process_nonlinear_constraints( diff --git a/src/optimagic/parameters/process_selectors.py b/src/optimagic/parameters/process_selectors.py index 8a9276852..8b05a3dab 100644 --- a/src/optimagic/parameters/process_selectors.py +++ b/src/optimagic/parameters/process_selectors.py @@ -3,11 +3,10 @@ import numpy as np import pandas as pd -from pybaum import tree_just_flatten from optimagic.constraints import Constraint from optimagic.exceptions import InvalidConstraintError -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten def process_selectors(constraints, params, tree_converter, param_names): diff --git a/src/optimagic/parameters/tree_conversion.py b/src/optimagic/parameters/tree_conversion.py index 2e29fd87e..b2cdb4cb5 100644 --- a/src/optimagic/parameters/tree_conversion.py +++ b/src/optimagic/parameters/tree_conversion.py @@ -1,12 +1,17 @@ from typing import Callable, NamedTuple import numpy as np -from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten +from pybaum import leaf_names from optimagic.exceptions import InvalidFunctionError from optimagic.parameters.block_trees import block_tree_to_matrix from optimagic.parameters.bounds import get_internal_bounds -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_just_flatten, + tree_unflatten, +) from optimagic.typing import AggregationLevel diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index 72dc7ab07..fb1f365bc 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -5,14 +5,19 @@ from typing import Any, Callable, Literal import numpy as np -from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten +from pybaum import leaf_names from optimagic.config import DEFAULT_PALETTE from optimagic.logging.logger import LogReader, SQLiteLogOptions from optimagic.optimization.algorithm import Algorithm from optimagic.optimization.history import History from optimagic.optimization.optimize_result import OptimizeResult -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_just_flatten, + tree_unflatten, +) from optimagic.typing import IterationHistory, PyTree from optimagic.visualization.backends import line_plot from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle diff --git a/src/optimagic/visualization/slice_plot.py b/src/optimagic/visualization/slice_plot.py index 92802cf3f..fa0581692 100644 --- a/src/optimagic/visualization/slice_plot.py +++ b/src/optimagic/visualization/slice_plot.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from pybaum import tree_just_flatten import optimagic as om from optimagic import deprecations @@ -24,7 +23,7 @@ from optimagic.parameters.bounds import pre_process_bounds from optimagic.parameters.conversion import get_converter from optimagic.parameters.space_conversion import InternalParams -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.shared.process_user_function import infer_aggregation_level from optimagic.typing import AggregationLevel, PyTree from optimagic.visualization.backends import grid_line_plot, line_plot diff --git a/src/optimagic/visualization/slice_plot_3d.py b/src/optimagic/visualization/slice_plot_3d.py index 1b6a7fc90..50f3f175b 100644 --- a/src/optimagic/visualization/slice_plot_3d.py +++ b/src/optimagic/visualization/slice_plot_3d.py @@ -8,7 +8,6 @@ import plotly.graph_objects as go from numpy.typing import NDArray from plotly.subplots import make_subplots -from pybaum import tree_just_flatten from optimagic import deprecations from optimagic.batch_evaluators import process_batch_evaluator @@ -20,7 +19,7 @@ ) from optimagic.parameters.bounds import pre_process_bounds from optimagic.parameters.conversion import get_converter -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.shared.process_user_function import infer_aggregation_level from optimagic.typing import AggregationLevel diff --git a/tests/estimagic/test_bootstrap_ci.py b/tests/estimagic/test_bootstrap_ci.py index 64562438d..8610d348a 100644 --- a/tests/estimagic/test_bootstrap_ci.py +++ b/tests/estimagic/test_bootstrap_ci.py @@ -3,11 +3,10 @@ import numpy as np import pandas as pd import pytest -from pybaum import tree_just_flatten from estimagic.bootstrap_ci import calculate_ci, check_inputs from estimagic.bootstrap_samples import get_bootstrap_indices -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.utilities import get_rng diff --git a/tests/estimagic/test_estimate_msm_dict_params_and_moments.py b/tests/estimagic/test_estimate_msm_dict_params_and_moments.py index b1cbcd250..79a6314e6 100644 --- a/tests/estimagic/test_estimate_msm_dict_params_and_moments.py +++ b/tests/estimagic/test_estimate_msm_dict_params_and_moments.py @@ -3,10 +3,9 @@ import numpy as np import pandas as pd from numpy.testing import assert_array_almost_equal as aaae -from pybaum import tree_just_flatten from estimagic.estimate_msm import estimate_msm -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten def test_estimate_msm_dict_params_and_moments(): diff --git a/tests/optimagic/logging/test_logger.py b/tests/optimagic/logging/test_logger.py index ff099d55f..7ec170638 100644 --- a/tests/optimagic/logging/test_logger.py +++ b/tests/optimagic/logging/test_logger.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd import pytest -from pybaum import tree_equal, tree_just_flatten +from pybaum import tree_equal from optimagic.logging.logger import ( LogOptions, @@ -13,7 +13,7 @@ SQLiteLogReader, ) from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten from optimagic.typing import Direction diff --git a/tests/optimagic/optimization/test_history.py b/tests/optimagic/optimization/test_history.py index cb03bc253..2dd4ed790 100644 --- a/tests/optimagic/optimization/test_history.py +++ b/tests/optimagic/optimization/test_history.py @@ -4,7 +4,6 @@ from numpy.testing import assert_array_almost_equal as aaae from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal -from pybaum import tree_map import optimagic as om from optimagic.optimization.history import ( @@ -19,6 +18,7 @@ _task_to_categorical, _validate_args_are_all_none_or_lists_of_same_length, ) +from optimagic.parameters.tree_registry import tree_map from optimagic.typing import Direction, EvalTask # ====================================================================================== diff --git a/tests/optimagic/optimization/test_params_versions.py b/tests/optimagic/optimization/test_params_versions.py index f3399cb12..39bbaac82 100644 --- a/tests/optimagic/optimization/test_params_versions.py +++ b/tests/optimagic/optimization/test_params_versions.py @@ -2,7 +2,6 @@ import pandas as pd import pytest from numpy.testing import assert_array_almost_equal as aaae -from pybaum import tree_just_flatten from optimagic.examples.criterion_functions import ( sos_gradient, @@ -11,7 +10,7 @@ sos_scalar, ) from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten REGISTRY = get_registry(extended=True) diff --git a/tests/optimagic/optimization/test_with_logging.py b/tests/optimagic/optimization/test_with_logging.py index b279f5202..d8cc7bd35 100644 --- a/tests/optimagic/optimization/test_with_logging.py +++ b/tests/optimagic/optimization/test_with_logging.py @@ -11,7 +11,6 @@ import pandas as pd import pytest from numpy.testing import assert_array_almost_equal as aaae -from pybaum import tree_just_flatten from optimagic import mark from optimagic.examples.criterion_functions import ( @@ -21,7 +20,7 @@ from optimagic.logging.logger import SQLiteLogOptions from optimagic.logging.types import ExistenceStrategy from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten @mark.least_squares diff --git a/tests/optimagic/parameters/test_block_trees.py b/tests/optimagic/parameters/test_block_trees.py index 08b2307cd..f2bc32523 100644 --- a/tests/optimagic/parameters/test_block_trees.py +++ b/tests/optimagic/parameters/test_block_trees.py @@ -3,7 +3,6 @@ import pytest from numpy.testing import assert_array_equal from pybaum import tree_equal -from pybaum import tree_just_flatten as tree_leaves from optimagic import second_derivative from optimagic.parameters.block_trees import ( @@ -13,6 +12,7 @@ matrix_to_block_tree, ) from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves def test_matrix_to_block_tree_array_and_scalar(): diff --git a/tests/optimagic/parameters/test_nonlinear_constraints.py b/tests/optimagic/parameters/test_nonlinear_constraints.py index 2d0eeaa1c..6e428ed71 100644 --- a/tests/optimagic/parameters/test_nonlinear_constraints.py +++ b/tests/optimagic/parameters/test_nonlinear_constraints.py @@ -6,7 +6,6 @@ import pytest from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal -from pybaum import tree_just_flatten from optimagic.differentiation.numdiff_options import NumdiffOptions from optimagic.exceptions import InvalidConstraintError @@ -22,7 +21,7 @@ process_nonlinear_constraints, vector_as_list_of_scalar_constraints, ) -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_just_flatten @dataclass diff --git a/tests/optimagic/parameters/test_process_selectors.py b/tests/optimagic/parameters/test_process_selectors.py index 7ad9c78e6..726864a5f 100644 --- a/tests/optimagic/parameters/test_process_selectors.py +++ b/tests/optimagic/parameters/test_process_selectors.py @@ -2,12 +2,16 @@ import pandas as pd import pytest from numpy.testing import assert_array_equal as aae -from pybaum import tree_flatten, tree_just_flatten, tree_unflatten from optimagic.exceptions import InvalidConstraintError from optimagic.parameters.process_selectors import process_selectors from optimagic.parameters.tree_conversion import TreeConverter -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_just_flatten, + tree_unflatten, +) @pytest.mark.parametrize("constraints", [None, []]) diff --git a/tests/optimagic/parameters/test_tree_registry.py b/tests/optimagic/parameters/test_tree_registry.py index 6f7362538..3f617a267 100644 --- a/tests/optimagic/parameters/test_tree_registry.py +++ b/tests/optimagic/parameters/test_tree_registry.py @@ -2,9 +2,13 @@ import pandas as pd import pytest from pandas.testing import assert_frame_equal -from pybaum import leaf_names, tree_flatten, tree_unflatten +from pybaum import leaf_names -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import ( + get_registry, + tree_flatten, + tree_unflatten, +) @pytest.fixture() From 5c3b56373b2da0d85df369784b78a10dc97de2ba Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Tue, 31 Mar 2026 17:33:08 +0200 Subject: [PATCH 04/19] fix: update implementation for setting data_col attribute for dataframes --- src/optimagic/parameters/bounds.py | 6 +- src/optimagic/parameters/tree_registry.py | 58 ++++++++----------- .../parameters/test_tree_registry.py | 29 ++++++++++ 3 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/optimagic/parameters/bounds.py b/src/optimagic/parameters/bounds.py index fbdc2a559..1c7811a89 100644 --- a/src/optimagic/parameters/bounds.py +++ b/src/optimagic/parameters/bounds.py @@ -11,8 +11,8 @@ from optimagic.exceptions import InvalidBoundsError from optimagic.parameters.tree_registry import ( get_registry, + set_data_col_df_attribute, tree_map, - update_tree, ) from optimagic.parameters.tree_registry import ( tree_just_flatten as tree_leaves, @@ -183,7 +183,9 @@ def _update_bounds_and_flatten( """ registry = get_registry(extended=True, data_col=kind) - flat_nan_tree = tree_leaves(update_tree(nan_tree, data_col=kind), registry=registry) + flat_nan_tree = tree_leaves( + set_data_col_df_attribute(nan_tree, data_col=kind), registry=registry + ) if bounds is not None: registry = get_registry(extended=True) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 501f26dde..d72e6eae4 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -131,48 +131,40 @@ def tree_map(func, tree, is_leaf=None, registry=None): ) -def update_tree(tree, data_col): - return tree_map( - lambda node: ( - node - if not isinstance(node, pd.DataFrame) - else CustomDataFrame(node, data_col=data_col) - ), - tree, - ) +def set_data_col_df_attribute(tree, data_col): + def set_attr(node): + if isinstance(node, pd.DataFrame): + node = node.copy() + node.attrs["data_col"] = data_col + return node + return tree_map(set_attr, tree) -optree.register_pytree_node( - pd.Series, - lambda sr: ( - sr.tolist(), - {"index": sr.index, "name": sr.name}, - ), - lambda aux_data, leaves: pd.Series(leaves, **aux_data), - namespace=extended, -) +def _flatten_df_optree(df): + data_col = df.attrs.get("data_col", "value") + return _flatten_df(df, data_col=data_col) -@optree.register_pytree_node_class(namespace=extended) -class CustomDataFrame: - def __init__(self, df, data_col): - self.df = df - self.data_col = data_col - def __tree_flatten__(self): - return _flatten_df(self.df, self.data_col) - - @classmethod - def __tree_unflatten__(cls, aux_data, leaves): - return _unflatten_df( - aux_data=aux_data, leaves=leaves, data_col=aux_data["data_col"] - ) +def _unflatten_df_optree(aux_data, leaves): + data_col = aux_data["df"].attrs.get("data_col", "value") + return _unflatten_df(aux_data=aux_data, leaves=leaves, data_col=data_col) optree.register_pytree_node( pd.DataFrame, - partial(_flatten_df, data_col="value"), - partial(_unflatten_df, data_col="value"), + _flatten_df_optree, + _unflatten_df_optree, + namespace=extended, +) + +optree.register_pytree_node( + pd.Series, + lambda sr: ( + sr.tolist(), + {"index": sr.index, "name": sr.name}, + ), + lambda aux_data, leaves: pd.Series(leaves, **aux_data), namespace=extended, ) diff --git a/tests/optimagic/parameters/test_tree_registry.py b/tests/optimagic/parameters/test_tree_registry.py index 3f617a267..1bf4b95cf 100644 --- a/tests/optimagic/parameters/test_tree_registry.py +++ b/tests/optimagic/parameters/test_tree_registry.py @@ -6,6 +6,7 @@ from optimagic.parameters.tree_registry import ( get_registry, + set_data_col_df_attribute, tree_flatten, tree_unflatten, ) @@ -66,3 +67,31 @@ def test_leaf_names_partially_numeric_df(other_df): registry = get_registry(extended=True) names = leaf_names(other_df, registry=registry) assert names == ["alpha_b", "alpha_c", "beta_b", "beta_c", "gamma_b", "gamma_c"] + + +def test_set_data_col_attribute_assigns_attribute(value_df): + df = set_data_col_df_attribute(value_df, data_col="attr") + assert df.attrs.get("data_col") == "attr" + assert value_df.attrs.get("data_col") is None + + +def test_set_data_col_attribute_unflattened_tree_has_attribute(value_df): + registry = get_registry(extended=True) + df = set_data_col_df_attribute(value_df, data_col="attr") + tree, treedef = tree_flatten(df, registry=registry) + df = tree_unflatten(treedef, tree) + assert df.attrs.get("data_col") == "attr" + + +def test_set_data_col_attribute_returns_nan(value_df): + registry = get_registry(extended=True) + df = set_data_col_df_attribute(value_df, data_col="attr") + tree, treedef = tree_flatten(df, registry=registry) + assert all(np.isnan(value) for value in tree) + + +def test_set_data_col_attribute_returs_column_values(value_df): + registry = get_registry(extended=True) + df = set_data_col_df_attribute(value_df, data_col="a") + tree, treedef = tree_flatten(df, registry=registry) + assert tree == [0, 2, 4] From 51cc30fc0232e768b783a487698afe29f575a3dc Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Wed, 1 Apr 2026 11:26:57 +0200 Subject: [PATCH 05/19] chore: replace leaf_names with optree method --- src/estimagic/bootstrap.py | 2 +- src/estimagic/estimate_msm.py | 7 ++-- src/optimagic/optimization/history.py | 7 ++-- src/optimagic/parameters/bounds.py | 2 +- src/optimagic/parameters/tree_conversion.py | 2 +- src/optimagic/parameters/tree_registry.py | 35 +++++++++++++++++-- src/optimagic/visualization/history_plots.py | 2 +- tests/estimagic/test_shared.py | 4 +-- .../parameters/test_tree_registry.py | 2 +- 9 files changed, 49 insertions(+), 14 deletions(-) diff --git a/src/estimagic/bootstrap.py b/src/estimagic/bootstrap.py index 6ab35f1a0..c84668e99 100644 --- a/src/estimagic/bootstrap.py +++ b/src/estimagic/bootstrap.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd -from pybaum import leaf_names from estimagic.bootstrap_ci import calculate_ci from estimagic.bootstrap_helpers import check_inputs @@ -15,6 +14,7 @@ from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( get_registry, + leaf_names, tree_flatten, tree_just_flatten, tree_unflatten, diff --git a/src/estimagic/estimate_msm.py b/src/estimagic/estimate_msm.py index 62990a620..d095eddcd 100644 --- a/src/estimagic/estimate_msm.py +++ b/src/estimagic/estimate_msm.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd -from pybaum import leaf_names from estimagic.msm_covs import cov_optimal, cov_robust from estimagic.msm_sensitivity import ( @@ -51,7 +50,11 @@ from optimagic.parameters.bounds import Bounds, pre_process_bounds from optimagic.parameters.conversion import Converter, get_converter from optimagic.parameters.space_conversion import InternalParams -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import ( + get_registry, + leaf_names, + tree_just_flatten, +) from optimagic.shared.check_option_dicts import ( check_optimization_options, ) diff --git a/src/optimagic/optimization/history.py b/src/optimagic/optimization/history.py index c9d3f1d86..a3e2069d2 100644 --- a/src/optimagic/optimization/history.py +++ b/src/optimagic/optimization/history.py @@ -6,9 +6,12 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from pybaum import leaf_names -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import ( + get_registry, + leaf_names, + tree_just_flatten, +) from optimagic.timing import CostModel from optimagic.typing import Direction, EvalTask, PyTree diff --git a/src/optimagic/parameters/bounds.py b/src/optimagic/parameters/bounds.py index 1c7811a89..eed12a819 100644 --- a/src/optimagic/parameters/bounds.py +++ b/src/optimagic/parameters/bounds.py @@ -5,12 +5,12 @@ import numpy as np from numpy.typing import NDArray -from pybaum import leaf_names from scipy.optimize import Bounds as ScipyBounds from optimagic.exceptions import InvalidBoundsError from optimagic.parameters.tree_registry import ( get_registry, + leaf_names, set_data_col_df_attribute, tree_map, ) diff --git a/src/optimagic/parameters/tree_conversion.py b/src/optimagic/parameters/tree_conversion.py index b2cdb4cb5..0dddd328e 100644 --- a/src/optimagic/parameters/tree_conversion.py +++ b/src/optimagic/parameters/tree_conversion.py @@ -1,13 +1,13 @@ from typing import Callable, NamedTuple import numpy as np -from pybaum import leaf_names from optimagic.exceptions import InvalidFunctionError from optimagic.parameters.block_trees import block_tree_to_matrix from optimagic.parameters.bounds import get_internal_bounds from optimagic.parameters.tree_registry import ( get_registry, + leaf_names, tree_flatten, tree_just_flatten, tree_unflatten, diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index d72e6eae4..0ab0ef0f4 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,5 +1,6 @@ """Wrapper around pybaum get_registry to tailor it to optimagic.""" +import itertools from collections import OrderedDict from functools import partial from itertools import product @@ -131,6 +132,12 @@ def tree_map(func, tree, is_leaf=None, registry=None): ) +def leaf_names(tree, is_leaf=None, registry=None, separator="_"): + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry) + paths = treespec.paths() + return [separator.join(str(p) for p in path) for path in paths] + + def set_data_col_df_attribute(tree, data_col): def set_attr(node): if isinstance(node, pd.DataFrame): @@ -141,9 +148,31 @@ def set_attr(node): return tree_map(set_attr, tree) +def _get_names_pandas_dataframe(df): + index_strings = list(df.index.map(_index_element_to_string)) + out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] + return out + + +def _array_element_names(arr): + dim_names = [map(str, range(n)) for n in arr.shape] + names = list(map("_".join, itertools.product(*dim_names))) + return names + + def _flatten_df_optree(df): data_col = df.attrs.get("data_col", "value") - return _flatten_df(df, data_col=data_col) + is_value_df = "value" in df + if is_value_df: + flat = df.get(data_col, default=np.full(len(df), np.nan)).tolist() + else: + flat = df.to_numpy().flatten().tolist() + + aux_data = { + "is_value_df": is_value_df, + "df": df, + } + return flat, aux_data, _get_df_names(df) def _unflatten_df_optree(aux_data, leaves): @@ -163,6 +192,7 @@ def _unflatten_df_optree(aux_data, leaves): lambda sr: ( sr.tolist(), {"index": sr.index, "name": sr.name}, + list(sr.index.map(_index_element_to_string)), ), lambda aux_data, leaves: pd.Series(leaves, **aux_data), namespace=extended, @@ -170,8 +200,7 @@ def _unflatten_df_optree(aux_data, leaves): optree.register_pytree_node( np.ndarray, - lambda arr: (arr.flatten().tolist(), arr.shape), + lambda arr: (arr.flatten().tolist(), arr.shape, _array_element_names(arr)), lambda aux_data, leaves: np.array(leaves).reshape(aux_data), namespace=extended, ) -# DONT FORGET JAX diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index fb1f365bc..f6089ec00 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Literal import numpy as np -from pybaum import leaf_names from optimagic.config import DEFAULT_PALETTE from optimagic.logging.logger import LogReader, SQLiteLogOptions @@ -14,6 +13,7 @@ from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import ( get_registry, + leaf_names, tree_flatten, tree_just_flatten, tree_unflatten, diff --git a/tests/estimagic/test_shared.py b/tests/estimagic/test_shared.py index 9a4240c74..0c2dcce5a 100644 --- a/tests/estimagic/test_shared.py +++ b/tests/estimagic/test_shared.py @@ -4,7 +4,7 @@ import pandas as pd import pytest from numpy.testing import assert_array_almost_equal as aaae -from pybaum import leaf_names, tree_equal +from pybaum import tree_equal from estimagic.shared_covs import ( _to_numpy, @@ -15,7 +15,7 @@ transform_free_cov_to_cov, transform_free_values_to_params_tree, ) -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, leaf_names from optimagic.utilities import get_rng diff --git a/tests/optimagic/parameters/test_tree_registry.py b/tests/optimagic/parameters/test_tree_registry.py index 1bf4b95cf..600c93ad4 100644 --- a/tests/optimagic/parameters/test_tree_registry.py +++ b/tests/optimagic/parameters/test_tree_registry.py @@ -2,10 +2,10 @@ import pandas as pd import pytest from pandas.testing import assert_frame_equal -from pybaum import leaf_names from optimagic.parameters.tree_registry import ( get_registry, + leaf_names, set_data_col_df_attribute, tree_flatten, tree_unflatten, From 2edc2f259aba7bb8706c45aea0787f5ed020baf3 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Wed, 1 Apr 2026 11:29:28 +0200 Subject: [PATCH 06/19] chore: remove repeated OrderedDict check --- src/optimagic/parameters/tree_registry.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 0ab0ef0f4..2c64edd50 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -106,23 +106,13 @@ def tree_flatten(tree, is_leaf=None, registry=None): def tree_just_flatten(tree, is_leaf=None, registry=None): - if isinstance(tree, dict): - tree = OrderedDict(tree) - - return optree.tree_leaves( - tree, - is_leaf, - namespace=extended if registry else "", - ) + leaves, _ = tree_flatten(tree, is_leaf=is_leaf, registry=registry) + return leaves def tree_unflatten(treedef, leaves, is_leaf=None, registry=None): if not isinstance(treedef, PyTreeSpec): - if isinstance(treedef, dict): - treedef = OrderedDict(treedef) - _, treedef = optree.tree_flatten( - treedef, namespace=extended if registry else "" - ) + _, treedef = tree_flatten(treedef, is_leaf=is_leaf, registry=registry) return optree.tree_unflatten(treespec=treedef, leaves=leaves) From 1f07a0213f5c2ae810015abf46c5eebcddb831a2 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Wed, 1 Apr 2026 11:41:44 +0200 Subject: [PATCH 07/19] chore: move namespace variable to typing.py --- src/optimagic/parameters/tree_registry.py | 15 +++++++-------- src/optimagic/typing.py | 1 + 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 2c64edd50..ccb05632d 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -11,6 +11,8 @@ from optree.pytree import PyTreeSpec from pybaum import get_registry as get_pybaum_registry +from optimagic.typing import extended_namespace + def get_registry(extended=False, data_col="value"): """Return pytree registry. @@ -94,14 +96,11 @@ def _index_element_to_string(element): return res_string -extended = "extended" - - def tree_flatten(tree, is_leaf=None, registry=None): if isinstance(tree, dict): tree = OrderedDict(tree) return optree.tree_flatten( - tree, is_leaf=is_leaf, namespace=extended if registry else "" + tree, is_leaf=is_leaf, namespace=extended_namespace if registry else "" ) @@ -118,7 +117,7 @@ def tree_unflatten(treedef, leaves, is_leaf=None, registry=None): def tree_map(func, tree, is_leaf=None, registry=None): return optree.tree_map( - func, tree, is_leaf=is_leaf, namespace=extended if registry else "" + func, tree, is_leaf=is_leaf, namespace=extended_namespace if registry else "" ) @@ -174,7 +173,7 @@ def _unflatten_df_optree(aux_data, leaves): pd.DataFrame, _flatten_df_optree, _unflatten_df_optree, - namespace=extended, + namespace=extended_namespace, ) optree.register_pytree_node( @@ -185,12 +184,12 @@ def _unflatten_df_optree(aux_data, leaves): list(sr.index.map(_index_element_to_string)), ), lambda aux_data, leaves: pd.Series(leaves, **aux_data), - namespace=extended, + namespace=extended_namespace, ) optree.register_pytree_node( np.ndarray, lambda arr: (arr.flatten().tolist(), arr.shape, _array_element_names(arr)), lambda aux_data, leaves: np.array(leaves).reshape(aux_data), - namespace=extended, + namespace=extended_namespace, ) diff --git a/src/optimagic/typing.py b/src/optimagic/typing.py index 9b389ced2..795b98174 100644 --- a/src/optimagic/typing.py +++ b/src/optimagic/typing.py @@ -22,6 +22,7 @@ Scalar = Any T = TypeVar("T") +extended_namespace = "extended_namespace" class AggregationLevel(Enum): From def8aa3a6fab0fb3639acfd1c231c263704755df Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Wed, 1 Apr 2026 17:42:45 +0200 Subject: [PATCH 08/19] chore: remove unused method --- src/optimagic/parameters/tree_registry.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index ccb05632d..6f25a10f8 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -137,12 +137,6 @@ def set_attr(node): return tree_map(set_attr, tree) -def _get_names_pandas_dataframe(df): - index_strings = list(df.index.map(_index_element_to_string)) - out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] - return out - - def _array_element_names(arr): dim_names = [map(str, range(n)) for n in arr.shape] names = list(map("_".join, itertools.product(*dim_names))) From 109623f4f69faa053c9e5af6aa3b0989af1c2299 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Wed, 1 Apr 2026 22:57:30 +0200 Subject: [PATCH 09/19] chore: use optree context manager for ordering dict --- src/optimagic/parameters/tree_registry.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 6f25a10f8..6b37594e2 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,7 +1,6 @@ """Wrapper around pybaum get_registry to tailor it to optimagic.""" import itertools -from collections import OrderedDict from functools import partial from itertools import product @@ -97,11 +96,10 @@ def _index_element_to_string(element): def tree_flatten(tree, is_leaf=None, registry=None): - if isinstance(tree, dict): - tree = OrderedDict(tree) - return optree.tree_flatten( - tree, is_leaf=is_leaf, namespace=extended_namespace if registry else "" - ) + with optree.dict_insertion_ordered(True, namespace=extended_namespace): + return optree.tree_flatten( + tree, is_leaf=is_leaf, namespace=extended_namespace if registry else "" + ) def tree_just_flatten(tree, is_leaf=None, registry=None): @@ -122,7 +120,7 @@ def tree_map(func, tree, is_leaf=None, registry=None): def leaf_names(tree, is_leaf=None, registry=None, separator="_"): - leaves, treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry) + _, treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry) paths = treespec.paths() return [separator.join(str(p) for p in path) for path in paths] From e3eb3822970f499bcecbb42c27b566d250d71817 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Wed, 1 Apr 2026 22:59:15 +0200 Subject: [PATCH 10/19] chore: replace tree_equal method with optree impl --- src/optimagic/parameters/tree_registry.py | 32 +++++++++++++++++++ tests/estimagic/test_shared.py | 3 +- .../test_compare_derivatives_with_jax.py | 2 +- tests/optimagic/logging/test_logger.py | 7 ++-- .../optimagic/parameters/test_block_trees.py | 3 +- 5 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 6b37594e2..efa74dfa8 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -185,3 +185,35 @@ def _unflatten_df_optree(aux_data, leaves): lambda aux_data, leaves: np.array(leaves).reshape(aux_data), namespace=extended_namespace, ) + +EQUALITY_CHECKERS = {} +EQUALITY_CHECKERS[np.ndarray] = lambda a, b: bool((a == b).all()) +EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b) +EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b) + + +def tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None): + equality_checkers = ( + EQUALITY_CHECKERS + if equality_checkers is None + else {**EQUALITY_CHECKERS, **equality_checkers} + ) + + first_flat, first_treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry) + second_flat, second_treespec = tree_flatten( + other, is_leaf=is_leaf, registry=registry + ) + + first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) + second_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) + + equal = first_names == second_names and first_treespec == second_treespec + + if equal: + for first, second in zip(first_flat, second_flat, strict=True): + check_func = equality_checkers.get(type(first), lambda a, b: a == b) + equal = equal and check_func(first, second) + if not equal: + break + + return equal diff --git a/tests/estimagic/test_shared.py b/tests/estimagic/test_shared.py index 0c2dcce5a..fded1a3d9 100644 --- a/tests/estimagic/test_shared.py +++ b/tests/estimagic/test_shared.py @@ -4,7 +4,6 @@ import pandas as pd import pytest from numpy.testing import assert_array_almost_equal as aaae -from pybaum import tree_equal from estimagic.shared_covs import ( _to_numpy, @@ -15,7 +14,7 @@ transform_free_cov_to_cov, transform_free_values_to_params_tree, ) -from optimagic.parameters.tree_registry import get_registry, leaf_names +from optimagic.parameters.tree_registry import get_registry, leaf_names, tree_equal from optimagic.utilities import get_rng diff --git a/tests/optimagic/differentiation/test_compare_derivatives_with_jax.py b/tests/optimagic/differentiation/test_compare_derivatives_with_jax.py index 87b5554d8..56cb39a0b 100644 --- a/tests/optimagic/differentiation/test_compare_derivatives_with_jax.py +++ b/tests/optimagic/differentiation/test_compare_derivatives_with_jax.py @@ -7,10 +7,10 @@ import numpy as np import pytest from numpy.testing import assert_array_almost_equal as aaae -from pybaum import tree_equal from optimagic.config import IS_JAX_INSTALLED from optimagic.differentiation.derivatives import first_derivative, second_derivative +from optimagic.parameters.tree_registry import tree_equal if not IS_JAX_INSTALLED: pytestmark = pytest.mark.skip(reason="jax is not installed.") diff --git a/tests/optimagic/logging/test_logger.py b/tests/optimagic/logging/test_logger.py index 7ec170638..35e9e6c9f 100644 --- a/tests/optimagic/logging/test_logger.py +++ b/tests/optimagic/logging/test_logger.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd import pytest -from pybaum import tree_equal from optimagic.logging.logger import ( LogOptions, @@ -13,7 +12,11 @@ SQLiteLogReader, ) from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import ( + get_registry, + tree_equal, + tree_just_flatten, +) from optimagic.typing import Direction diff --git a/tests/optimagic/parameters/test_block_trees.py b/tests/optimagic/parameters/test_block_trees.py index f2bc32523..30c991c19 100644 --- a/tests/optimagic/parameters/test_block_trees.py +++ b/tests/optimagic/parameters/test_block_trees.py @@ -2,7 +2,6 @@ import pandas as pd import pytest from numpy.testing import assert_array_equal -from pybaum import tree_equal from optimagic import second_derivative from optimagic.parameters.block_trees import ( @@ -11,7 +10,7 @@ hessian_to_block_tree, matrix_to_block_tree, ) -from optimagic.parameters.tree_registry import get_registry +from optimagic.parameters.tree_registry import get_registry, tree_equal from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves From 5aa2b08ae8ce501728d1b90b61555b8de7753187 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 11:33:30 +0200 Subject: [PATCH 11/19] chore: remove get_registry method and use namespace arugment --- src/estimagic/bootstrap.py | 30 ++--- src/estimagic/estimate_msm.py | 17 ++- src/estimagic/msm_weighting.py | 6 +- src/estimagic/shared_covs.py | 13 +- src/optimagic/benchmarking/run_benchmark.py | 7 +- src/optimagic/differentiation/derivatives.py | 40 +++---- src/optimagic/examples/criterion_functions.py | 9 +- src/optimagic/optimization/fun_value.py | 5 +- src/optimagic/optimization/history.py | 8 +- src/optimagic/parameters/block_trees.py | 12 +- src/optimagic/parameters/bounds.py | 24 ++-- .../parameters/nonlinear_constraints.py | 9 +- src/optimagic/parameters/process_selectors.py | 11 +- src/optimagic/parameters/tree_conversion.py | 29 +++-- src/optimagic/parameters/tree_registry.py | 113 +++++------------- src/optimagic/visualization/history_plots.py | 14 +-- src/optimagic/visualization/slice_plot.py | 5 +- src/optimagic/visualization/slice_plot_3d.py | 5 +- tests/estimagic/test_bootstrap_ci.py | 5 +- ...st_estimate_msm_dict_params_and_moments.py | 7 +- tests/estimagic/test_shared.py | 5 +- tests/optimagic/logging/test_logger.py | 7 +- .../optimization/test_params_versions.py | 20 ++-- .../optimization/test_with_logging.py | 5 +- .../optimagic/parameters/test_block_trees.py | 7 +- .../parameters/test_nonlinear_constraints.py | 5 +- .../parameters/test_process_selectors.py | 9 +- .../parameters/test_tree_registry.py | 33 ++--- 28 files changed, 176 insertions(+), 284 deletions(-) diff --git a/src/estimagic/bootstrap.py b/src/estimagic/bootstrap.py index c84668e99..776e49e4a 100644 --- a/src/estimagic/bootstrap.py +++ b/src/estimagic/bootstrap.py @@ -13,7 +13,7 @@ from optimagic.batch_evaluators import joblib_batch_evaluator from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, tree_flatten, tree_just_flatten, @@ -107,9 +107,8 @@ def bootstrap( # Process results # ================================================================================== - registry = get_registry(extended=True) flat_outcomes = [ - tree_just_flatten(_outcome, registry=registry) for _outcome in all_outcomes + tree_just_flatten(_outcome, namespace=extended) for _outcome in all_outcomes ] internal_outcomes = np.array(flat_outcomes) @@ -167,11 +166,10 @@ def outcomes(self): List[Any]: The boostrap outcomes as a list of pytrees. """ - registry = get_registry(extended=True) - _, treedef = tree_flatten(self._base_outcome, registry=registry) + _, treedef = tree_flatten(self._base_outcome, namespace=extended) outcomes = [ - tree_unflatten(treedef, out, registry=registry) + tree_unflatten(treedef, out, namespace=extended) for out in self._internal_outcomes ] return outcomes @@ -187,10 +185,9 @@ def se(self): cov = self._internal_cov se = np.sqrt(np.diagonal(cov)) - registry = get_registry(extended=True) - _, treedef = tree_flatten(self._base_outcome, registry=registry) + _, treedef = tree_flatten(self._base_outcome, namespace=extended) - se = tree_unflatten(treedef, se, registry=registry) + se = tree_unflatten(treedef, se, namespace=extended) return se def cov(self, return_type="pytree"): @@ -211,8 +208,7 @@ def cov(self, return_type="pytree"): cov = self._internal_cov if return_type == "dataframe": - registry = get_registry(extended=True) - names = np.array(leaf_names(self._base_outcome, registry=registry)) + names = np.array(leaf_names(self._base_outcome, namespace=extended)) cov = pd.DataFrame(cov, columns=names, index=names) elif return_type == "pytree": cov = matrix_to_block_tree(cov, self._base_outcome, self._base_outcome) @@ -239,15 +235,16 @@ def ci(self, ci_method="percentile", ci_level=0.95): bounds of confidence intervals. """ - registry = get_registry(extended=True) - base_outcome_flat, treedef = tree_flatten(self._base_outcome, registry=registry) + base_outcome_flat, treedef = tree_flatten( + self._base_outcome, namespace=extended + ) lower_flat, upper_flat = calculate_ci( base_outcome_flat, self._internal_outcomes, ci_method, ci_level ) - lower = tree_unflatten(treedef, lower_flat, registry=registry) - upper = tree_unflatten(treedef, upper_flat, registry=registry) + lower = tree_unflatten(treedef, lower_flat, namespace=extended) + upper = tree_unflatten(treedef, upper_flat, namespace=extended) return lower, upper def p_values(self): @@ -276,8 +273,7 @@ def summary(self, ci_method="percentile", ci_level=0.95): Soon this will be a pytree. """ - registry = get_registry(extended=True) - names = leaf_names(self.base_outcome, registry=registry) + names = leaf_names(self.base_outcome, namespace=extended) summary_data = _calulcate_summary_data_bootstrap( self, ci_method=ci_method, ci_level=ci_level ) diff --git a/src/estimagic/estimate_msm.py b/src/estimagic/estimate_msm.py index d095eddcd..5cf337799 100644 --- a/src/estimagic/estimate_msm.py +++ b/src/estimagic/estimate_msm.py @@ -51,7 +51,7 @@ from optimagic.parameters.conversion import Converter, get_converter from optimagic.parameters.space_conversion import InternalParams from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, tree_just_flatten, ) @@ -321,8 +321,7 @@ def func(x): sim_mom = simulate_moments(params, **simulate_moments_kwargs) if isinstance(sim_mom, dict) and "simulated_moments" in sim_mom: sim_mom = sim_mom["simulated_moments"] - registry = get_registry(extended=True) - out = np.array(tree_just_flatten(sim_mom, registry=registry)) + out = np.array(tree_just_flatten(sim_mom, namespace=extended)) return out int_jac = first_derivative( @@ -421,8 +420,7 @@ def get_msm_optimization_functions( chol_weights = np.linalg.cholesky(flat_weights) - registry = get_registry(extended=True) - flat_emp_mom = tree_just_flatten(empirical_moments, registry=registry) + flat_emp_mom = tree_just_flatten(empirical_moments, namespace=extended) _simulate_moments = _partial_kwargs(simulate_moments, simulate_moments_kwargs) _jacobian = _partial_kwargs(jacobian, jacobian_kwargs) @@ -433,7 +431,7 @@ def get_msm_optimization_functions( simulate_moments=_simulate_moments, flat_empirical_moments=flat_emp_mom, chol_weights=chol_weights, - registry=registry, + namespace=extended, ) ) @@ -448,7 +446,7 @@ def get_msm_optimization_functions( def _msm_criterion( - params, simulate_moments, flat_empirical_moments, chol_weights, registry + params, simulate_moments, flat_empirical_moments, chol_weights, namespace ): """Calculate msm criterion given parameters and building blocks.""" simulated = simulate_moments(params) @@ -457,7 +455,7 @@ def _msm_criterion( if isinstance(simulated, np.ndarray) and simulated.ndim == 1: simulated_flat = simulated else: - simulated_flat = np.array(tree_just_flatten(simulated, registry=registry)) + simulated_flat = np.array(tree_just_flatten(simulated, namespace=namespace)) deviations = simulated_flat - flat_empirical_moments residuals = deviations @ chol_weights @@ -978,9 +976,8 @@ def sensitivity( inner_tree=self._empirical_moments, ) elif return_type == "dataframe": - registry = get_registry(extended=True) row_names = self._internal_estimates.names - col_names = leaf_names(self._empirical_moments, registry=registry) + col_names = leaf_names(self._empirical_moments, namespace=extended) out = pd.DataFrame( data=raw, index=row_names, diff --git a/src/estimagic/msm_weighting.py b/src/estimagic/msm_weighting.py index 6be390130..34222bb6d 100644 --- a/src/estimagic/msm_weighting.py +++ b/src/estimagic/msm_weighting.py @@ -6,7 +6,7 @@ from estimagic.bootstrap import bootstrap from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten from optimagic.utilities import robust_inverse @@ -50,13 +50,11 @@ def get_moments_cov( first_eval = calculate_moments(data, **moment_kwargs) - registry = get_registry(extended=True) - @functools.wraps(calculate_moments) def func(data, **kwargs): raw = calculate_moments(data, **kwargs) out = pd.Series( - tree_just_flatten(raw, registry=registry) + tree_just_flatten(raw, namespace=extended) ) # xxxx won't be necessary soon! return out diff --git a/src/estimagic/shared_covs.py b/src/estimagic/shared_covs.py index b95c3eb4a..c1f0b782f 100644 --- a/src/estimagic/shared_covs.py +++ b/src/estimagic/shared_covs.py @@ -6,7 +6,7 @@ from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_just_flatten, tree_unflatten, ) @@ -149,9 +149,8 @@ def calculate_estimation_summary( # Flatten summary and construct data frame for flat estimates # ================================================================================== - registry = get_registry(extended=True) flat_data = { - key: tree_just_flatten(val, registry=registry) + key: tree_just_flatten(val, namespace=extended) for key, val in summary_data.items() } @@ -170,7 +169,7 @@ def calculate_estimation_summary( # ================================================================================== # create tree with values corresponding to indices of df - indices = tree_unflatten(summary_data["value"], names, registry=registry) + indices = tree_unflatten(summary_data["value"], names, namespace=extended) estimates_flat = tree_just_flatten(summary_data["value"]) indices_flat = tree_just_flatten(indices) @@ -319,8 +318,7 @@ def calculate_free_estimates(estimates, internal_estimates): mask = internal_estimates.free_mask names = internal_estimates.names - registry = get_registry(extended=True) - external_flat = np.array(tree_just_flatten(estimates, registry=registry)) + external_flat = np.array(tree_just_flatten(estimates, namespace=extended)) free_estimates = FreeParams( values=external_flat[mask], @@ -354,8 +352,7 @@ def transform_free_values_to_params_tree(values, free_params, params): mask = free_params.free_mask flat = np.full(len(mask), np.nan) flat[np.ix_(mask)] = values - registry = get_registry(extended=True) - pytree = tree_unflatten(params, flat, registry=registry) + pytree = tree_unflatten(params, flat, namespace=extended) return pytree diff --git a/src/optimagic/benchmarking/run_benchmark.py b/src/optimagic/benchmarking/run_benchmark.py index 265b838d9..e89b73b7e 100644 --- a/src/optimagic/benchmarking/run_benchmark.py +++ b/src/optimagic/benchmarking/run_benchmark.py @@ -13,7 +13,7 @@ from optimagic import batch_evaluators from optimagic.algorithms import AVAILABLE_ALGORITHMS from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten def run_benchmark( @@ -179,7 +179,6 @@ def _process_one_result(optimize_result, problem): dict: Processed result. """ - _registry = get_registry(extended=True) _criterion = problem["noise_free_fun"] _start_x = problem["inputs"]["params"] _start_crit_value = _criterion(_start_x) @@ -190,7 +189,7 @@ def _process_one_result(optimize_result, problem): # This will happen if the optimization raised an error if isinstance(optimize_result, str): - params_history_flat = [tree_just_flatten(_start_x, registry=_registry)] + params_history_flat = [tree_just_flatten(_start_x, namespace=extended)] criterion_history = [_start_crit_value] time_history = [np.inf] batches_history = [0] @@ -198,7 +197,7 @@ def _process_one_result(optimize_result, problem): history = optimize_result.history params_history = history.params params_history_flat = [ - tree_just_flatten(p, registry=_registry) for p in params_history + tree_just_flatten(p, namespace=extended) for p in params_history ] if _is_noisy: criterion_history = np.array([_criterion(p) for p in params_history]) diff --git a/src/optimagic/differentiation/derivatives.py b/src/optimagic/differentiation/derivatives.py index ac28b3af3..a249ebeea 100644 --- a/src/optimagic/differentiation/derivatives.py +++ b/src/optimagic/differentiation/derivatives.py @@ -21,7 +21,7 @@ from optimagic.parameters.block_trees import hessian_to_block_tree, matrix_to_block_tree from optimagic.parameters.bounds import Bounds, get_internal_bounds, pre_process_bounds from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_flatten, tree_just_flatten, tree_unflatten, @@ -218,24 +218,23 @@ def first_derivative( # ================================================================================== # Convert scalar | pytree arguments to 1d arrays of floats # ================================================================================== - registry = get_registry(extended=True) is_fast_path = _is_1d_array(params) if not is_fast_path: - x, params_treedef = tree_flatten(params, registry=registry) + x, params_treedef = tree_flatten(params, namespace=extended) x = np.array(x, dtype=np.float64) if scaling_factor is not None and not np.isscalar(scaling_factor): scaling_factor = np.array( - tree_just_flatten(scaling_factor, registry=registry) + tree_just_flatten(scaling_factor, namespace=extended) ) if min_steps is not None and not np.isscalar(min_steps): - min_steps = np.array(tree_just_flatten(min_steps, registry=registry)) + min_steps = np.array(tree_just_flatten(min_steps, namespace=extended)) if step_size is not None and not np.isscalar(step_size): - step_size = np.array(tree_just_flatten(step_size, registry=registry)) + step_size = np.array(tree_just_flatten(step_size, namespace=extended)) else: x = params.astype(np.float64) @@ -289,7 +288,7 @@ def first_derivative( if not is_fast_path: evaluation_points = [ # entries are either a numpy.ndarray or np.nan - _unflatten_if_not_nan(p, params_treedef, registry) + _unflatten_if_not_nan(p, params_treedef, extended) for p in evaluation_points ] @@ -328,14 +327,14 @@ def first_derivative( elif vector_out: f0 = f0_tree.astype(float) else: - f0 = tree_leaves(f0_tree, registry=registry) + f0 = tree_leaves(f0_tree, namespace=extended) f0 = np.array(f0, dtype=np.float64) # convert the raw evaluations to numpy arrays raw_evals_arr = _convert_evals_to_numpy( raw_evals=raw_evals, unpacker=unpacker, - registry=registry, + namespace=extended, is_scalar_out=scalar_out, is_vector_out=vector_out, ) @@ -533,24 +532,23 @@ def second_derivative( # ================================================================================== # Convert scalar | pytree arguments to 1d arrays of floats # ================================================================================== - registry = get_registry(extended=True) is_fast_path = _is_1d_array(params) if not is_fast_path: - x, params_treedef = tree_flatten(params, registry=registry) + x, params_treedef = tree_flatten(params, namespace=extended) x = np.array(x, dtype=np.float64) if scaling_factor is not None and not np.isscalar(scaling_factor): scaling_factor = np.array( - tree_just_flatten(scaling_factor, registry=registry) + tree_just_flatten(scaling_factor, namespace=extended) ) if min_steps is not None and not np.isscalar(min_steps): - min_steps = np.array(tree_just_flatten(min_steps, registry=registry)) + min_steps = np.array(tree_just_flatten(min_steps, namespace=extended)) if step_size is not None and not np.isscalar(step_size): - step_size = np.array(tree_just_flatten(step_size, registry=registry)) + step_size = np.array(tree_just_flatten(step_size, namespace=extended)) else: x = params.astype(np.float64) @@ -626,7 +624,7 @@ def second_derivative( evaluation_points = { # entries are either a numpy.ndarray or np.nan, we unflatten only step_type: [ - _unflatten_if_not_nan(p, params_treedef, registry) for p in points + _unflatten_if_not_nan(p, params_treedef, extended) for p in points ] for step_type, points in evaluation_points.items() } @@ -665,13 +663,13 @@ def second_derivative( func_value = f0 f0_tree = unpacker(f0) - f0 = tree_leaves(f0_tree, registry=registry) + f0 = tree_leaves(f0_tree, namespace=extended) f0 = np.array(f0, dtype=np.float64) # convert the raw evaluations to numpy arrays raw_evals = { step_type: _convert_evals_to_numpy( - raw_evals=evals, unpacker=unpacker, registry=registry + raw_evals=evals, unpacker=unpacker, namespace=extended ) for step_type, evals in raw_evals.items() } @@ -925,7 +923,7 @@ def _convert_richardson_candidates_to_frame(jac, err): def _convert_evals_to_numpy( - raw_evals, unpacker, registry, is_scalar_out=False, is_vector_out=False + raw_evals, unpacker, namespace, is_scalar_out=False, is_vector_out=False ): """Harmonize the output of the function evaluations. @@ -949,7 +947,7 @@ def _convert_evals_to_numpy( else: evals = [ ( - np.array(tree_leaves(val, registry=registry), dtype=np.float64) + np.array(tree_leaves(val, namespace=namespace), dtype=np.float64) if not _is_scalar_nan(val) else val ) @@ -1208,9 +1206,9 @@ def _is_scalar_nan(value): return isinstance(value, float) and np.isnan(value) -def _unflatten_if_not_nan(leaves, treedef, registry): +def _unflatten_if_not_nan(leaves, treedef, namespace): if isinstance(leaves, np.ndarray): - out = tree_unflatten(treedef, leaves, registry=registry) + out = tree_unflatten(treedef, leaves, namespace=namespace) else: out = leaves return out diff --git a/src/optimagic/examples/criterion_functions.py b/src/optimagic/examples/criterion_functions.py index ae3708143..52f9dabf8 100644 --- a/src/optimagic/examples/criterion_functions.py +++ b/src/optimagic/examples/criterion_functions.py @@ -17,14 +17,12 @@ ) from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_just_flatten, tree_unflatten, ) from optimagic.typing import PyTree -REGISTRY = get_registry(extended=True) - @mark.scalar def trid_scalar(params: PyTree) -> float: @@ -217,11 +215,10 @@ def _get_x(params: PyTree) -> NDArray[np.float64]: if isinstance(params, np.ndarray) and params.ndim == 1: x = params.astype(float) else: - registry = get_registry(extended=True) - x = np.array(tree_just_flatten(params, registry=registry), dtype=np.float64) + x = np.array(tree_just_flatten(params, namespace=extended), dtype=np.float64) return x def _unflatten_gradient(flat: NDArray[np.float64], params: PyTree) -> PyTree: - out = tree_unflatten(params, flat.tolist(), registry=REGISTRY) + out = tree_unflatten(params, flat.tolist(), namespace=extended) return out diff --git a/src/optimagic/optimization/fun_value.py b/src/optimagic/optimization/fun_value.py index 3cca7cbc2..9672ba46c 100644 --- a/src/optimagic/optimization/fun_value.py +++ b/src/optimagic/optimization/fun_value.py @@ -7,7 +7,7 @@ from numpy.typing import NDArray from optimagic.exceptions import InvalidFunctionError -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten from optimagic.typing import AggregationLevel, PyTree, Scalar from optimagic.utilities import isscalar @@ -123,8 +123,7 @@ def _get_flat_value(value: PyTree) -> NDArray[np.float64]: elif isinstance(value, np.ndarray): flat = value.flatten() else: - registry = get_registry(extended=True) - flat = tree_just_flatten(value, registry=registry) + flat = tree_just_flatten(value, namespace=extended) flat_arr = np.asarray(flat, dtype=np.float64) return flat_arr diff --git a/src/optimagic/optimization/history.py b/src/optimagic/optimization/history.py index a3e2069d2..6744c5cef 100644 --- a/src/optimagic/optimization/history.py +++ b/src/optimagic/optimization/history.py @@ -8,7 +8,7 @@ from numpy.typing import NDArray from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, tree_just_flatten, ) @@ -401,8 +401,7 @@ def _get_flat_params(params: list[PyTree]) -> list[list[float]]: if fast_path: flatten = lambda x: x.tolist() else: - registry = get_registry(extended=True) - flatten = partial(tree_just_flatten, registry=registry) + flatten = partial(tree_just_flatten, namespace=extended) return [flatten(p) for p in params] @@ -414,8 +413,7 @@ def _get_flat_param_names(param: PyTree) -> list[str]: # arrays, but the fast path is only taken for 1d arrays, so it can be ignored. return np.arange(param.size).astype(str).tolist() - registry = get_registry(extended=True) - return leaf_names(param, registry=registry) + return leaf_names(param, namespace=extended) def _is_1d_array(param: PyTree) -> bool: diff --git a/src/optimagic/parameters/block_trees.py b/src/optimagic/parameters/block_trees.py index 9e34b842b..f3c620088 100644 --- a/src/optimagic/parameters/block_trees.py +++ b/src/optimagic/parameters/block_trees.py @@ -4,7 +4,7 @@ import pandas as pd from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_flatten, tree_unflatten, ) @@ -332,9 +332,8 @@ def _is_pd_object(obj): def _check_dimensions_matrix(matrix, outer_tree, inner_tree): - extended_registry = get_registry(extended=True) - flat_outer = tree_leaves(outer_tree, registry=extended_registry) - flat_inner = tree_leaves(inner_tree, registry=extended_registry) + flat_outer = tree_leaves(outer_tree, namespace=extended) + flat_inner = tree_leaves(inner_tree, namespace=extended) if matrix.shape[0] != len(flat_outer): raise ValueError("First dimension of matrix does not match that of outer_tree.") @@ -345,9 +344,8 @@ def _check_dimensions_matrix(matrix, outer_tree, inner_tree): def _check_dimensions_hessian(hessian, f_tree, params_tree): - extended_registry = get_registry(extended=True) - flat_f = tree_leaves(f_tree, registry=extended_registry) - flat_p = tree_leaves(params_tree, registry=extended_registry) + flat_f = tree_leaves(f_tree, namespace=extended) + flat_p = tree_leaves(params_tree, namespace=extended) if len(flat_f) == 1: # consider only dimensions with non trivial size (larger than 1) diff --git a/src/optimagic/parameters/bounds.py b/src/optimagic/parameters/bounds.py index eed12a819..4e5efcad3 100644 --- a/src/optimagic/parameters/bounds.py +++ b/src/optimagic/parameters/bounds.py @@ -9,7 +9,7 @@ from optimagic.exceptions import InvalidBoundsError from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, set_data_col_df_attribute, tree_map, @@ -17,7 +17,7 @@ from optimagic.parameters.tree_registry import ( tree_just_flatten as tree_leaves, ) -from optimagic.typing import PyTree, PyTreeRegistry +from optimagic.typing import PyTree from optimagic.utilities import fast_numpy_full @@ -81,7 +81,7 @@ def _process_bounds_sequence(bounds: Sequence[tuple[float, float]]) -> Bounds: def get_internal_bounds( params: PyTree, bounds: Bounds | None = None, - registry: PyTreeRegistry | None = None, + namespace: str = extended, add_soft_bounds: bool = False, ) -> tuple[NDArray[np.float64] | None, NDArray[np.float64] | None]: """Create consolidated and flattened bounds for params. @@ -98,7 +98,7 @@ def get_internal_bounds( Args: params: The parameter pytree. bounds: The lower and upper bounds. - registry: pybaum registry. + namespace: optree namespace. add_soft_bounds: If True, the element-wise maximum (minimum) of the lower and soft_lower (upper and soft_upper) bounds are taken. If False, the lower (upper) bounds are returned. @@ -122,12 +122,11 @@ def get_internal_bounds( # None-valued bounds are replaced with arrays of np.inf and -np.inf, and then # translated back to None if all entries are non-finite. - registry = get_registry(extended=True) if registry is None else registry - n_params = len(tree_leaves(params, registry=registry)) + n_params = len(tree_leaves(params, namespace=namespace)) # Fill leaves with np.nan. If params contains a data frame with bounds as a column, # that column is NOT overwritten (as long as an extended registry is used). - nan_tree = tree_map(lambda leaf: np.nan, params, registry=registry) # noqa: ARG005 + nan_tree = tree_map(lambda leaf: np.nan, params, namespace=namespace) # noqa: ARG005 lower_flat = _update_bounds_and_flatten(nan_tree, bounds.lower, kind="lower_bound") upper_flat = _update_bounds_and_flatten(nan_tree, bounds.upper, kind="upper_bound") @@ -182,18 +181,15 @@ def _update_bounds_and_flatten( np.ndarray: The updated and flattened bounds. """ - registry = get_registry(extended=True, data_col=kind) flat_nan_tree = tree_leaves( - set_data_col_df_attribute(nan_tree, data_col=kind), registry=registry + set_data_col_df_attribute(nan_tree, data_col=kind), namespace=extended ) - if bounds is not None: - registry = get_registry(extended=True) - flat_bounds = tree_leaves(bounds, registry=registry) + flat_bounds = tree_leaves(bounds, namespace=extended) seperator = 10 * "$" - params_names = leaf_names(nan_tree, registry=registry, separator=seperator) - bounds_names = leaf_names(bounds, registry=registry, separator=seperator) + params_names = leaf_names(nan_tree, namespace=extended, separator=seperator) + bounds_names = leaf_names(bounds, namespace=extended, separator=seperator) flat_nan_dict = dict(zip(params_names, flat_nan_tree, strict=False)) diff --git a/src/optimagic/parameters/nonlinear_constraints.py b/src/optimagic/parameters/nonlinear_constraints.py index 1d8a022ae..6af14bf6d 100644 --- a/src/optimagic/parameters/nonlinear_constraints.py +++ b/src/optimagic/parameters/nonlinear_constraints.py @@ -10,7 +10,7 @@ from optimagic.optimization.algo_options import CONSTRAINTS_ABSOLUTE_TOLERANCE from optimagic.parameters.block_trees import block_tree_to_matrix from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_flatten, tree_just_flatten, tree_unflatten, @@ -365,14 +365,13 @@ def _extend_jacobian(jac_mat, selection_indices, n_params): def _get_selection_indices(params, selector): """Get index of selected flat params and number of flat params.""" - registry = get_registry(extended=True) - flat_params, params_treedef = tree_flatten(params, registry=registry) + flat_params, params_treedef = tree_flatten(params, namespace=extended) n_params = len(flat_params) indices = np.arange(n_params, dtype=int) - params_indices = tree_unflatten(params_treedef, indices, registry=registry) + params_indices = tree_unflatten(params_treedef, indices, namespace=extended) selected = selector(params_indices) selection_indices = np.array( - tree_just_flatten(selected, registry=registry), dtype=int + tree_just_flatten(selected, namespace=extended), dtype=int ) return selection_indices, n_params diff --git a/src/optimagic/parameters/process_selectors.py b/src/optimagic/parameters/process_selectors.py index 8b05a3dab..3c44ce25c 100644 --- a/src/optimagic/parameters/process_selectors.py +++ b/src/optimagic/parameters/process_selectors.py @@ -6,7 +6,7 @@ from optimagic.constraints import Constraint from optimagic.exceptions import InvalidConstraintError -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten def process_selectors(constraints, params, tree_converter, param_names): @@ -36,7 +36,6 @@ def process_selectors(constraints, params, tree_converter, param_names): if isinstance(constraints, dict): constraints = [constraints] - registry = get_registry(extended=True) n_params = len(tree_converter.params_flatten(params)) helper = tree_converter.params_unflatten(np.arange(n_params)) params_case = _get_params_case(params) @@ -52,7 +51,7 @@ def process_selectors(constraints, params, tree_converter, param_names): field=field, constraint=constr, params_case=params_case, - registry=registry, + namespace=extended, ) try: with warnings.catch_warnings(): @@ -135,19 +134,19 @@ def _get_selection_field(constraint, selector_case, params_case): return field -def _get_selection_evaluator(field, constraint, params_case, registry): +def _get_selection_evaluator(field, constraint, params_case, namespace): if field == "selector": def evaluator(params): raw = constraint["selector"](params) - flat = tree_just_flatten(raw, registry=registry) + flat = tree_just_flatten(raw, namespace=namespace) return flat elif field == "selectors": def evaluator(params): raw = [sel(params) for sel in constraint["selectors"]] - flat = [tree_just_flatten(r, registry=registry) for r in raw] + flat = [tree_just_flatten(r, namespace=namespace) for r in raw] return flat elif field == "loc": diff --git a/src/optimagic/parameters/tree_conversion.py b/src/optimagic/parameters/tree_conversion.py index 0dddd328e..1dcbabd9b 100644 --- a/src/optimagic/parameters/tree_conversion.py +++ b/src/optimagic/parameters/tree_conversion.py @@ -6,7 +6,7 @@ from optimagic.parameters.block_trees import block_tree_to_matrix from optimagic.parameters.bounds import get_internal_bounds from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, tree_flatten, tree_just_flatten, @@ -50,26 +50,25 @@ def get_tree_converter( FlatParams: NamedTuple of 1d arrays with flattened bounds and param names. """ - _registry = get_registry(extended=True) - _params_vec, _params_treedef = tree_flatten(params, registry=_registry) + _params_vec, _params_treedef = tree_flatten(params, namespace=extended) _params_vec = np.array(_params_vec).astype(float) _lower, _upper = get_internal_bounds( params=params, bounds=bounds, - registry=_registry, + namespace=extended, ) if add_soft_bounds: _soft_lower, _soft_upper = get_internal_bounds( params=params, bounds=bounds, - registry=_registry, + namespace=extended, add_soft_bounds=add_soft_bounds, ) else: _soft_lower, _soft_upper = None, None - _param_names = leaf_names(params, registry=_registry) + _param_names = leaf_names(params, namespace=extended) flat_params = FlatParams( values=_params_vec, @@ -80,13 +79,13 @@ def get_tree_converter( soft_upper_bounds=_soft_upper, ) - _params_flatten = _get_params_flatten(registry=_registry) + _params_flatten = _get_params_flatten(namespace=extended) _params_unflatten = _get_params_unflatten( - registry=_registry, treedef=_params_treedef + namespace=extended, treedef=_params_treedef ) _derivative_flatten = _get_derivative_flatten( - registry=_registry, + namespace=extended, solver_type=solver_type, params=params, func_eval=func_eval, @@ -102,16 +101,16 @@ def get_tree_converter( return converter, flat_params -def _get_params_flatten(registry): +def _get_params_flatten(namespace): def params_flatten(params): - return np.array(tree_just_flatten(params, registry=registry)).astype(float) + return np.array(tree_just_flatten(params, namespace=namespace)).astype(float) return params_flatten -def _get_params_unflatten(registry, treedef): +def _get_params_unflatten(namespace, treedef): def params_unflatten(x): - return tree_unflatten(treedef=treedef, leaves=list(x), registry=registry) + return tree_unflatten(treedef=treedef, leaves=list(x), namespace=namespace) return params_unflatten @@ -143,13 +142,13 @@ def _get_best_key_and_aggregator(needed_key, available_keys): return key, aggregate -def _get_derivative_flatten(registry, solver_type, params, func_eval, derivative_eval): +def _get_derivative_flatten(namespace, solver_type, params, func_eval, derivative_eval): # gradient case if solver_type == AggregationLevel.SCALAR: def derivative_flatten(derivative_eval): flat = np.array( - tree_just_flatten(derivative_eval, registry=registry) + tree_just_flatten(derivative_eval, namespace=namespace) ).astype(float) return flat diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index efa74dfa8..d182973c5 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,80 +1,16 @@ """Wrapper around pybaum get_registry to tailor it to optimagic.""" import itertools -from functools import partial from itertools import product import numpy as np import optree import pandas as pd from optree.pytree import PyTreeSpec -from pybaum import get_registry as get_pybaum_registry from optimagic.typing import extended_namespace -def get_registry(extended=False, data_col="value"): - """Return pytree registry. - - Special Rules - ------------- - If extended is True the registry contains pd.DataFrame. In optimagic a data frame - can represent a 1d object with extra information, instead of a 2d object. This is - only allowed for params data frames, in which case they contain a 'value' column. - The extra information of such an object can be accessed using the data_col argument. - By default the 'value' column is extracted. If data_col is not 'value' but the data - frame contains a 'value' column, a list of np.nan is returned. - - Args: - extended (bool): If True appends types 'numpy.ndarray', 'pandas.Series' and - 'pandas.DataFrame' to the registry. - data_col (str): This column is used as the data source in a data frame when - flattening and unflattening a pytree. Defaults to 'value'; see special rules - above for behavior with non-default values. - - Returns: - dict: The pytree registry. - - """ - types = ( - ["numpy.ndarray", "pandas.Series", "jax.numpy.ndarray"] if extended else None - ) - registry = get_pybaum_registry(types=types) - if extended: - registry[pd.DataFrame] = { - "flatten": partial(_flatten_df, data_col=data_col), - "unflatten": partial(_unflatten_df, data_col=data_col), - "names": _get_df_names, - } - return registry - - -def _flatten_df(df, data_col): - is_value_df = "value" in df - if is_value_df: - flat = df.get(data_col, default=np.full(len(df), np.nan)).tolist() - else: - flat = df.to_numpy().flatten().tolist() - - aux_data = { - "is_value_df": is_value_df, - "df": df, - } - return flat, aux_data - - -def _unflatten_df(aux_data, leaves, data_col): - if aux_data["is_value_df"]: - out = aux_data["df"].assign(**{data_col: leaves}) - else: - out = pd.DataFrame( - data=np.array(leaves).reshape(aux_data["df"].shape), - columns=aux_data["df"].columns, - index=aux_data["df"].index, - ) - return out - - def _get_df_names(df): index_strings = list(df.index.map(_index_element_to_string)) if "value" in df: @@ -95,32 +31,31 @@ def _index_element_to_string(element): return res_string -def tree_flatten(tree, is_leaf=None, registry=None): +def tree_flatten(tree, is_leaf=None, namespace=""): with optree.dict_insertion_ordered(True, namespace=extended_namespace): - return optree.tree_flatten( - tree, is_leaf=is_leaf, namespace=extended_namespace if registry else "" - ) + return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace) -def tree_just_flatten(tree, is_leaf=None, registry=None): - leaves, _ = tree_flatten(tree, is_leaf=is_leaf, registry=registry) +def tree_just_flatten(tree, is_leaf=None, namespace=""): + leaves, _ = tree_flatten(tree, is_leaf=is_leaf, namespace=namespace) return leaves -def tree_unflatten(treedef, leaves, is_leaf=None, registry=None): +extended = extended_namespace + + +def tree_unflatten(treedef, leaves, is_leaf=None, namespace=""): if not isinstance(treedef, PyTreeSpec): - _, treedef = tree_flatten(treedef, is_leaf=is_leaf, registry=registry) + _, treedef = tree_flatten(treedef, is_leaf=is_leaf, namespace=namespace) return optree.tree_unflatten(treespec=treedef, leaves=leaves) -def tree_map(func, tree, is_leaf=None, registry=None): - return optree.tree_map( - func, tree, is_leaf=is_leaf, namespace=extended_namespace if registry else "" - ) +def tree_map(func, tree, is_leaf=None, namespace=""): + return optree.tree_map(func, tree, is_leaf=is_leaf, namespace=namespace) -def leaf_names(tree, is_leaf=None, registry=None, separator="_"): - _, treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry) +def leaf_names(tree, is_leaf=None, namespace="", separator="_"): + _, treespec = tree_flatten(tree, is_leaf=is_leaf, namespace=namespace) paths = treespec.paths() return [separator.join(str(p) for p in path) for path in paths] @@ -158,7 +93,15 @@ def _flatten_df_optree(df): def _unflatten_df_optree(aux_data, leaves): data_col = aux_data["df"].attrs.get("data_col", "value") - return _unflatten_df(aux_data=aux_data, leaves=leaves, data_col=data_col) + if aux_data["is_value_df"]: + out = aux_data["df"].assign(**{data_col: leaves}) + else: + out = pd.DataFrame( + data=np.array(leaves).reshape(aux_data["df"].shape), + columns=aux_data["df"].columns, + index=aux_data["df"].index, + ) + return out optree.register_pytree_node( @@ -192,20 +135,22 @@ def _unflatten_df_optree(aux_data, leaves): EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b) -def tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None): +def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None): equality_checkers = ( EQUALITY_CHECKERS if equality_checkers is None else {**EQUALITY_CHECKERS, **equality_checkers} ) - first_flat, first_treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry) + first_flat, first_treespec = tree_flatten( + tree, is_leaf=is_leaf, namespace=namespace + ) second_flat, second_treespec = tree_flatten( - other, is_leaf=is_leaf, registry=registry + other, is_leaf=is_leaf, namespace=namespace ) - first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) - second_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) + first_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace) + second_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace) equal = first_names == second_names and first_treespec == second_treespec diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index f6089ec00..5fc618cbf 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -12,7 +12,7 @@ from optimagic.optimization.history import History from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, tree_flatten, tree_just_flatten, @@ -585,15 +585,13 @@ def _extract_params_plot_lines( history = data.history.params start_params = data.start_params - registry = get_registry(extended=True) - - hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T - names = leaf_names(start_params, registry=registry) + hist_arr = np.array([tree_just_flatten(p, namespace=extended) for p in history]).T + names = leaf_names(start_params, namespace=extended) if selector is not None: - flat, treedef = tree_flatten(start_params, registry=registry) - helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) - selected = np.array(tree_just_flatten(selector(helper), registry=registry)) + flat, treedef = tree_flatten(start_params, namespace=extended) + helper = tree_unflatten(treedef, list(range(len(flat))), namespace=extended) + selected = np.array(tree_just_flatten(selector(helper), namespace=extended)) names = [names[i] for i in selected] hist_arr = hist_arr[selected] diff --git a/src/optimagic/visualization/slice_plot.py b/src/optimagic/visualization/slice_plot.py index fa0581692..7e42c06d1 100644 --- a/src/optimagic/visualization/slice_plot.py +++ b/src/optimagic/visualization/slice_plot.py @@ -23,7 +23,7 @@ from optimagic.parameters.bounds import pre_process_bounds from optimagic.parameters.conversion import get_converter from optimagic.parameters.space_conversion import InternalParams -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten from optimagic.shared.process_user_function import infer_aggregation_level from optimagic.typing import AggregationLevel, PyTree from optimagic.visualization.backends import grid_line_plot, line_plot @@ -248,9 +248,8 @@ def _get_plot_data( selected = np.arange(n_params, dtype=int) if selector is not None: helper = converter.params_from_internal(selected) - registry = get_registry(extended=True) selected = np.array( - tree_just_flatten(selector(helper), registry=registry), dtype=int + tree_just_flatten(selector(helper), namespace=extended), dtype=int ).ravel() # Ensure the result is a 1D array if not np.isfinite(internal_params.lower_bounds[selected]).all(): diff --git a/src/optimagic/visualization/slice_plot_3d.py b/src/optimagic/visualization/slice_plot_3d.py index 50f3f175b..f46c2447a 100644 --- a/src/optimagic/visualization/slice_plot_3d.py +++ b/src/optimagic/visualization/slice_plot_3d.py @@ -19,7 +19,7 @@ ) from optimagic.parameters.bounds import pre_process_bounds from optimagic.parameters.conversion import get_converter -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten from optimagic.shared.process_user_function import infer_aggregation_level from optimagic.typing import AggregationLevel @@ -149,9 +149,8 @@ def slice_plot_3d( # type: ignore[no-untyped-def] selected = np.arange(n_params, dtype=int) if selector is not None: helper = converter.params_from_internal(selected) - registry = get_registry(extended=True) selected = np.array( - tree_just_flatten(selector(helper), registry=registry), dtype=int + tree_just_flatten(selector(helper), namespace=extended), dtype=int ).reshape(-1) n_params = len(selected) if not np.isfinite(internal_params.lower_bounds[selected]).all(): diff --git a/tests/estimagic/test_bootstrap_ci.py b/tests/estimagic/test_bootstrap_ci.py index 8610d348a..801faaebb 100644 --- a/tests/estimagic/test_bootstrap_ci.py +++ b/tests/estimagic/test_bootstrap_ci.py @@ -6,7 +6,7 @@ from estimagic.bootstrap_ci import calculate_ci, check_inputs from estimagic.bootstrap_samples import get_bootstrap_indices -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten from optimagic.utilities import get_rng @@ -68,10 +68,9 @@ def _outcome_func_arr(data): @pytest.mark.parametrize("outcome, method", TEST_CASES) def test_ci(outcome, method, setup, expected): - registry = get_registry(extended=True) def outcome_flat(data): - return tree_just_flatten(outcome(data), registry=registry) + return tree_just_flatten(outcome(data), namespace=extended) base_outcome = outcome_flat(setup["df"]) lower, upper = calculate_ci(base_outcome, setup["estimates"], ci_method=method) diff --git a/tests/estimagic/test_estimate_msm_dict_params_and_moments.py b/tests/estimagic/test_estimate_msm_dict_params_and_moments.py index 79a6314e6..2cb2bacbb 100644 --- a/tests/estimagic/test_estimate_msm_dict_params_and_moments.py +++ b/tests/estimagic/test_estimate_msm_dict_params_and_moments.py @@ -5,7 +5,7 @@ from numpy.testing import assert_array_almost_equal as aaae from estimagic.estimate_msm import estimate_msm -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten def test_estimate_msm_dict_params_and_moments(): @@ -97,8 +97,7 @@ def assert_almost_equal(x, y, decimal=6): x_flat = x y_flat = y else: - registry = get_registry(extended=True) - x_flat = np.array(tree_just_flatten(x, registry=registry)) - y_flat = np.array(tree_just_flatten(x, registry=registry)) + x_flat = np.array(tree_just_flatten(x, namespace=extended)) + y_flat = np.array(tree_just_flatten(x, namespace=extended)) aaae(x_flat, y_flat, decimal=decimal) diff --git a/tests/estimagic/test_shared.py b/tests/estimagic/test_shared.py index fded1a3d9..3ddd3a485 100644 --- a/tests/estimagic/test_shared.py +++ b/tests/estimagic/test_shared.py @@ -14,7 +14,7 @@ transform_free_cov_to_cov, transform_free_values_to_params_tree, ) -from optimagic.parameters.tree_registry import get_registry, leaf_names, tree_equal +from optimagic.parameters.tree_registry import extended, leaf_names, tree_equal from optimagic.utilities import get_rng @@ -239,8 +239,7 @@ def test_calculate_estimation_summary(): "free": np.array([True, True, True]), } - registry = get_registry(extended=True) - names = leaf_names(summary_data["value"], registry=registry) + names = leaf_names(summary_data["value"], namespace=extended) free_names = names # function call diff --git a/tests/optimagic/logging/test_logger.py b/tests/optimagic/logging/test_logger.py index 35e9e6c9f..86fa29af8 100644 --- a/tests/optimagic/logging/test_logger.py +++ b/tests/optimagic/logging/test_logger.py @@ -13,7 +13,7 @@ ) from optimagic.optimization.optimize import minimize from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_equal, tree_just_flatten, ) @@ -84,10 +84,9 @@ def test_log_reader_read_multistart_history(example_db): assert local_history is None assert exploration is None - registry = get_registry(extended=True) assert tree_equal( - tree_just_flatten(asdict(history), registry=registry), - tree_just_flatten(asdict(reader.read_history()), registry=registry), + tree_just_flatten(asdict(history), namespace=extended), + tree_just_flatten(asdict(reader.read_history()), namespace=extended), ) diff --git a/tests/optimagic/optimization/test_params_versions.py b/tests/optimagic/optimization/test_params_versions.py index 39bbaac82..52db60eed 100644 --- a/tests/optimagic/optimization/test_params_versions.py +++ b/tests/optimagic/optimization/test_params_versions.py @@ -10,9 +10,7 @@ sos_scalar, ) from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten - -REGISTRY = get_registry(extended=True) +from optimagic.parameters.tree_registry import extended, tree_just_flatten PARAMS = [ {"a": 1.0, "b": 2, "c": 3, "d": 4, "e": 5}, @@ -28,7 +26,7 @@ @pytest.mark.parametrize("params", PARAMS + SCALAR_PARAMS) def test_tree_params_numerical_derivative_scalar_criterion(params): - flat = np.array(tree_just_flatten(params, registry=REGISTRY)) + flat = np.array(tree_just_flatten(params, namespace=extended)) expected = np.zeros_like(flat) res = minimize( @@ -36,13 +34,13 @@ def test_tree_params_numerical_derivative_scalar_criterion(params): params=params, algorithm="scipy_lbfgsb", ) - calculated = np.array(tree_just_flatten(res.params, registry=REGISTRY)) + calculated = np.array(tree_just_flatten(res.params, namespace=extended)) aaae(calculated, expected) @pytest.mark.parametrize("params", PARAMS + SCALAR_PARAMS) def test_tree_params_scalar_criterion(params): - flat = np.array(tree_just_flatten(params, registry=REGISTRY)) + flat = np.array(tree_just_flatten(params, namespace=extended)) expected = np.zeros_like(flat) res = minimize( @@ -51,7 +49,7 @@ def test_tree_params_scalar_criterion(params): params=params, algorithm="scipy_lbfgsb", ) - calculated = np.array(tree_just_flatten(res.params, registry=REGISTRY)) + calculated = np.array(tree_just_flatten(res.params, namespace=extended)) aaae(calculated, expected) @@ -63,7 +61,7 @@ def test_tree_params_scalar_criterion(params): @pytest.mark.parametrize("params, algorithm", TEST_CASES_SOS_LS) def test_tree_params_numerical_derivative_sos_ls(params, algorithm): - flat = np.array(tree_just_flatten(params, registry=REGISTRY)) + flat = np.array(tree_just_flatten(params, namespace=extended)) expected = np.zeros_like(flat) res = minimize( @@ -71,13 +69,13 @@ def test_tree_params_numerical_derivative_sos_ls(params, algorithm): params=params, algorithm=algorithm, ) - calculated = np.array(tree_just_flatten(res.params, registry=REGISTRY)) + calculated = np.array(tree_just_flatten(res.params, namespace=extended)) aaae(calculated, expected) @pytest.mark.parametrize("params, algorithm", TEST_CASES_SOS_LS) def test_tree_params_sos_ls(params, algorithm): - flat = np.array(tree_just_flatten(params, registry=REGISTRY)) + flat = np.array(tree_just_flatten(params, namespace=extended)) expected = np.zeros_like(flat) derivatives = [sos_gradient, sos_ls_jacobian] @@ -87,5 +85,5 @@ def test_tree_params_sos_ls(params, algorithm): params=params, algorithm=algorithm, ) - calculated = np.array(tree_just_flatten(res.params, registry=REGISTRY)) + calculated = np.array(tree_just_flatten(res.params, namespace=extended)) aaae(calculated, expected) diff --git a/tests/optimagic/optimization/test_with_logging.py b/tests/optimagic/optimization/test_with_logging.py index d8cc7bd35..9abc05bf4 100644 --- a/tests/optimagic/optimization/test_with_logging.py +++ b/tests/optimagic/optimization/test_with_logging.py @@ -20,7 +20,7 @@ from optimagic.logging.logger import SQLiteLogOptions from optimagic.logging.types import ExistenceStrategy from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten @mark.least_squares @@ -46,8 +46,7 @@ def test_optimization_with_valid_logging(algorithm, params): algorithm=algorithm, logging="logging.db", ) - registry = get_registry(extended=True) - flat = np.array(tree_just_flatten(res.params, registry=registry)) + flat = np.array(tree_just_flatten(res.params, namespace=extended)) aaae(flat, np.zeros(3)) diff --git a/tests/optimagic/parameters/test_block_trees.py b/tests/optimagic/parameters/test_block_trees.py index 30c991c19..95703f750 100644 --- a/tests/optimagic/parameters/test_block_trees.py +++ b/tests/optimagic/parameters/test_block_trees.py @@ -10,7 +10,7 @@ hessian_to_block_tree, matrix_to_block_tree, ) -from optimagic.parameters.tree_registry import get_registry, tree_equal +from optimagic.parameters.tree_registry import extended, tree_equal from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves @@ -128,9 +128,8 @@ def test_block_tree_to_hessian_bijection(): params = {"a": np.arange(4), "b": [{"c": (1, 2), "d": np.array([5, 6])}]} f_tree = {"e": np.arange(3), "f": (5, 6, [7, 8, {"g": 1.0}])} - registry = get_registry(extended=True) - n_p = len(tree_leaves(params, registry=registry)) - n_f = len(tree_leaves(f_tree, registry=registry)) + n_p = len(tree_leaves(params, namespace=extended)) + n_f = len(tree_leaves(f_tree, namespace=extended)) expected = np.arange(n_f * n_p**2).reshape(n_f, n_p, n_p) block_hessian = hessian_to_block_tree(expected, f_tree, params) diff --git a/tests/optimagic/parameters/test_nonlinear_constraints.py b/tests/optimagic/parameters/test_nonlinear_constraints.py index 6e428ed71..1209658f7 100644 --- a/tests/optimagic/parameters/test_nonlinear_constraints.py +++ b/tests/optimagic/parameters/test_nonlinear_constraints.py @@ -21,7 +21,7 @@ process_nonlinear_constraints, vector_as_list_of_scalar_constraints, ) -from optimagic.parameters.tree_registry import get_registry, tree_just_flatten +from optimagic.parameters.tree_registry import extended, tree_just_flatten @dataclass @@ -30,8 +30,7 @@ def params_from_internal(self, x): return x def params_to_internal(self, params): - registry = get_registry(extended=True) - return np.array(tree_just_flatten(params, registry=registry)) + return np.array(tree_just_flatten(params, namespace=extended)) # ====================================================================================== diff --git a/tests/optimagic/parameters/test_process_selectors.py b/tests/optimagic/parameters/test_process_selectors.py index 726864a5f..805e8a356 100644 --- a/tests/optimagic/parameters/test_process_selectors.py +++ b/tests/optimagic/parameters/test_process_selectors.py @@ -7,7 +7,7 @@ from optimagic.parameters.process_selectors import process_selectors from optimagic.parameters.tree_conversion import TreeConverter from optimagic.parameters.tree_registry import ( - get_registry, + extended, tree_flatten, tree_just_flatten, tree_unflatten, @@ -35,15 +35,14 @@ def tree_params(): @pytest.fixture() def tree_params_converter(tree_params): - registry = get_registry(extended=True) - _, treedef = tree_flatten(tree_params, registry=registry) + _, treedef = tree_flatten(tree_params, namespace=extended) converter = TreeConverter( params_flatten=lambda params: np.array( - tree_just_flatten(params, registry=registry) + tree_just_flatten(params, namespace=extended) ), params_unflatten=lambda x: tree_unflatten( - treedef, x.tolist(), registry=registry + treedef, x.tolist(), namespace=extended ), derivative_flatten=None, ) diff --git a/tests/optimagic/parameters/test_tree_registry.py b/tests/optimagic/parameters/test_tree_registry.py index 600c93ad4..cb7d2e2da 100644 --- a/tests/optimagic/parameters/test_tree_registry.py +++ b/tests/optimagic/parameters/test_tree_registry.py @@ -4,7 +4,7 @@ from pandas.testing import assert_frame_equal from optimagic.parameters.tree_registry import ( - get_registry, + extended, leaf_names, set_data_col_df_attribute, tree_flatten, @@ -31,41 +31,35 @@ def other_df(): def test_flatten_df_with_value_column(value_df): - registry = get_registry(extended=True) - flat, _ = tree_flatten(value_df, registry=registry) + flat, _ = tree_flatten(value_df, namespace=extended) assert flat == [1, 3, 5] def test_unflatten_df_with_value_column(value_df): - registry = get_registry(extended=True) - _, treedef = tree_flatten(value_df, registry=registry) - unflat = tree_unflatten(treedef, [10, 11, 12], registry=registry) + _, treedef = tree_flatten(value_df, namespace=extended) + unflat = tree_unflatten(treedef, [10, 11, 12], namespace=extended) assert unflat.equals(value_df.assign(value=[10, 11, 12])) def test_leaf_names_df_with_value_column(value_df): - registry = get_registry(extended=True) - names = leaf_names(value_df, registry=registry) + names = leaf_names(value_df, namespace=extended) assert names == ["alpha", "beta", "gamma"] def test_flatten_partially_numeric_df(other_df): - registry = get_registry(extended=True) - flat, _ = tree_flatten(other_df, registry=registry) + flat, _ = tree_flatten(other_df, namespace=extended) assert flat == [0, 3.14, 1, 3.14, 2, 3.14] def test_unflatten_partially_numeric_df(other_df): - registry = get_registry(extended=True) - _, treedef = tree_flatten(other_df, registry=registry) - unflat = tree_unflatten(treedef, [1, 2, 3, 4, 5, 6], registry=registry) + _, treedef = tree_flatten(other_df, namespace=extended) + unflat = tree_unflatten(treedef, [1, 2, 3, 4, 5, 6], namespace=extended) other_df = other_df.assign(b=[1, 3, 5], c=[2, 4, 6]) assert_frame_equal(unflat, other_df, check_dtype=False) def test_leaf_names_partially_numeric_df(other_df): - registry = get_registry(extended=True) - names = leaf_names(other_df, registry=registry) + names = leaf_names(other_df, namespace=extended) assert names == ["alpha_b", "alpha_c", "beta_b", "beta_c", "gamma_b", "gamma_c"] @@ -76,22 +70,19 @@ def test_set_data_col_attribute_assigns_attribute(value_df): def test_set_data_col_attribute_unflattened_tree_has_attribute(value_df): - registry = get_registry(extended=True) df = set_data_col_df_attribute(value_df, data_col="attr") - tree, treedef = tree_flatten(df, registry=registry) + tree, treedef = tree_flatten(df, namespace=extended) df = tree_unflatten(treedef, tree) assert df.attrs.get("data_col") == "attr" def test_set_data_col_attribute_returns_nan(value_df): - registry = get_registry(extended=True) df = set_data_col_df_attribute(value_df, data_col="attr") - tree, treedef = tree_flatten(df, registry=registry) + tree, treedef = tree_flatten(df, namespace=extended) assert all(np.isnan(value) for value in tree) def test_set_data_col_attribute_returs_column_values(value_df): - registry = get_registry(extended=True) df = set_data_col_df_attribute(value_df, data_col="a") - tree, treedef = tree_flatten(df, registry=registry) + tree, treedef = tree_flatten(df, namespace=extended) assert tree == [0, 2, 4] From 9d8c744c21141fa7b29d6819ec467543bf265948 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 12:33:31 +0200 Subject: [PATCH 12/19] chore: use namespaces for passing data_col value for dataframes --- src/optimagic/parameters/bounds.py | 5 +- src/optimagic/parameters/tree_registry.py | 185 ++++++++++-------- src/optimagic/typing.py | 1 - .../parameters/test_tree_registry.py | 26 --- 4 files changed, 101 insertions(+), 116 deletions(-) diff --git a/src/optimagic/parameters/bounds.py b/src/optimagic/parameters/bounds.py index 4e5efcad3..2c909271f 100644 --- a/src/optimagic/parameters/bounds.py +++ b/src/optimagic/parameters/bounds.py @@ -11,7 +11,6 @@ from optimagic.parameters.tree_registry import ( extended, leaf_names, - set_data_col_df_attribute, tree_map, ) from optimagic.parameters.tree_registry import ( @@ -181,9 +180,7 @@ def _update_bounds_and_flatten( np.ndarray: The updated and flattened bounds. """ - flat_nan_tree = tree_leaves( - set_data_col_df_attribute(nan_tree, data_col=kind), namespace=extended - ) + flat_nan_tree = tree_leaves(nan_tree, namespace=kind) if bounds is not None: flat_bounds = tree_leaves(bounds, namespace=extended) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index d182973c5..695d2d3aa 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,6 +1,7 @@ """Wrapper around pybaum get_registry to tailor it to optimagic.""" import itertools +from functools import partial from itertools import product import numpy as np @@ -8,31 +9,26 @@ import pandas as pd from optree.pytree import PyTreeSpec -from optimagic.typing import extended_namespace +extended = "value" +namespaces = [ + extended, + "lower_bound", + "upper_bound", + "soft_lower_bound", + "soft_upper_bound", +] - -def _get_df_names(df): - index_strings = list(df.index.map(_index_element_to_string)) - if "value" in df: - out = index_strings - else: - out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] - - return out - - -def _index_element_to_string(element): - if isinstance(element, (tuple, list)): - as_strings = [str(entry) for entry in element] - res_string = "_".join(as_strings) - else: - res_string = str(element) - - return res_string +EQUALITY_CHECKERS = {} +EQUALITY_CHECKERS[np.ndarray.__name__] = lambda a, b: bool((a == b).all()) +EQUALITY_CHECKERS[pd.Series.__name__] = lambda a, b: a.equals(b) +EQUALITY_CHECKERS[pd.DataFrame.__name__] = lambda a, b: a.equals(b) def tree_flatten(tree, is_leaf=None, namespace=""): - with optree.dict_insertion_ordered(True, namespace=extended_namespace): + if namespace: + with optree.dict_insertion_ordered(True, namespace=namespace): + return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace) + else: return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace) @@ -41,9 +37,6 @@ def tree_just_flatten(tree, is_leaf=None, namespace=""): return leaves -extended = extended_namespace - - def tree_unflatten(treedef, leaves, is_leaf=None, namespace=""): if not isinstance(treedef, PyTreeSpec): _, treedef = tree_flatten(treedef, is_leaf=is_leaf, namespace=namespace) @@ -60,14 +53,35 @@ def leaf_names(tree, is_leaf=None, namespace="", separator="_"): return [separator.join(str(p) for p in path) for path in paths] -def set_data_col_df_attribute(tree, data_col): - def set_attr(node): - if isinstance(node, pd.DataFrame): - node = node.copy() - node.attrs["data_col"] = data_col - return node +def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None): + equality_checkers = ( + EQUALITY_CHECKERS + if equality_checkers is None + else {**EQUALITY_CHECKERS, **equality_checkers} + ) + + first_flat, first_treespec = tree_flatten( + tree, is_leaf=is_leaf, namespace=namespace + ) + second_flat, second_treespec = tree_flatten( + other, is_leaf=is_leaf, namespace=namespace + ) - return tree_map(set_attr, tree) + first_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace) + second_names = leaf_names(other, is_leaf=is_leaf, namespace=namespace) + + equal = first_names == second_names and first_treespec == second_treespec + + if equal: + for first, second in zip(first_flat, second_flat, strict=True): + check_func = equality_checkers.get( + type(first).__name__, lambda a, b: a == b + ) + equal = equal and check_func(first, second) + if not equal: + break + + return equal def _array_element_names(arr): @@ -76,8 +90,27 @@ def _array_element_names(arr): return names -def _flatten_df_optree(df): - data_col = df.attrs.get("data_col", "value") +def _get_df_names(df): + index_strings = list(df.index.map(_index_element_to_string)) + if "value" in df: + out = index_strings + else: + out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] + + return out + + +def _index_element_to_string(element): + if isinstance(element, (tuple, list)): + as_strings = [str(entry) for entry in element] + res_string = "_".join(as_strings) + else: + res_string = str(element) + + return res_string + + +def _flatten_df(df, data_col): is_value_df = "value" in df if is_value_df: flat = df.get(data_col, default=np.full(len(df), np.nan)).tolist() @@ -91,8 +124,7 @@ def _flatten_df_optree(df): return flat, aux_data, _get_df_names(df) -def _unflatten_df_optree(aux_data, leaves): - data_col = aux_data["df"].attrs.get("data_col", "value") +def _unflatten_df(aux_data, leaves, data_col): if aux_data["is_value_df"]: out = aux_data["df"].assign(**{data_col: leaves}) else: @@ -104,61 +136,44 @@ def _unflatten_df_optree(aux_data, leaves): return out -optree.register_pytree_node( - pd.DataFrame, - _flatten_df_optree, - _unflatten_df_optree, - namespace=extended_namespace, -) - -optree.register_pytree_node( - pd.Series, - lambda sr: ( - sr.tolist(), - {"index": sr.index, "name": sr.name}, - list(sr.index.map(_index_element_to_string)), - ), - lambda aux_data, leaves: pd.Series(leaves, **aux_data), - namespace=extended_namespace, -) - -optree.register_pytree_node( - np.ndarray, - lambda arr: (arr.flatten().tolist(), arr.shape, _array_element_names(arr)), - lambda aux_data, leaves: np.array(leaves).reshape(aux_data), - namespace=extended_namespace, -) +def _flatten_series(series: pd.Series): + return ( + series.tolist(), + {"index": series.index, "name": series.name}, + list(series.index.map(_index_element_to_string)), + ) -EQUALITY_CHECKERS = {} -EQUALITY_CHECKERS[np.ndarray] = lambda a, b: bool((a == b).all()) -EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b) -EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b) +def _unflatten_series(aux_data, leaves): + return pd.Series(leaves, **aux_data) -def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None): - equality_checkers = ( - EQUALITY_CHECKERS - if equality_checkers is None - else {**EQUALITY_CHECKERS, **equality_checkers} - ) - first_flat, first_treespec = tree_flatten( - tree, is_leaf=is_leaf, namespace=namespace - ) - second_flat, second_treespec = tree_flatten( - other, is_leaf=is_leaf, namespace=namespace - ) +def _flatten_ndarray(arr: np.ndarray): + return arr.flatten().tolist(), arr.shape, _array_element_names(arr) - first_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace) - second_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace) - equal = first_names == second_names and first_treespec == second_treespec +def _unflatten_ndarray(aux_data, leaves): + return np.array(leaves).reshape(aux_data) - if equal: - for first, second in zip(first_flat, second_flat, strict=True): - check_func = equality_checkers.get(type(first), lambda a, b: a == b) - equal = equal and check_func(first, second) - if not equal: - break - return equal +for namespace in namespaces: + optree.register_pytree_node( + pd.DataFrame, + partial(_flatten_df, data_col=namespace), + partial(_unflatten_df, data_col=namespace), + namespace=namespace, + ) + + optree.register_pytree_node( + pd.Series, + _flatten_series, + _unflatten_series, + namespace=namespace, + ) + + optree.register_pytree_node( + np.ndarray, + _flatten_ndarray, + _unflatten_ndarray, + namespace=namespace, + ) diff --git a/src/optimagic/typing.py b/src/optimagic/typing.py index 795b98174..9b389ced2 100644 --- a/src/optimagic/typing.py +++ b/src/optimagic/typing.py @@ -22,7 +22,6 @@ Scalar = Any T = TypeVar("T") -extended_namespace = "extended_namespace" class AggregationLevel(Enum): diff --git a/tests/optimagic/parameters/test_tree_registry.py b/tests/optimagic/parameters/test_tree_registry.py index cb7d2e2da..150681624 100644 --- a/tests/optimagic/parameters/test_tree_registry.py +++ b/tests/optimagic/parameters/test_tree_registry.py @@ -6,7 +6,6 @@ from optimagic.parameters.tree_registry import ( extended, leaf_names, - set_data_col_df_attribute, tree_flatten, tree_unflatten, ) @@ -61,28 +60,3 @@ def test_unflatten_partially_numeric_df(other_df): def test_leaf_names_partially_numeric_df(other_df): names = leaf_names(other_df, namespace=extended) assert names == ["alpha_b", "alpha_c", "beta_b", "beta_c", "gamma_b", "gamma_c"] - - -def test_set_data_col_attribute_assigns_attribute(value_df): - df = set_data_col_df_attribute(value_df, data_col="attr") - assert df.attrs.get("data_col") == "attr" - assert value_df.attrs.get("data_col") is None - - -def test_set_data_col_attribute_unflattened_tree_has_attribute(value_df): - df = set_data_col_df_attribute(value_df, data_col="attr") - tree, treedef = tree_flatten(df, namespace=extended) - df = tree_unflatten(treedef, tree) - assert df.attrs.get("data_col") == "attr" - - -def test_set_data_col_attribute_returns_nan(value_df): - df = set_data_col_df_attribute(value_df, data_col="attr") - tree, treedef = tree_flatten(df, namespace=extended) - assert all(np.isnan(value) for value in tree) - - -def test_set_data_col_attribute_returs_column_values(value_df): - df = set_data_col_df_attribute(value_df, data_col="a") - tree, treedef = tree_flatten(df, namespace=extended) - assert tree == [0, 2, 4] From 94516da3d8214078d3cf1a8e67be332afe4361ed Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 13:28:26 +0200 Subject: [PATCH 13/19] chore: move namespaces list to typing.py --- src/estimagic/bootstrap.py | 23 ++++++----- src/estimagic/estimate_msm.py | 10 ++--- src/estimagic/msm_weighting.py | 5 ++- src/estimagic/shared_covs.py | 10 ++--- src/optimagic/benchmarking/run_benchmark.py | 7 ++-- src/optimagic/differentiation/derivatives.py | 40 +++++++++++-------- src/optimagic/examples/criterion_functions.py | 9 +++-- src/optimagic/optimization/fun_value.py | 6 +-- src/optimagic/optimization/history.py | 7 ++-- src/optimagic/parameters/block_trees.py | 10 ++--- src/optimagic/parameters/bounds.py | 15 ++++--- .../parameters/nonlinear_constraints.py | 8 ++-- src/optimagic/parameters/process_selectors.py | 5 ++- src/optimagic/parameters/tree_conversion.py | 17 ++++---- src/optimagic/parameters/tree_registry.py | 11 +---- src/optimagic/typing.py | 11 +++++ src/optimagic/visualization/history_plots.py | 19 +++++---- src/optimagic/visualization/slice_plot.py | 6 +-- src/optimagic/visualization/slice_plot_3d.py | 6 +-- tests/estimagic/test_bootstrap_ci.py | 5 ++- ...st_estimate_msm_dict_params_and_moments.py | 7 ++-- tests/estimagic/test_shared.py | 5 ++- tests/optimagic/logging/test_logger.py | 7 ++-- .../optimization/test_params_versions.py | 19 ++++----- .../optimization/test_with_logging.py | 5 ++- .../optimagic/parameters/test_block_trees.py | 7 ++-- .../parameters/test_nonlinear_constraints.py | 5 ++- .../parameters/test_process_selectors.py | 8 ++-- .../parameters/test_tree_registry.py | 18 ++++----- 29 files changed, 170 insertions(+), 141 deletions(-) diff --git a/src/estimagic/bootstrap.py b/src/estimagic/bootstrap.py index 776e49e4a..3844cfd83 100644 --- a/src/estimagic/bootstrap.py +++ b/src/estimagic/bootstrap.py @@ -13,12 +13,12 @@ from optimagic.batch_evaluators import joblib_batch_evaluator from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_flatten, tree_just_flatten, tree_unflatten, ) +from optimagic.typing import value_namespace from optimagic.utilities import get_rng @@ -108,7 +108,8 @@ def bootstrap( # ================================================================================== flat_outcomes = [ - tree_just_flatten(_outcome, namespace=extended) for _outcome in all_outcomes + tree_just_flatten(_outcome, namespace=value_namespace) + for _outcome in all_outcomes ] internal_outcomes = np.array(flat_outcomes) @@ -166,10 +167,10 @@ def outcomes(self): List[Any]: The boostrap outcomes as a list of pytrees. """ - _, treedef = tree_flatten(self._base_outcome, namespace=extended) + _, treedef = tree_flatten(self._base_outcome, namespace=value_namespace) outcomes = [ - tree_unflatten(treedef, out, namespace=extended) + tree_unflatten(treedef, out, namespace=value_namespace) for out in self._internal_outcomes ] return outcomes @@ -185,9 +186,9 @@ def se(self): cov = self._internal_cov se = np.sqrt(np.diagonal(cov)) - _, treedef = tree_flatten(self._base_outcome, namespace=extended) + _, treedef = tree_flatten(self._base_outcome, namespace=value_namespace) - se = tree_unflatten(treedef, se, namespace=extended) + se = tree_unflatten(treedef, se, namespace=value_namespace) return se def cov(self, return_type="pytree"): @@ -208,7 +209,7 @@ def cov(self, return_type="pytree"): cov = self._internal_cov if return_type == "dataframe": - names = np.array(leaf_names(self._base_outcome, namespace=extended)) + names = np.array(leaf_names(self._base_outcome, namespace=value_namespace)) cov = pd.DataFrame(cov, columns=names, index=names) elif return_type == "pytree": cov = matrix_to_block_tree(cov, self._base_outcome, self._base_outcome) @@ -236,15 +237,15 @@ def ci(self, ci_method="percentile", ci_level=0.95): """ base_outcome_flat, treedef = tree_flatten( - self._base_outcome, namespace=extended + self._base_outcome, namespace=value_namespace ) lower_flat, upper_flat = calculate_ci( base_outcome_flat, self._internal_outcomes, ci_method, ci_level ) - lower = tree_unflatten(treedef, lower_flat, namespace=extended) - upper = tree_unflatten(treedef, upper_flat, namespace=extended) + lower = tree_unflatten(treedef, lower_flat, namespace=value_namespace) + upper = tree_unflatten(treedef, upper_flat, namespace=value_namespace) return lower, upper def p_values(self): @@ -273,7 +274,7 @@ def summary(self, ci_method="percentile", ci_level=0.95): Soon this will be a pytree. """ - names = leaf_names(self.base_outcome, namespace=extended) + names = leaf_names(self.base_outcome, namespace=value_namespace) summary_data = _calulcate_summary_data_bootstrap( self, ci_method=ci_method, ci_level=ci_level ) diff --git a/src/estimagic/estimate_msm.py b/src/estimagic/estimate_msm.py index 5cf337799..7da185d1b 100644 --- a/src/estimagic/estimate_msm.py +++ b/src/estimagic/estimate_msm.py @@ -51,13 +51,13 @@ from optimagic.parameters.conversion import Converter, get_converter from optimagic.parameters.space_conversion import InternalParams from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_just_flatten, ) from optimagic.shared.check_option_dicts import ( check_optimization_options, ) +from optimagic.typing import value_namespace from optimagic.utilities import get_rng, to_pickle @@ -321,7 +321,7 @@ def func(x): sim_mom = simulate_moments(params, **simulate_moments_kwargs) if isinstance(sim_mom, dict) and "simulated_moments" in sim_mom: sim_mom = sim_mom["simulated_moments"] - out = np.array(tree_just_flatten(sim_mom, namespace=extended)) + out = np.array(tree_just_flatten(sim_mom, namespace=value_namespace)) return out int_jac = first_derivative( @@ -420,7 +420,7 @@ def get_msm_optimization_functions( chol_weights = np.linalg.cholesky(flat_weights) - flat_emp_mom = tree_just_flatten(empirical_moments, namespace=extended) + flat_emp_mom = tree_just_flatten(empirical_moments, namespace=value_namespace) _simulate_moments = _partial_kwargs(simulate_moments, simulate_moments_kwargs) _jacobian = _partial_kwargs(jacobian, jacobian_kwargs) @@ -431,7 +431,7 @@ def get_msm_optimization_functions( simulate_moments=_simulate_moments, flat_empirical_moments=flat_emp_mom, chol_weights=chol_weights, - namespace=extended, + namespace=value_namespace, ) ) @@ -977,7 +977,7 @@ def sensitivity( ) elif return_type == "dataframe": row_names = self._internal_estimates.names - col_names = leaf_names(self._empirical_moments, namespace=extended) + col_names = leaf_names(self._empirical_moments, namespace=value_namespace) out = pd.DataFrame( data=raw, index=row_names, diff --git a/src/estimagic/msm_weighting.py b/src/estimagic/msm_weighting.py index 34222bb6d..22602bcc3 100644 --- a/src/estimagic/msm_weighting.py +++ b/src/estimagic/msm_weighting.py @@ -6,7 +6,8 @@ from estimagic.bootstrap import bootstrap from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace from optimagic.utilities import robust_inverse @@ -54,7 +55,7 @@ def get_moments_cov( def func(data, **kwargs): raw = calculate_moments(data, **kwargs) out = pd.Series( - tree_just_flatten(raw, namespace=extended) + tree_just_flatten(raw, namespace=value_namespace) ) # xxxx won't be necessary soon! return out diff --git a/src/estimagic/shared_covs.py b/src/estimagic/shared_covs.py index c1f0b782f..a5931113a 100644 --- a/src/estimagic/shared_covs.py +++ b/src/estimagic/shared_covs.py @@ -6,10 +6,10 @@ from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( - extended, tree_just_flatten, tree_unflatten, ) +from optimagic.typing import value_namespace def transform_covariance( @@ -150,7 +150,7 @@ def calculate_estimation_summary( # ================================================================================== flat_data = { - key: tree_just_flatten(val, namespace=extended) + key: tree_just_flatten(val, namespace=value_namespace) for key, val in summary_data.items() } @@ -169,7 +169,7 @@ def calculate_estimation_summary( # ================================================================================== # create tree with values corresponding to indices of df - indices = tree_unflatten(summary_data["value"], names, namespace=extended) + indices = tree_unflatten(summary_data["value"], names, namespace=value_namespace) estimates_flat = tree_just_flatten(summary_data["value"]) indices_flat = tree_just_flatten(indices) @@ -318,7 +318,7 @@ def calculate_free_estimates(estimates, internal_estimates): mask = internal_estimates.free_mask names = internal_estimates.names - external_flat = np.array(tree_just_flatten(estimates, namespace=extended)) + external_flat = np.array(tree_just_flatten(estimates, namespace=value_namespace)) free_estimates = FreeParams( values=external_flat[mask], @@ -352,7 +352,7 @@ def transform_free_values_to_params_tree(values, free_params, params): mask = free_params.free_mask flat = np.full(len(mask), np.nan) flat[np.ix_(mask)] = values - pytree = tree_unflatten(params, flat, namespace=extended) + pytree = tree_unflatten(params, flat, namespace=value_namespace) return pytree diff --git a/src/optimagic/benchmarking/run_benchmark.py b/src/optimagic/benchmarking/run_benchmark.py index e89b73b7e..7f104ba66 100644 --- a/src/optimagic/benchmarking/run_benchmark.py +++ b/src/optimagic/benchmarking/run_benchmark.py @@ -13,7 +13,8 @@ from optimagic import batch_evaluators from optimagic.algorithms import AVAILABLE_ALGORITHMS from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace def run_benchmark( @@ -189,7 +190,7 @@ def _process_one_result(optimize_result, problem): # This will happen if the optimization raised an error if isinstance(optimize_result, str): - params_history_flat = [tree_just_flatten(_start_x, namespace=extended)] + params_history_flat = [tree_just_flatten(_start_x, namespace=value_namespace)] criterion_history = [_start_crit_value] time_history = [np.inf] batches_history = [0] @@ -197,7 +198,7 @@ def _process_one_result(optimize_result, problem): history = optimize_result.history params_history = history.params params_history_flat = [ - tree_just_flatten(p, namespace=extended) for p in params_history + tree_just_flatten(p, namespace=value_namespace) for p in params_history ] if _is_noisy: criterion_history = np.array([_criterion(p) for p in params_history]) diff --git a/src/optimagic/differentiation/derivatives.py b/src/optimagic/differentiation/derivatives.py index a249ebeea..4dcc6f97b 100644 --- a/src/optimagic/differentiation/derivatives.py +++ b/src/optimagic/differentiation/derivatives.py @@ -21,13 +21,12 @@ from optimagic.parameters.block_trees import hessian_to_block_tree, matrix_to_block_tree from optimagic.parameters.bounds import Bounds, get_internal_bounds, pre_process_bounds from optimagic.parameters.tree_registry import ( - extended, tree_flatten, tree_just_flatten, tree_unflatten, ) from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves -from optimagic.typing import BatchEvaluatorLiteral, PyTree +from optimagic.typing import BatchEvaluatorLiteral, PyTree, value_namespace @dataclass(frozen=True) @@ -222,19 +221,23 @@ def first_derivative( is_fast_path = _is_1d_array(params) if not is_fast_path: - x, params_treedef = tree_flatten(params, namespace=extended) + x, params_treedef = tree_flatten(params, namespace=value_namespace) x = np.array(x, dtype=np.float64) if scaling_factor is not None and not np.isscalar(scaling_factor): scaling_factor = np.array( - tree_just_flatten(scaling_factor, namespace=extended) + tree_just_flatten(scaling_factor, namespace=value_namespace) ) if min_steps is not None and not np.isscalar(min_steps): - min_steps = np.array(tree_just_flatten(min_steps, namespace=extended)) + min_steps = np.array( + tree_just_flatten(min_steps, namespace=value_namespace) + ) if step_size is not None and not np.isscalar(step_size): - step_size = np.array(tree_just_flatten(step_size, namespace=extended)) + step_size = np.array( + tree_just_flatten(step_size, namespace=value_namespace) + ) else: x = params.astype(np.float64) @@ -288,7 +291,7 @@ def first_derivative( if not is_fast_path: evaluation_points = [ # entries are either a numpy.ndarray or np.nan - _unflatten_if_not_nan(p, params_treedef, extended) + _unflatten_if_not_nan(p, params_treedef, value_namespace) for p in evaluation_points ] @@ -327,14 +330,14 @@ def first_derivative( elif vector_out: f0 = f0_tree.astype(float) else: - f0 = tree_leaves(f0_tree, namespace=extended) + f0 = tree_leaves(f0_tree, namespace=value_namespace) f0 = np.array(f0, dtype=np.float64) # convert the raw evaluations to numpy arrays raw_evals_arr = _convert_evals_to_numpy( raw_evals=raw_evals, unpacker=unpacker, - namespace=extended, + namespace=value_namespace, is_scalar_out=scalar_out, is_vector_out=vector_out, ) @@ -536,19 +539,23 @@ def second_derivative( is_fast_path = _is_1d_array(params) if not is_fast_path: - x, params_treedef = tree_flatten(params, namespace=extended) + x, params_treedef = tree_flatten(params, namespace=value_namespace) x = np.array(x, dtype=np.float64) if scaling_factor is not None and not np.isscalar(scaling_factor): scaling_factor = np.array( - tree_just_flatten(scaling_factor, namespace=extended) + tree_just_flatten(scaling_factor, namespace=value_namespace) ) if min_steps is not None and not np.isscalar(min_steps): - min_steps = np.array(tree_just_flatten(min_steps, namespace=extended)) + min_steps = np.array( + tree_just_flatten(min_steps, namespace=value_namespace) + ) if step_size is not None and not np.isscalar(step_size): - step_size = np.array(tree_just_flatten(step_size, namespace=extended)) + step_size = np.array( + tree_just_flatten(step_size, namespace=value_namespace) + ) else: x = params.astype(np.float64) @@ -624,7 +631,8 @@ def second_derivative( evaluation_points = { # entries are either a numpy.ndarray or np.nan, we unflatten only step_type: [ - _unflatten_if_not_nan(p, params_treedef, extended) for p in points + _unflatten_if_not_nan(p, params_treedef, value_namespace) + for p in points ] for step_type, points in evaluation_points.items() } @@ -663,13 +671,13 @@ def second_derivative( func_value = f0 f0_tree = unpacker(f0) - f0 = tree_leaves(f0_tree, namespace=extended) + f0 = tree_leaves(f0_tree, namespace=value_namespace) f0 = np.array(f0, dtype=np.float64) # convert the raw evaluations to numpy arrays raw_evals = { step_type: _convert_evals_to_numpy( - raw_evals=evals, unpacker=unpacker, namespace=extended + raw_evals=evals, unpacker=unpacker, namespace=value_namespace ) for step_type, evals in raw_evals.items() } diff --git a/src/optimagic/examples/criterion_functions.py b/src/optimagic/examples/criterion_functions.py index 52f9dabf8..0c89563f1 100644 --- a/src/optimagic/examples/criterion_functions.py +++ b/src/optimagic/examples/criterion_functions.py @@ -17,11 +17,10 @@ ) from optimagic.parameters.block_trees import matrix_to_block_tree from optimagic.parameters.tree_registry import ( - extended, tree_just_flatten, tree_unflatten, ) -from optimagic.typing import PyTree +from optimagic.typing import PyTree, value_namespace @mark.scalar @@ -215,10 +214,12 @@ def _get_x(params: PyTree) -> NDArray[np.float64]: if isinstance(params, np.ndarray) and params.ndim == 1: x = params.astype(float) else: - x = np.array(tree_just_flatten(params, namespace=extended), dtype=np.float64) + x = np.array( + tree_just_flatten(params, namespace=value_namespace), dtype=np.float64 + ) return x def _unflatten_gradient(flat: NDArray[np.float64], params: PyTree) -> PyTree: - out = tree_unflatten(params, flat.tolist(), namespace=extended) + out = tree_unflatten(params, flat.tolist(), namespace=value_namespace) return out diff --git a/src/optimagic/optimization/fun_value.py b/src/optimagic/optimization/fun_value.py index 9672ba46c..d07719f72 100644 --- a/src/optimagic/optimization/fun_value.py +++ b/src/optimagic/optimization/fun_value.py @@ -7,8 +7,8 @@ from numpy.typing import NDArray from optimagic.exceptions import InvalidFunctionError -from optimagic.parameters.tree_registry import extended, tree_just_flatten -from optimagic.typing import AggregationLevel, PyTree, Scalar +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import AggregationLevel, PyTree, Scalar, value_namespace from optimagic.utilities import isscalar @@ -123,7 +123,7 @@ def _get_flat_value(value: PyTree) -> NDArray[np.float64]: elif isinstance(value, np.ndarray): flat = value.flatten() else: - flat = tree_just_flatten(value, namespace=extended) + flat = tree_just_flatten(value, namespace=value_namespace) flat_arr = np.asarray(flat, dtype=np.float64) return flat_arr diff --git a/src/optimagic/optimization/history.py b/src/optimagic/optimization/history.py index 6744c5cef..e000b1431 100644 --- a/src/optimagic/optimization/history.py +++ b/src/optimagic/optimization/history.py @@ -8,12 +8,11 @@ from numpy.typing import NDArray from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_just_flatten, ) from optimagic.timing import CostModel -from optimagic.typing import Direction, EvalTask, PyTree +from optimagic.typing import Direction, EvalTask, PyTree, value_namespace @dataclass(frozen=True) @@ -401,7 +400,7 @@ def _get_flat_params(params: list[PyTree]) -> list[list[float]]: if fast_path: flatten = lambda x: x.tolist() else: - flatten = partial(tree_just_flatten, namespace=extended) + flatten = partial(tree_just_flatten, namespace=value_namespace) return [flatten(p) for p in params] @@ -413,7 +412,7 @@ def _get_flat_param_names(param: PyTree) -> list[str]: # arrays, but the fast path is only taken for 1d arrays, so it can be ignored. return np.arange(param.size).astype(str).tolist() - return leaf_names(param, namespace=extended) + return leaf_names(param, namespace=value_namespace) def _is_1d_array(param: PyTree) -> bool: diff --git a/src/optimagic/parameters/block_trees.py b/src/optimagic/parameters/block_trees.py index f3c620088..2724deb59 100644 --- a/src/optimagic/parameters/block_trees.py +++ b/src/optimagic/parameters/block_trees.py @@ -4,11 +4,11 @@ import pandas as pd from optimagic.parameters.tree_registry import ( - extended, tree_flatten, tree_unflatten, ) from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves +from optimagic.typing import value_namespace def matrix_to_block_tree(matrix, outer_tree, inner_tree): @@ -332,8 +332,8 @@ def _is_pd_object(obj): def _check_dimensions_matrix(matrix, outer_tree, inner_tree): - flat_outer = tree_leaves(outer_tree, namespace=extended) - flat_inner = tree_leaves(inner_tree, namespace=extended) + flat_outer = tree_leaves(outer_tree, namespace=value_namespace) + flat_inner = tree_leaves(inner_tree, namespace=value_namespace) if matrix.shape[0] != len(flat_outer): raise ValueError("First dimension of matrix does not match that of outer_tree.") @@ -344,8 +344,8 @@ def _check_dimensions_matrix(matrix, outer_tree, inner_tree): def _check_dimensions_hessian(hessian, f_tree, params_tree): - flat_f = tree_leaves(f_tree, namespace=extended) - flat_p = tree_leaves(params_tree, namespace=extended) + flat_f = tree_leaves(f_tree, namespace=value_namespace) + flat_p = tree_leaves(params_tree, namespace=value_namespace) if len(flat_f) == 1: # consider only dimensions with non trivial size (larger than 1) diff --git a/src/optimagic/parameters/bounds.py b/src/optimagic/parameters/bounds.py index 2c909271f..0013c6767 100644 --- a/src/optimagic/parameters/bounds.py +++ b/src/optimagic/parameters/bounds.py @@ -9,14 +9,13 @@ from optimagic.exceptions import InvalidBoundsError from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_map, ) from optimagic.parameters.tree_registry import ( tree_just_flatten as tree_leaves, ) -from optimagic.typing import PyTree +from optimagic.typing import PyTree, value_namespace from optimagic.utilities import fast_numpy_full @@ -80,7 +79,7 @@ def _process_bounds_sequence(bounds: Sequence[tuple[float, float]]) -> Bounds: def get_internal_bounds( params: PyTree, bounds: Bounds | None = None, - namespace: str = extended, + namespace: str = value_namespace, add_soft_bounds: bool = False, ) -> tuple[NDArray[np.float64] | None, NDArray[np.float64] | None]: """Create consolidated and flattened bounds for params. @@ -182,11 +181,15 @@ def _update_bounds_and_flatten( """ flat_nan_tree = tree_leaves(nan_tree, namespace=kind) if bounds is not None: - flat_bounds = tree_leaves(bounds, namespace=extended) + flat_bounds = tree_leaves(bounds, namespace=value_namespace) seperator = 10 * "$" - params_names = leaf_names(nan_tree, namespace=extended, separator=seperator) - bounds_names = leaf_names(bounds, namespace=extended, separator=seperator) + params_names = leaf_names( + nan_tree, namespace=value_namespace, separator=seperator + ) + bounds_names = leaf_names( + bounds, namespace=value_namespace, separator=seperator + ) flat_nan_dict = dict(zip(params_names, flat_nan_tree, strict=False)) diff --git a/src/optimagic/parameters/nonlinear_constraints.py b/src/optimagic/parameters/nonlinear_constraints.py index 6af14bf6d..87345a85c 100644 --- a/src/optimagic/parameters/nonlinear_constraints.py +++ b/src/optimagic/parameters/nonlinear_constraints.py @@ -10,11 +10,11 @@ from optimagic.optimization.algo_options import CONSTRAINTS_ABSOLUTE_TOLERANCE from optimagic.parameters.block_trees import block_tree_to_matrix from optimagic.parameters.tree_registry import ( - extended, tree_flatten, tree_just_flatten, tree_unflatten, ) +from optimagic.typing import value_namespace def process_nonlinear_constraints( @@ -365,13 +365,13 @@ def _extend_jacobian(jac_mat, selection_indices, n_params): def _get_selection_indices(params, selector): """Get index of selected flat params and number of flat params.""" - flat_params, params_treedef = tree_flatten(params, namespace=extended) + flat_params, params_treedef = tree_flatten(params, namespace=value_namespace) n_params = len(flat_params) indices = np.arange(n_params, dtype=int) - params_indices = tree_unflatten(params_treedef, indices, namespace=extended) + params_indices = tree_unflatten(params_treedef, indices, namespace=value_namespace) selected = selector(params_indices) selection_indices = np.array( - tree_just_flatten(selected, namespace=extended), dtype=int + tree_just_flatten(selected, namespace=value_namespace), dtype=int ) return selection_indices, n_params diff --git a/src/optimagic/parameters/process_selectors.py b/src/optimagic/parameters/process_selectors.py index 3c44ce25c..83cee9eee 100644 --- a/src/optimagic/parameters/process_selectors.py +++ b/src/optimagic/parameters/process_selectors.py @@ -6,7 +6,8 @@ from optimagic.constraints import Constraint from optimagic.exceptions import InvalidConstraintError -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace def process_selectors(constraints, params, tree_converter, param_names): @@ -51,7 +52,7 @@ def process_selectors(constraints, params, tree_converter, param_names): field=field, constraint=constr, params_case=params_case, - namespace=extended, + namespace=value_namespace, ) try: with warnings.catch_warnings(): diff --git a/src/optimagic/parameters/tree_conversion.py b/src/optimagic/parameters/tree_conversion.py index 1dcbabd9b..12594094e 100644 --- a/src/optimagic/parameters/tree_conversion.py +++ b/src/optimagic/parameters/tree_conversion.py @@ -6,13 +6,12 @@ from optimagic.parameters.block_trees import block_tree_to_matrix from optimagic.parameters.bounds import get_internal_bounds from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_flatten, tree_just_flatten, tree_unflatten, ) -from optimagic.typing import AggregationLevel +from optimagic.typing import AggregationLevel, value_namespace def get_tree_converter( @@ -50,25 +49,25 @@ def get_tree_converter( FlatParams: NamedTuple of 1d arrays with flattened bounds and param names. """ - _params_vec, _params_treedef = tree_flatten(params, namespace=extended) + _params_vec, _params_treedef = tree_flatten(params, namespace=value_namespace) _params_vec = np.array(_params_vec).astype(float) _lower, _upper = get_internal_bounds( params=params, bounds=bounds, - namespace=extended, + namespace=value_namespace, ) if add_soft_bounds: _soft_lower, _soft_upper = get_internal_bounds( params=params, bounds=bounds, - namespace=extended, + namespace=value_namespace, add_soft_bounds=add_soft_bounds, ) else: _soft_lower, _soft_upper = None, None - _param_names = leaf_names(params, namespace=extended) + _param_names = leaf_names(params, namespace=value_namespace) flat_params = FlatParams( values=_params_vec, @@ -79,13 +78,13 @@ def get_tree_converter( soft_upper_bounds=_soft_upper, ) - _params_flatten = _get_params_flatten(namespace=extended) + _params_flatten = _get_params_flatten(namespace=value_namespace) _params_unflatten = _get_params_unflatten( - namespace=extended, treedef=_params_treedef + namespace=value_namespace, treedef=_params_treedef ) _derivative_flatten = _get_derivative_flatten( - namespace=extended, + namespace=value_namespace, solver_type=solver_type, params=params, func_eval=func_eval, diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 695d2d3aa..8d0e5dbe0 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -9,14 +9,7 @@ import pandas as pd from optree.pytree import PyTreeSpec -extended = "value" -namespaces = [ - extended, - "lower_bound", - "upper_bound", - "soft_lower_bound", - "soft_upper_bound", -] +from optimagic.typing import optree_namespaces EQUALITY_CHECKERS = {} EQUALITY_CHECKERS[np.ndarray.__name__] = lambda a, b: bool((a == b).all()) @@ -156,7 +149,7 @@ def _unflatten_ndarray(aux_data, leaves): return np.array(leaves).reshape(aux_data) -for namespace in namespaces: +for namespace in optree_namespaces: optree.register_pytree_node( pd.DataFrame, partial(_flatten_df, data_col=namespace), diff --git a/src/optimagic/typing.py b/src/optimagic/typing.py index 9b389ced2..455b37157 100644 --- a/src/optimagic/typing.py +++ b/src/optimagic/typing.py @@ -173,3 +173,14 @@ class MultiStartIterationHistory(TupleLikeAccess): history: IterationHistory local_histories: list[IterationHistory] | None = None exploration: IterationHistory | None = None + + +optree_namespaces = [ + "value", + "lower_bound", + "upper_bound", + "soft_lower_bound", + "soft_upper_bound", +] + +value_namespace = optree_namespaces[0] diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index 5fc618cbf..d1ec9b475 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -12,13 +12,12 @@ from optimagic.optimization.history import History from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_flatten, tree_just_flatten, tree_unflatten, ) -from optimagic.typing import IterationHistory, PyTree +from optimagic.typing import IterationHistory, PyTree, value_namespace from optimagic.visualization.backends import line_plot from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle @@ -585,13 +584,19 @@ def _extract_params_plot_lines( history = data.history.params start_params = data.start_params - hist_arr = np.array([tree_just_flatten(p, namespace=extended) for p in history]).T - names = leaf_names(start_params, namespace=extended) + hist_arr = np.array( + [tree_just_flatten(p, namespace=value_namespace) for p in history] + ).T + names = leaf_names(start_params, namespace=value_namespace) if selector is not None: - flat, treedef = tree_flatten(start_params, namespace=extended) - helper = tree_unflatten(treedef, list(range(len(flat))), namespace=extended) - selected = np.array(tree_just_flatten(selector(helper), namespace=extended)) + flat, treedef = tree_flatten(start_params, namespace=value_namespace) + helper = tree_unflatten( + treedef, list(range(len(flat))), namespace=value_namespace + ) + selected = np.array( + tree_just_flatten(selector(helper), namespace=value_namespace) + ) names = [names[i] for i in selected] hist_arr = hist_arr[selected] diff --git a/src/optimagic/visualization/slice_plot.py b/src/optimagic/visualization/slice_plot.py index 7e42c06d1..b0cf68242 100644 --- a/src/optimagic/visualization/slice_plot.py +++ b/src/optimagic/visualization/slice_plot.py @@ -23,9 +23,9 @@ from optimagic.parameters.bounds import pre_process_bounds from optimagic.parameters.conversion import get_converter from optimagic.parameters.space_conversion import InternalParams -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten from optimagic.shared.process_user_function import infer_aggregation_level -from optimagic.typing import AggregationLevel, PyTree +from optimagic.typing import AggregationLevel, PyTree, value_namespace from optimagic.visualization.backends import grid_line_plot, line_plot from optimagic.visualization.plotting_utilities import LineData, MarkerData @@ -249,7 +249,7 @@ def _get_plot_data( if selector is not None: helper = converter.params_from_internal(selected) selected = np.array( - tree_just_flatten(selector(helper), namespace=extended), dtype=int + tree_just_flatten(selector(helper), namespace=value_namespace), dtype=int ).ravel() # Ensure the result is a 1D array if not np.isfinite(internal_params.lower_bounds[selected]).all(): diff --git a/src/optimagic/visualization/slice_plot_3d.py b/src/optimagic/visualization/slice_plot_3d.py index f46c2447a..be33779d7 100644 --- a/src/optimagic/visualization/slice_plot_3d.py +++ b/src/optimagic/visualization/slice_plot_3d.py @@ -19,9 +19,9 @@ ) from optimagic.parameters.bounds import pre_process_bounds from optimagic.parameters.conversion import get_converter -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten from optimagic.shared.process_user_function import infer_aggregation_level -from optimagic.typing import AggregationLevel +from optimagic.typing import AggregationLevel, value_namespace def slice_plot_3d( # type: ignore[no-untyped-def] @@ -150,7 +150,7 @@ def slice_plot_3d( # type: ignore[no-untyped-def] if selector is not None: helper = converter.params_from_internal(selected) selected = np.array( - tree_just_flatten(selector(helper), namespace=extended), dtype=int + tree_just_flatten(selector(helper), namespace=value_namespace), dtype=int ).reshape(-1) n_params = len(selected) if not np.isfinite(internal_params.lower_bounds[selected]).all(): diff --git a/tests/estimagic/test_bootstrap_ci.py b/tests/estimagic/test_bootstrap_ci.py index 801faaebb..5bd7dc5f0 100644 --- a/tests/estimagic/test_bootstrap_ci.py +++ b/tests/estimagic/test_bootstrap_ci.py @@ -6,7 +6,8 @@ from estimagic.bootstrap_ci import calculate_ci, check_inputs from estimagic.bootstrap_samples import get_bootstrap_indices -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace from optimagic.utilities import get_rng @@ -70,7 +71,7 @@ def _outcome_func_arr(data): def test_ci(outcome, method, setup, expected): def outcome_flat(data): - return tree_just_flatten(outcome(data), namespace=extended) + return tree_just_flatten(outcome(data), namespace=value_namespace) base_outcome = outcome_flat(setup["df"]) lower, upper = calculate_ci(base_outcome, setup["estimates"], ci_method=method) diff --git a/tests/estimagic/test_estimate_msm_dict_params_and_moments.py b/tests/estimagic/test_estimate_msm_dict_params_and_moments.py index 2cb2bacbb..ae2459c4f 100644 --- a/tests/estimagic/test_estimate_msm_dict_params_and_moments.py +++ b/tests/estimagic/test_estimate_msm_dict_params_and_moments.py @@ -5,7 +5,8 @@ from numpy.testing import assert_array_almost_equal as aaae from estimagic.estimate_msm import estimate_msm -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace def test_estimate_msm_dict_params_and_moments(): @@ -97,7 +98,7 @@ def assert_almost_equal(x, y, decimal=6): x_flat = x y_flat = y else: - x_flat = np.array(tree_just_flatten(x, namespace=extended)) - y_flat = np.array(tree_just_flatten(x, namespace=extended)) + x_flat = np.array(tree_just_flatten(x, namespace=value_namespace)) + y_flat = np.array(tree_just_flatten(x, namespace=value_namespace)) aaae(x_flat, y_flat, decimal=decimal) diff --git a/tests/estimagic/test_shared.py b/tests/estimagic/test_shared.py index 3ddd3a485..a1c21b641 100644 --- a/tests/estimagic/test_shared.py +++ b/tests/estimagic/test_shared.py @@ -14,7 +14,8 @@ transform_free_cov_to_cov, transform_free_values_to_params_tree, ) -from optimagic.parameters.tree_registry import extended, leaf_names, tree_equal +from optimagic.parameters.tree_registry import leaf_names, tree_equal +from optimagic.typing import value_namespace from optimagic.utilities import get_rng @@ -239,7 +240,7 @@ def test_calculate_estimation_summary(): "free": np.array([True, True, True]), } - names = leaf_names(summary_data["value"], namespace=extended) + names = leaf_names(summary_data["value"], namespace=value_namespace) free_names = names # function call diff --git a/tests/optimagic/logging/test_logger.py b/tests/optimagic/logging/test_logger.py index 86fa29af8..76421f630 100644 --- a/tests/optimagic/logging/test_logger.py +++ b/tests/optimagic/logging/test_logger.py @@ -13,11 +13,10 @@ ) from optimagic.optimization.optimize import minimize from optimagic.parameters.tree_registry import ( - extended, tree_equal, tree_just_flatten, ) -from optimagic.typing import Direction +from optimagic.typing import Direction, value_namespace @pytest.fixture() @@ -85,8 +84,8 @@ def test_log_reader_read_multistart_history(example_db): assert exploration is None assert tree_equal( - tree_just_flatten(asdict(history), namespace=extended), - tree_just_flatten(asdict(reader.read_history()), namespace=extended), + tree_just_flatten(asdict(history), namespace=value_namespace), + tree_just_flatten(asdict(reader.read_history()), namespace=value_namespace), ) diff --git a/tests/optimagic/optimization/test_params_versions.py b/tests/optimagic/optimization/test_params_versions.py index 52db60eed..f4c868f16 100644 --- a/tests/optimagic/optimization/test_params_versions.py +++ b/tests/optimagic/optimization/test_params_versions.py @@ -10,7 +10,8 @@ sos_scalar, ) from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace PARAMS = [ {"a": 1.0, "b": 2, "c": 3, "d": 4, "e": 5}, @@ -26,7 +27,7 @@ @pytest.mark.parametrize("params", PARAMS + SCALAR_PARAMS) def test_tree_params_numerical_derivative_scalar_criterion(params): - flat = np.array(tree_just_flatten(params, namespace=extended)) + flat = np.array(tree_just_flatten(params, namespace=value_namespace)) expected = np.zeros_like(flat) res = minimize( @@ -34,13 +35,13 @@ def test_tree_params_numerical_derivative_scalar_criterion(params): params=params, algorithm="scipy_lbfgsb", ) - calculated = np.array(tree_just_flatten(res.params, namespace=extended)) + calculated = np.array(tree_just_flatten(res.params, namespace=value_namespace)) aaae(calculated, expected) @pytest.mark.parametrize("params", PARAMS + SCALAR_PARAMS) def test_tree_params_scalar_criterion(params): - flat = np.array(tree_just_flatten(params, namespace=extended)) + flat = np.array(tree_just_flatten(params, namespace=value_namespace)) expected = np.zeros_like(flat) res = minimize( @@ -49,7 +50,7 @@ def test_tree_params_scalar_criterion(params): params=params, algorithm="scipy_lbfgsb", ) - calculated = np.array(tree_just_flatten(res.params, namespace=extended)) + calculated = np.array(tree_just_flatten(res.params, namespace=value_namespace)) aaae(calculated, expected) @@ -61,7 +62,7 @@ def test_tree_params_scalar_criterion(params): @pytest.mark.parametrize("params, algorithm", TEST_CASES_SOS_LS) def test_tree_params_numerical_derivative_sos_ls(params, algorithm): - flat = np.array(tree_just_flatten(params, namespace=extended)) + flat = np.array(tree_just_flatten(params, namespace=value_namespace)) expected = np.zeros_like(flat) res = minimize( @@ -69,13 +70,13 @@ def test_tree_params_numerical_derivative_sos_ls(params, algorithm): params=params, algorithm=algorithm, ) - calculated = np.array(tree_just_flatten(res.params, namespace=extended)) + calculated = np.array(tree_just_flatten(res.params, namespace=value_namespace)) aaae(calculated, expected) @pytest.mark.parametrize("params, algorithm", TEST_CASES_SOS_LS) def test_tree_params_sos_ls(params, algorithm): - flat = np.array(tree_just_flatten(params, namespace=extended)) + flat = np.array(tree_just_flatten(params, namespace=value_namespace)) expected = np.zeros_like(flat) derivatives = [sos_gradient, sos_ls_jacobian] @@ -85,5 +86,5 @@ def test_tree_params_sos_ls(params, algorithm): params=params, algorithm=algorithm, ) - calculated = np.array(tree_just_flatten(res.params, namespace=extended)) + calculated = np.array(tree_just_flatten(res.params, namespace=value_namespace)) aaae(calculated, expected) diff --git a/tests/optimagic/optimization/test_with_logging.py b/tests/optimagic/optimization/test_with_logging.py index 9abc05bf4..594f8bdd0 100644 --- a/tests/optimagic/optimization/test_with_logging.py +++ b/tests/optimagic/optimization/test_with_logging.py @@ -20,7 +20,8 @@ from optimagic.logging.logger import SQLiteLogOptions from optimagic.logging.types import ExistenceStrategy from optimagic.optimization.optimize import minimize -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace @mark.least_squares @@ -46,7 +47,7 @@ def test_optimization_with_valid_logging(algorithm, params): algorithm=algorithm, logging="logging.db", ) - flat = np.array(tree_just_flatten(res.params, namespace=extended)) + flat = np.array(tree_just_flatten(res.params, namespace=value_namespace)) aaae(flat, np.zeros(3)) diff --git a/tests/optimagic/parameters/test_block_trees.py b/tests/optimagic/parameters/test_block_trees.py index 95703f750..afe70e494 100644 --- a/tests/optimagic/parameters/test_block_trees.py +++ b/tests/optimagic/parameters/test_block_trees.py @@ -10,8 +10,9 @@ hessian_to_block_tree, matrix_to_block_tree, ) -from optimagic.parameters.tree_registry import extended, tree_equal +from optimagic.parameters.tree_registry import tree_equal from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves +from optimagic.typing import value_namespace def test_matrix_to_block_tree_array_and_scalar(): @@ -128,8 +129,8 @@ def test_block_tree_to_hessian_bijection(): params = {"a": np.arange(4), "b": [{"c": (1, 2), "d": np.array([5, 6])}]} f_tree = {"e": np.arange(3), "f": (5, 6, [7, 8, {"g": 1.0}])} - n_p = len(tree_leaves(params, namespace=extended)) - n_f = len(tree_leaves(f_tree, namespace=extended)) + n_p = len(tree_leaves(params, namespace=value_namespace)) + n_f = len(tree_leaves(f_tree, namespace=value_namespace)) expected = np.arange(n_f * n_p**2).reshape(n_f, n_p, n_p) block_hessian = hessian_to_block_tree(expected, f_tree, params) diff --git a/tests/optimagic/parameters/test_nonlinear_constraints.py b/tests/optimagic/parameters/test_nonlinear_constraints.py index 1209658f7..cd1e1d491 100644 --- a/tests/optimagic/parameters/test_nonlinear_constraints.py +++ b/tests/optimagic/parameters/test_nonlinear_constraints.py @@ -21,7 +21,8 @@ process_nonlinear_constraints, vector_as_list_of_scalar_constraints, ) -from optimagic.parameters.tree_registry import extended, tree_just_flatten +from optimagic.parameters.tree_registry import tree_just_flatten +from optimagic.typing import value_namespace @dataclass @@ -30,7 +31,7 @@ def params_from_internal(self, x): return x def params_to_internal(self, params): - return np.array(tree_just_flatten(params, namespace=extended)) + return np.array(tree_just_flatten(params, namespace=value_namespace)) # ====================================================================================== diff --git a/tests/optimagic/parameters/test_process_selectors.py b/tests/optimagic/parameters/test_process_selectors.py index 805e8a356..ced0a8d91 100644 --- a/tests/optimagic/parameters/test_process_selectors.py +++ b/tests/optimagic/parameters/test_process_selectors.py @@ -7,11 +7,11 @@ from optimagic.parameters.process_selectors import process_selectors from optimagic.parameters.tree_conversion import TreeConverter from optimagic.parameters.tree_registry import ( - extended, tree_flatten, tree_just_flatten, tree_unflatten, ) +from optimagic.typing import value_namespace @pytest.mark.parametrize("constraints", [None, []]) @@ -35,14 +35,14 @@ def tree_params(): @pytest.fixture() def tree_params_converter(tree_params): - _, treedef = tree_flatten(tree_params, namespace=extended) + _, treedef = tree_flatten(tree_params, namespace=value_namespace) converter = TreeConverter( params_flatten=lambda params: np.array( - tree_just_flatten(params, namespace=extended) + tree_just_flatten(params, namespace=value_namespace) ), params_unflatten=lambda x: tree_unflatten( - treedef, x.tolist(), namespace=extended + treedef, x.tolist(), namespace=value_namespace ), derivative_flatten=None, ) diff --git a/tests/optimagic/parameters/test_tree_registry.py b/tests/optimagic/parameters/test_tree_registry.py index 150681624..4f5f77d37 100644 --- a/tests/optimagic/parameters/test_tree_registry.py +++ b/tests/optimagic/parameters/test_tree_registry.py @@ -4,11 +4,11 @@ from pandas.testing import assert_frame_equal from optimagic.parameters.tree_registry import ( - extended, leaf_names, tree_flatten, tree_unflatten, ) +from optimagic.typing import value_namespace @pytest.fixture() @@ -30,33 +30,33 @@ def other_df(): def test_flatten_df_with_value_column(value_df): - flat, _ = tree_flatten(value_df, namespace=extended) + flat, _ = tree_flatten(value_df, namespace=value_namespace) assert flat == [1, 3, 5] def test_unflatten_df_with_value_column(value_df): - _, treedef = tree_flatten(value_df, namespace=extended) - unflat = tree_unflatten(treedef, [10, 11, 12], namespace=extended) + _, treedef = tree_flatten(value_df, namespace=value_namespace) + unflat = tree_unflatten(treedef, [10, 11, 12], namespace=value_namespace) assert unflat.equals(value_df.assign(value=[10, 11, 12])) def test_leaf_names_df_with_value_column(value_df): - names = leaf_names(value_df, namespace=extended) + names = leaf_names(value_df, namespace=value_namespace) assert names == ["alpha", "beta", "gamma"] def test_flatten_partially_numeric_df(other_df): - flat, _ = tree_flatten(other_df, namespace=extended) + flat, _ = tree_flatten(other_df, namespace=value_namespace) assert flat == [0, 3.14, 1, 3.14, 2, 3.14] def test_unflatten_partially_numeric_df(other_df): - _, treedef = tree_flatten(other_df, namespace=extended) - unflat = tree_unflatten(treedef, [1, 2, 3, 4, 5, 6], namespace=extended) + _, treedef = tree_flatten(other_df, namespace=value_namespace) + unflat = tree_unflatten(treedef, [1, 2, 3, 4, 5, 6], namespace=value_namespace) other_df = other_df.assign(b=[1, 3, 5], c=[2, 4, 6]) assert_frame_equal(unflat, other_df, check_dtype=False) def test_leaf_names_partially_numeric_df(other_df): - names = leaf_names(other_df, namespace=extended) + names = leaf_names(other_df, namespace=value_namespace) assert names == ["alpha_b", "alpha_c", "beta_b", "beta_c", "gamma_b", "gamma_c"] From 9e7996dd940986ba929719029d994a788e3a5155 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 14:52:33 +0200 Subject: [PATCH 14/19] chore: register jax arrays --- src/optimagic/parameters/tree_registry.py | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 8d0e5dbe0..19099da79 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -11,11 +11,22 @@ from optimagic.typing import optree_namespaces +try: + import jax.numpy as jnp # type: ignore[import-not-found] + + _has_jax = True +except ImportError: + _has_jax = False + + EQUALITY_CHECKERS = {} EQUALITY_CHECKERS[np.ndarray.__name__] = lambda a, b: bool((a == b).all()) EQUALITY_CHECKERS[pd.Series.__name__] = lambda a, b: a.equals(b) EQUALITY_CHECKERS[pd.DataFrame.__name__] = lambda a, b: a.equals(b) +if _has_jax: + EQUALITY_CHECKERS[jnp.ndarray.__name__] = lambda a, b: bool((a == b).all()) + def tree_flatten(tree, is_leaf=None, namespace=""): if namespace: @@ -149,6 +160,15 @@ def _unflatten_ndarray(aux_data, leaves): return np.array(leaves).reshape(aux_data) +if _has_jax: + + def _flatten_jax_array(arr: jnp.ndarray): + return arr.flatten().tolist(), arr.shape, _array_element_names(arr) + + def _unflatten_jax_array(aux_data, leaves): + return jnp.array(leaves).reshape(aux_data) + + for namespace in optree_namespaces: optree.register_pytree_node( pd.DataFrame, @@ -170,3 +190,11 @@ def _unflatten_ndarray(aux_data, leaves): _unflatten_ndarray, namespace=namespace, ) + + if _has_jax: + optree.register_pytree_node( + jnp.ndarray, + _flatten_jax_array, + _unflatten_jax_array, + namespace=namespace, + ) From b932c30afceb84f5b979a383489cf7bcedf529dd Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 15:08:54 +0200 Subject: [PATCH 15/19] chore: remove type hints --- src/optimagic/parameters/tree_registry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 19099da79..358437bee 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -140,7 +140,7 @@ def _unflatten_df(aux_data, leaves, data_col): return out -def _flatten_series(series: pd.Series): +def _flatten_series(series): return ( series.tolist(), {"index": series.index, "name": series.name}, @@ -152,7 +152,7 @@ def _unflatten_series(aux_data, leaves): return pd.Series(leaves, **aux_data) -def _flatten_ndarray(arr: np.ndarray): +def _flatten_ndarray(arr): return arr.flatten().tolist(), arr.shape, _array_element_names(arr) @@ -162,7 +162,7 @@ def _unflatten_ndarray(aux_data, leaves): if _has_jax: - def _flatten_jax_array(arr: jnp.ndarray): + def _flatten_jax_array(arr): return arr.flatten().tolist(), arr.shape, _array_element_names(arr) def _unflatten_jax_array(aux_data, leaves): From 8276c6d37e088e70e82f560e83f4ee6a8ba220a6 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 16:40:54 +0200 Subject: [PATCH 16/19] chore: rearrange method order --- src/optimagic/parameters/tree_registry.py | 54 +++++++++++------------ 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 358437bee..9883689ab 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,4 +1,4 @@ -"""Wrapper around pybaum get_registry to tailor it to optimagic.""" +"""Wrapper around optree to tailor it to optimagic.""" import itertools from functools import partial @@ -88,32 +88,6 @@ def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None): return equal -def _array_element_names(arr): - dim_names = [map(str, range(n)) for n in arr.shape] - names = list(map("_".join, itertools.product(*dim_names))) - return names - - -def _get_df_names(df): - index_strings = list(df.index.map(_index_element_to_string)) - if "value" in df: - out = index_strings - else: - out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] - - return out - - -def _index_element_to_string(element): - if isinstance(element, (tuple, list)): - as_strings = [str(entry) for entry in element] - res_string = "_".join(as_strings) - else: - res_string = str(element) - - return res_string - - def _flatten_df(df, data_col): is_value_df = "value" in df if is_value_df: @@ -169,6 +143,32 @@ def _unflatten_jax_array(aux_data, leaves): return jnp.array(leaves).reshape(aux_data) +def _get_df_names(df): + index_strings = list(df.index.map(_index_element_to_string)) + if "value" in df: + out = index_strings + else: + out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] + + return out + + +def _index_element_to_string(element): + if isinstance(element, (tuple, list)): + as_strings = [str(entry) for entry in element] + res_string = "_".join(as_strings) + else: + res_string = str(element) + + return res_string + + +def _array_element_names(arr): + dim_names = [map(str, range(n)) for n in arr.shape] + names = list(map("_".join, itertools.product(*dim_names))) + return names + + for namespace in optree_namespaces: optree.register_pytree_node( pd.DataFrame, From 6f03b1289ef3a83c5caa5e8b1445d7dbe8dbdc8a Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 16:50:54 +0200 Subject: [PATCH 17/19] chore: remove pybaum dependency --- pixi.lock | 43 ++----------------------------------------- pyproject.toml | 1 - 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/pixi.lock b/pixi.lock index cad1467ab..51557bc34 100644 --- a/pixi.lock +++ b/pixi.lock @@ -333,7 +333,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -606,7 +605,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl @@ -890,7 +888,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -1268,7 +1265,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -1578,7 +1574,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl @@ -1899,7 +1894,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -2291,7 +2285,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -2681,7 +2674,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -3072,7 +3064,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -3415,7 +3406,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -3688,7 +3678,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl @@ -3973,7 +3962,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -4316,7 +4304,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -4589,7 +4576,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl @@ -4874,7 +4860,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -5218,7 +5203,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -5492,7 +5476,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl @@ -5778,7 +5761,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -6118,7 +6100,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/37/9a/0c28b6371e0cdcb14f8f1930778cb3123acfcbd2c95bb9cf6b4a2ba0cce3/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl @@ -6385,7 +6366,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ef/91/a42ae716f8925e9659df2da21ba941f158686856107a61cc97a95e7647a3/sqlalchemy-2.0.48-cp312-cp312-macosx_11_0_arm64.whl @@ -6664,7 +6644,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/f7/9a/0408b02a4bcb3cf8b338a2b074ac7d1b2099e2b092b42473def22f7b625f/kaleido-0.2.1-py2.py3-none-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9f/c4/0ab22726a04ede56f689476b760f98f8f46607caecff993017ac1b64aa5d/scikit_learn-1.8.0-cp312-cp312-win_amd64.whl @@ -7004,7 +6983,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -7275,7 +7253,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl @@ -7558,7 +7535,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -7899,7 +7875,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -8170,7 +8145,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl @@ -8453,7 +8427,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -8796,7 +8769,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -9069,7 +9041,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl @@ -9353,7 +9324,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -9696,7 +9666,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl @@ -9976,7 +9945,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl @@ -10267,7 +10235,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/fe/53ac0cd932db5dcaf55961bc7cb7afdca8d80d8cc7406ed661f0c7dc111a/pdbp-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/52/d2/c6e44dba74f17c6216ce1b56044a9b93a929f1c2d5bdaff892512b260f5e/plotly-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/1e/0d44a4e3a291c009a357fbd1d61511d9306c2c4db9a7ceb6e8104d8d385f/Py_BOBYQA-1.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d1/fd/5c2baba82425b75baf7dbec5af57219cd252aa8a1ace4f5cd1d88e472276/pyswarms-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl @@ -19402,8 +19369,8 @@ packages: timestamp: 1733688053334 - pypi: ./ name: optimagic - version: 0.1.dev472+g7cc224247.d20260331 - sha256: 3fd4339b58c2f6c6ec2bec645fbce927d4e9a2180d17832e181c41f478e57915 + version: 0.1.dev488+g6dfe91919.d20260402 + sha256: 61e99547fb3283f2e855507263091d8d04b8b6dd01c0faf1b1e576a20d8ca14f requires_dist: - annotated-types>=0.4 - cloudpickle>=2.2 @@ -19411,7 +19378,6 @@ packages: - numpy>=1.26 - pandas>=2.1 - plotly>=5.14 - - pybaum>=0.1.2 - scipy>=1.11 - sqlalchemy>=2.0 - typing-extensions>=4.5 @@ -20954,11 +20920,6 @@ packages: - sphinx-rtd-theme ; extra == 'dev' - trustregion>=1.1 ; extra == 'trustregion' requires_python: '>=3.8' -- pypi: https://files.pythonhosted.org/packages/d3/20/bcd8f317a17900e5daeb3f87a87327399c9bbe9dcd97f4025f4663c3bdf1/pybaum-0.1.3-py3-none-any.whl - name: pybaum - version: 0.1.3 - sha256: a1d74200d0477c7da121af2f67a236d658ec62aaef5f5c35562ec959f52efa3d - requires_python: '>=3.7' - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-3.0.1-pyh7a1b43c_0.conda sha256: 2558727093f13d4c30e124724566d16badd7de532fd8ee7483628977117d02be md5: 70ece62498c769280f791e836ac53fff diff --git a/pyproject.toml b/pyproject.toml index 0c6a4905d..596d17288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ dependencies = [ "numpy>=1.26", "pandas>=2.1", "plotly>=5.14", - "pybaum>=0.1.2", "scipy>=1.11", "sqlalchemy>=2.0", "annotated-types>=0.4", From d3edeedb244661f99663f219f06a05928c2a089e Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 19:01:29 +0200 Subject: [PATCH 18/19] chore: remove remaining pybaum string --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 596d17288..0673b3f3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -348,7 +348,6 @@ ignore_errors = true [[tool.mypy.overrides]] module = [ - "pybaum", "scipy", "scipy.linalg", "scipy.linalg.lapack", From ca240b53616eaeb047c7c3dc05a2a995f5913f15 Mon Sep 17 00:00:00 2001 From: Abel Abate Date: Thu, 2 Apr 2026 19:03:30 +0200 Subject: [PATCH 19/19] chore: remove duplicate imports --- src/optimagic/parameters/tree_registry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/optimagic/parameters/tree_registry.py b/src/optimagic/parameters/tree_registry.py index 9883689ab..d09b1fced 100644 --- a/src/optimagic/parameters/tree_registry.py +++ b/src/optimagic/parameters/tree_registry.py @@ -1,6 +1,5 @@ """Wrapper around optree to tailor it to optimagic.""" -import itertools from functools import partial from itertools import product @@ -165,7 +164,7 @@ def _index_element_to_string(element): def _array_element_names(arr): dim_names = [map(str, range(n)) for n in arr.shape] - names = list(map("_".join, itertools.product(*dim_names))) + names = list(map("_".join, product(*dim_names))) return names