From a026ffc09091c6822806571158d4fbad08fa360c Mon Sep 17 00:00:00 2001 From: mj023 Date: Fri, 20 Mar 2026 16:46:03 +0100 Subject: [PATCH 1/8] Batch vmap --- src/lcm/dispatchers.py | 62 ++++++++++++++++++- src/lcm/grids.py | 4 ++ src/lcm/input_processing/regime_components.py | 4 ++ src/lcm/input_processing/regime_processing.py | 1 + src/lcm/interfaces.py | 2 + src/lcm/max_Q_over_a.py | 3 +- src/lcm/state_action_space.py | 2 + 7 files changed, 74 insertions(+), 4 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index f97c8698..00639327 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -1,11 +1,14 @@ import inspect from collections.abc import Callable +from functools import partial from types import MappingProxyType from typing import Literal, TypeVar, cast +import jax from jax import Array, vmap from lcm.functools import allow_args, allow_only_kwargs +from lcm.typing import Float1D, FloatND from lcm.utils import find_duplicates FunctionWithArrayReturn = TypeVar( @@ -147,7 +150,10 @@ def vmap_1d( def productmap( - *, func: FunctionWithArrayReturn, variables: tuple[str, ...] + *, + func: FunctionWithArrayReturn, + variables: tuple[str, ...], + batch_sizes: dict[str, int] | None = None, ) -> FunctionWithArrayReturn: """Apply vmap such that func is evaluated on the Cartesian product of variables. @@ -179,7 +185,12 @@ def productmap( func_callable_with_args = allow_args(func) - vmapped = _base_productmap(func_callable_with_args, variables) + if batch_sizes is not None: + vmapped = _base_productmap_batched( + func_callable_with_args, variables, batch_sizes + ) + else: + vmapped = _base_productmap(func_callable_with_args, variables) # Callables do not necessarily have a __signature__ attribute. vmapped.__signature__ = inspect.signature(func_callable_with_args) # ty: ignore[unresolved-attribute] @@ -219,5 +230,50 @@ def _base_productmap( vmapped = func for spec in vmap_specs: vmapped = vmap(vmapped, in_axes=spec) - return vmapped + + +def _base_productmap_batched( + func: FunctionWithArrayReturn, + product_axes: tuple[str, ...], + batch_sizes: dict[str, int], +) -> FunctionWithArrayReturn: + """Map func over the Cartesian product of product_axes. + + Like vmap, this function does not preserve the function signature and does not allow + the function to be called with keyword arguments. + + Args: + func: The function to be dispatched. Cannot have keyword-only arguments. + product_axes: Tuple with names of arguments over which we apply vmap. + + Returns: + A callable with the same arguments as func. See `product_map` for details. + + """ + + def nest_map(next_regime_to_V_arr: FloatND, **kwargs: FloatND) -> FloatND: + non_array_kwargs = { + key: val for key, val in kwargs.items() if key not in product_axes + } + loop = partial(func, **non_array_kwargs) + + # induction case: scan over one argument, eliminating it + def scan_one_more( + loop: FunctionWithArrayReturn, x: Float1D + ) -> FunctionWithArrayReturn: + def new_loop(**xs: Float1D) -> FloatND: + return jax.lax.map( + lambda x_i: loop(**{x: x_i}, **xs), + kwargs[x], + batch_size=batch_sizes[x], + ) + + return new_loop + + # compose + for x in reversed(product_axes): + loop = scan_one_more(loop, x) + return loop(next_regime_to_V_arr) + + return nest_map diff --git a/src/lcm/grids.py b/src/lcm/grids.py index 499c1390..b30f8f3a 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -75,14 +75,18 @@ def _to_categorical_dtype(cls: type) -> pd.CategoricalDtype: return decorator +@dataclass(frozen=True, kw_only=True) class Grid(ABC): """LCM Grid base class.""" + batch_size: int = 0 + @abstractmethod def to_jax(self) -> Int1D | Float1D: """Convert the grid to a Jax array.""" +@dataclass(frozen=True, kw_only=True) class ContinuousGrid(Grid): """Base class for grids representing continuous values with coordinate lookup. diff --git a/src/lcm/input_processing/regime_components.py b/src/lcm/input_processing/regime_components.py index 894b88a1..b6e53e16 100644 --- a/src/lcm/input_processing/regime_components.py +++ b/src/lcm/input_processing/regime_components.py @@ -77,6 +77,7 @@ def build_Q_and_F_functions( def build_max_Q_over_a_functions( *, state_action_space: StateActionSpace, + state_space_info: StateSpaceInfo, Q_and_F_functions: MappingProxyType[int, QAndFFunction], enable_jit: bool, ) -> MappingProxyType[int, MaxQOverAFunction]: @@ -84,6 +85,7 @@ def build_max_Q_over_a_functions( for period, Q_and_F in Q_and_F_functions.items(): max_Q_over_a_functions[period] = _build_max_Q_over_a_function( state_action_space=state_action_space, + state_space_info=state_space_info, Q_and_F=Q_and_F, enable_jit=enable_jit, ) @@ -93,6 +95,7 @@ def build_max_Q_over_a_functions( def _build_max_Q_over_a_function( *, state_action_space: StateActionSpace, + state_space_info: StateSpaceInfo, Q_and_F: QAndFFunction, enable_jit: bool, ) -> MaxQOverAFunction: @@ -100,6 +103,7 @@ def _build_max_Q_over_a_function( Q_and_F=Q_and_F, action_names=state_action_space.action_names, state_names=state_action_space.state_names, + batch_sizes=state_space_info.batch_sizes, ) if enable_jit: diff --git a/src/lcm/input_processing/regime_processing.py b/src/lcm/input_processing/regime_processing.py index 3b99d5cc..28f824b0 100644 --- a/src/lcm/input_processing/regime_processing.py +++ b/src/lcm/input_processing/regime_processing.py @@ -156,6 +156,7 @@ def process_regimes( ) max_Q_over_a_functions = build_max_Q_over_a_functions( state_action_space=state_action_spaces[name], + state_space_info=state_space_infos[name], Q_and_F_functions=Q_and_F_functions, enable_jit=enable_jit, ) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index ab8ee792..a522b5d3 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -140,6 +140,8 @@ class StateSpaceInfo: continuous_states: MappingProxyType[str, ContinuousGrid] """Immutable mapping of continuous state names to their grids.""" + batch_sizes: MappingProxyType[str, int] + """Immutable mapping of state names to their batch sizes.""" class PhaseVariantContainer[S, T]: diff --git a/src/lcm/max_Q_over_a.py b/src/lcm/max_Q_over_a.py index 892f9c8c..6974829a 100644 --- a/src/lcm/max_Q_over_a.py +++ b/src/lcm/max_Q_over_a.py @@ -23,6 +23,7 @@ def get_max_Q_over_a( Q_and_F: Callable[..., tuple[FloatND, BoolND]], action_names: tuple[str, ...], state_names: tuple[str, ...], + batch_sizes: dict[str, int], ) -> MaxQOverAFunction: r"""Get the function returning the maximum of Q over all actions. @@ -80,7 +81,7 @@ def max_Q_over_a( ) return Q_arr.max(where=F_arr, initial=-jnp.inf) - return productmap(func=max_Q_over_a, variables=state_names) + return productmap(func=max_Q_over_a, variables=state_names, batch_sizes=batch_sizes) def get_argmax_and_max_Q_over_a( diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index 67fc51a6..ec8f89e0 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -96,11 +96,13 @@ def create_state_space_info(regime: Regime) -> StateSpaceInfo: and isinstance(grid_spec, ContinuousGrid) and not isinstance(grid_spec, _ShockGrid) } + batch_sizes = {name: grid_spec.batch_size for name, grid_spec in gridspecs.items()} return StateSpaceInfo( state_names=tuple(state_names), discrete_states=MappingProxyType(discrete_states), continuous_states=MappingProxyType(continuous_states), + batch_sizes=MappingProxyType(batch_sizes), ) From 0983ab2e5265cf90be2a15aac0b0423255f2c96f Mon Sep 17 00:00:00 2001 From: mj023 Date: Fri, 27 Mar 2026 16:57:28 +0100 Subject: [PATCH 2/8] Always use batched --- src/lcm/dispatchers.py | 46 +++++++++---------- src/lcm/input_processing/regime_processing.py | 2 +- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 00639327..c55abc65 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -7,7 +7,7 @@ import jax from jax import Array, vmap -from lcm.functools import allow_args, allow_only_kwargs +from lcm.functools import allow_args, allow_only_kwargs, all_as_kwargs from lcm.typing import Float1D, FloatND from lcm.utils import find_duplicates @@ -185,17 +185,15 @@ def productmap( func_callable_with_args = allow_args(func) - if batch_sizes is not None: - vmapped = _base_productmap_batched( - func_callable_with_args, variables, batch_sizes - ) - else: - vmapped = _base_productmap(func_callable_with_args, variables) + if batch_sizes is None: + batch_sizes = {var:0 for var in variables} + + vmapped = _base_productmap_batched(func_callable_with_args, variables, batch_sizes=batch_sizes) # Callables do not necessarily have a __signature__ attribute. vmapped.__signature__ = inspect.signature(func_callable_with_args) # ty: ignore[unresolved-attribute] - return cast("FunctionWithArrayReturn", allow_only_kwargs(vmapped, enforce=False)) + return cast("FunctionWithArrayReturn", vmapped) def _base_productmap( @@ -252,28 +250,28 @@ def _base_productmap_batched( """ - def nest_map(next_regime_to_V_arr: FloatND, **kwargs: FloatND) -> FloatND: + def batched_vmap(**kwargs: FloatND) -> FloatND: + non_array_kwargs = { key: val for key, val in kwargs.items() if key not in product_axes } - loop = partial(func, **non_array_kwargs) + func_with_partialled_args = partial(func, **non_array_kwargs) - # induction case: scan over one argument, eliminating it - def scan_one_more( - loop: FunctionWithArrayReturn, x: Float1D + # Recursively map over one more product axe + def map_one_more( + loop: FunctionWithArrayReturn, axis: Float1D ) -> FunctionWithArrayReturn: - def new_loop(**xs: Float1D) -> FloatND: + def new_mapped_func(**already_mapped_kwargs: Float1D) -> FloatND: return jax.lax.map( - lambda x_i: loop(**{x: x_i}, **xs), - kwargs[x], - batch_size=batch_sizes[x], + lambda axis_i: loop(**{axis: axis_i}, **already_mapped_kwargs), + kwargs[axis], + batch_size=batch_sizes[axis], ) - return new_loop - - # compose - for x in reversed(product_axes): - loop = scan_one_more(loop, x) - return loop(next_regime_to_V_arr) + return new_mapped_func + # Loop over all product axes + for axis in reversed(product_axes): + func_with_partialled_args = map_one_more(func_with_partialled_args, axis) + return func_with_partialled_args() - return nest_map + return batched_vmap diff --git a/src/lcm/input_processing/regime_processing.py b/src/lcm/input_processing/regime_processing.py index 6ca066b2..2403110c 100644 --- a/src/lcm/input_processing/regime_processing.py +++ b/src/lcm/input_processing/regime_processing.py @@ -176,7 +176,7 @@ def process_regimes( max_Q_over_a_functions = build_max_Q_over_a_functions( state_action_space=state_action_spaces[name], state_space_info=state_space_infos[name], - Q_and_F_functions=Q_and_F_functions, + Q_and_F_functions=Q_and_F_solve, enable_jit=enable_jit, ) argmax_and_max_Q_over_a_functions = build_argmax_and_max_Q_over_a_functions( From aa55e1baec05e19b15155e54fbd30285e9174e67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:14:22 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lcm/utils/dispatchers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lcm/utils/dispatchers.py b/src/lcm/utils/dispatchers.py index c55abc65..f33a87b9 100644 --- a/src/lcm/utils/dispatchers.py +++ b/src/lcm/utils/dispatchers.py @@ -7,7 +7,7 @@ import jax from jax import Array, vmap -from lcm.functools import allow_args, allow_only_kwargs, all_as_kwargs +from lcm.functools import allow_args, allow_only_kwargs from lcm.typing import Float1D, FloatND from lcm.utils import find_duplicates @@ -186,9 +186,11 @@ def productmap( func_callable_with_args = allow_args(func) if batch_sizes is None: - batch_sizes = {var:0 for var in variables} + batch_sizes = dict.fromkeys(variables, 0) - vmapped = _base_productmap_batched(func_callable_with_args, variables, batch_sizes=batch_sizes) + vmapped = _base_productmap_batched( + func_callable_with_args, variables, batch_sizes=batch_sizes + ) # Callables do not necessarily have a __signature__ attribute. vmapped.__signature__ = inspect.signature(func_callable_with_args) # ty: ignore[unresolved-attribute] @@ -269,6 +271,7 @@ def new_mapped_func(**already_mapped_kwargs: Float1D) -> FloatND: ) return new_mapped_func + # Loop over all product axes for axis in reversed(product_axes): func_with_partialled_args = map_one_more(func_with_partialled_args, axis) From a7cde5ff68095ee40c723c7c5e6751fef8498e14 Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 30 Mar 2026 17:38:07 +0200 Subject: [PATCH 4/8] Handle positional only --- src/lcm/grids/continuous.py | 4 ++ src/lcm/grids/discrete.py | 39 ++++++++++--------- src/lcm/regime_building/max_Q_over_a.py | 2 +- src/lcm/regime_building/processing.py | 3 ++ src/lcm/utils/dispatchers.py | 50 +++++++++++++++++-------- tests/solution/test_solve_brute.py | 2 + tests/test_dispatchers.py | 12 ------ 7 files changed, 64 insertions(+), 48 deletions(-) diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index c1af7efd..c81f9c6a 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -16,6 +16,7 @@ ) +@dataclass(frozen=True, kw_only=True) class ContinuousGrid(Grid): """Base class for grids representing continuous values with coordinate lookup. @@ -24,6 +25,9 @@ class ContinuousGrid(Grid): """ + batch_size: int = 0 + """Size of the batches that are looped over during the solution.""" + @overload def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index 41451dee..f844aad6 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -6,15 +6,26 @@ from lcm.utils.containers import get_field_names_and_values -class _DiscreteGridBase(Grid): - """Base class for discrete grids: categories, codes, and JAX conversion.""" +class DiscreteGrid(Grid): + """A discrete grid defining the outcome space of a categorical variable. + + Args: + category_class: The category class representing the grid categories. Must + be a dataclass with fields that have unique int values. + + Raises: + GridInitializationError: If the `category_class` is not a dataclass with int + fields. + + """ - def __init__(self, category_class: type) -> None: + def __init__(self, category_class: type, batch_size: int = 0) -> None: _validate_discrete_grid(category_class) names_and_values = get_field_names_and_values(category_class) self.__categories = tuple(names_and_values.keys()) self.__codes = tuple(names_and_values.values()) self.__ordered: bool = getattr(category_class, "_ordered", False) + self.__batch_size: int = batch_size @property def categories(self) -> tuple[str, ...]: @@ -31,23 +42,11 @@ def ordered(self) -> bool: """Return whether the categories have a meaningful ordering.""" return self.__ordered + @property + def batch_size(self) -> int: + """Return batch size during solution.""" + return self.__batch_size + def to_jax(self) -> Int1D: """Convert the grid to a Jax array.""" return jnp.array(self.codes) - - -class DiscreteGrid(_DiscreteGridBase): - """A discrete grid defining the outcome space of a categorical variable. - - Args: - category_class: The category class representing the grid categories. Must - be a dataclass with fields that have unique int values. - - Raises: - GridInitializationError: If the `category_class` is not a dataclass with int - fields. - - """ - - def __init__(self, category_class: type) -> None: - super().__init__(category_class) diff --git a/src/lcm/regime_building/max_Q_over_a.py b/src/lcm/regime_building/max_Q_over_a.py index 5a7e2847..fa875229 100644 --- a/src/lcm/regime_building/max_Q_over_a.py +++ b/src/lcm/regime_building/max_Q_over_a.py @@ -21,9 +21,9 @@ def get_max_Q_over_a( *, Q_and_F: Callable[..., tuple[FloatND, BoolND]], + batch_sizes: dict[str, int], action_names: tuple[str, ...], state_names: tuple[str, ...], - batch_sizes: dict[str, int], ) -> MaxQOverAFunction: r"""Get the function returning the maximum of Q over all actions. diff --git a/src/lcm/regime_building/processing.py b/src/lcm/regime_building/processing.py index aca9c797..6efe8355 100644 --- a/src/lcm/regime_building/processing.py +++ b/src/lcm/regime_building/processing.py @@ -243,6 +243,7 @@ def _build_solve_functions( max_Q_over_a = _build_max_Q_over_a_per_period( state_action_space=state_action_space, Q_and_F_functions=Q_and_F_functions, + all_grids=all_grids[regime_name], enable_jit=enable_jit, ) @@ -1256,6 +1257,7 @@ def _build_max_Q_over_a_per_period( *, state_action_space: StateActionSpace, Q_and_F_functions: MappingProxyType[int, QAndFFunction], + all_grids: MappingProxyType[str, Grid], enable_jit: bool, ) -> MappingProxyType[int, MaxQOverAFunction]: """Build max-Q-over-a closures for each period.""" @@ -1263,6 +1265,7 @@ def _build_max_Q_over_a_per_period( for period, Q_and_F in Q_and_F_functions.items(): func = get_max_Q_over_a( Q_and_F=Q_and_F, + batch_sizes={name: grid.batch_size for name, grid in all_grids.items()}, # ty: ignore[unresolved-attribute] action_names=state_action_space.action_names, state_names=state_action_space.state_names, ) diff --git a/src/lcm/utils/dispatchers.py b/src/lcm/utils/dispatchers.py index f33a87b9..010d3ac7 100644 --- a/src/lcm/utils/dispatchers.py +++ b/src/lcm/utils/dispatchers.py @@ -7,9 +7,9 @@ import jax from jax import Array, vmap -from lcm.functools import allow_args, allow_only_kwargs from lcm.typing import Float1D, FloatND -from lcm.utils import find_duplicates +from lcm.utils.containers import find_duplicates +from lcm.utils.functools import allow_args, allow_only_kwargs FunctionWithArrayReturn = TypeVar( "FunctionWithArrayReturn", @@ -192,10 +192,17 @@ def productmap( func_callable_with_args, variables, batch_sizes=batch_sizes ) - # Callables do not necessarily have a __signature__ attribute. - vmapped.__signature__ = inspect.signature(func_callable_with_args) # ty: ignore[unresolved-attribute] + # Create new signature where every parameter is kw-only as + # batched_vmap takes only kwargs + signature = inspect.signature(func_callable_with_args) + new_parameters = [ + p.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for p in signature.parameters.values() + ] + new_signature = signature.replace(parameters=new_parameters) + vmapped.__signature__ = new_signature # ty: ignore[unresolved-attribute] - return cast("FunctionWithArrayReturn", vmapped) + return cast("FunctionWithArrayReturn", allow_only_kwargs(vmapped, enforce=False)) def _base_productmap( @@ -216,7 +223,6 @@ def _base_productmap( """ signature = inspect.signature(func) parameters = list(signature.parameters) - positions = [parameters.index(ax) for ax in product_axes if ax in parameters] vmap_specs = [] @@ -251,30 +257,44 @@ def _base_productmap_batched( A callable with the same arguments as func. See `product_map` for details. """ + parameters = inspect.signature(func).parameters def batched_vmap(**kwargs: FloatND) -> FloatND: - non_array_kwargs = { key: val for key, val in kwargs.items() if key not in product_axes } - func_with_partialled_args = partial(func, **non_array_kwargs) + func_with_partialled_args = cast( + "FunctionWithArrayReturn", partial(func, **non_array_kwargs) + ) # Recursively map over one more product axe def map_one_more( - loop: FunctionWithArrayReturn, axis: Float1D + loop: FunctionWithArrayReturn, axis: str ) -> FunctionWithArrayReturn: - def new_mapped_func(**already_mapped_kwargs: Float1D) -> FloatND: + def func_mapped_over_one_more_axis( + *already_mapped_args: Float1D, **already_mapped_kwargs: Float1D + ) -> FloatND: + if parameters[axis].kind == inspect.Parameter.POSITIONAL_ONLY: + return jax.lax.map( + lambda axis_i: loop( + axis_i, *already_mapped_args, **already_mapped_kwargs + ), + jax.numpy.atleast_1d(kwargs[axis]), + batch_size=batch_sizes[axis], + ) return jax.lax.map( - lambda axis_i: loop(**{axis: axis_i}, **already_mapped_kwargs), - kwargs[axis], + lambda axis_i: loop( + *already_mapped_args, **{axis: axis_i}, **already_mapped_kwargs + ), + jax.numpy.atleast_1d(kwargs[axis]), batch_size=batch_sizes[axis], ) - return new_mapped_func + return cast("FunctionWithArrayReturn", func_mapped_over_one_more_axis) # Loop over all product axes for axis in reversed(product_axes): func_with_partialled_args = map_one_more(func_with_partialled_args, axis) - return func_with_partialled_args() + return func_with_partialled_args() # ty: ignore[invalid-return-type] - return batched_vmap + return cast("FunctionWithArrayReturn", batched_vmap) diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index c4cd87b9..92db0a97 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -119,6 +119,7 @@ def _Q_and_F( Q_and_F=_Q_and_F, action_names=("consumption", "labor_supply"), state_names=("lazy", "wealth"), + batch_sizes={"lazy": 0, "wealth": 0}, ) # ================================================================================== @@ -176,6 +177,7 @@ def _Q_and_F(a, c, b, d, next_regime_to_V_arr): # noqa: ARG001 Q_and_F=_Q_and_F, action_names=("d",), state_names=("a", "b", "c"), + batch_sizes={"a": 0, "b": 0, "c": 0}, ) expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]]) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 4ae09d3d..ead717b4 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -142,18 +142,6 @@ def test_productmap_with_all_arguments_mapped_some_len_one(): aaae(calculated, expected) -def test_productmap_with_all_arguments_mapped_some_scalar(): - grids = { - "a": 1, - "b": 2, - "c": jnp.linspace(1, 5, 5), - } - - decorated = productmap(func=f, variables=("a", "b", "c")) - with pytest.raises(ValueError, match="vmap was requested to map its argument"): - decorated(**grids) # ty: ignore[missing-argument] - - def test_productmap_with_some_arguments_mapped(): grids = { "a": jnp.linspace(-5, 5, 10), From 357d2841ef01e262d76f1afff40ddcb0c81bd729 Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 30 Mar 2026 18:35:50 +0200 Subject: [PATCH 5/8] Only one base_productmap --- src/lcm/utils/dispatchers.py | 59 ++++++---------------- src/lcm_examples/mahler_yum_2024/_model.py | 17 ++++--- 2 files changed, 27 insertions(+), 49 deletions(-) diff --git a/src/lcm/utils/dispatchers.py b/src/lcm/utils/dispatchers.py index 010d3ac7..60efebf4 100644 --- a/src/lcm/utils/dispatchers.py +++ b/src/lcm/utils/dispatchers.py @@ -68,7 +68,13 @@ def simulation_spacemap( mappable_func = allow_args(func) - vmapped = _base_productmap(mappable_func, action_names) + vmapped = allow_args( + productmap( + func=mappable_func, + variables=action_names, + batch_sizes=dict.fromkeys(action_names, 0), + ) + ) vmapped = vmap_1d(func=vmapped, variables=state_names, callable_with="only_args") # Callables do not necessarily have a __signature__ attribute. @@ -155,17 +161,18 @@ def productmap( variables: tuple[str, ...], batch_sizes: dict[str, int] | None = None, ) -> FunctionWithArrayReturn: - """Apply vmap such that func is evaluated on the Cartesian product of variables. + """Apply mp such that func is evaluated on the Cartesian product of variables. - This is achieved by an iterative application of vmap. + This is achieved by an iterative application of mp. - In contrast to _base_productmap, productmap preserves the function signature and - allows the function to be called with keyword arguments. + In contrast to _base_productmap_batched, productmap preserves the function signature + and allows the function to be called with keyword arguments. Args: func: The function to be dispatched. variables: Tuple with names of arguments that over which the Cartesian product should be formed. + batch_sizes: Dict with the batch sizes for each variable. Returns: A callable with the same arguments as func (but with an additional leading @@ -185,6 +192,7 @@ def productmap( func_callable_with_args = allow_args(func) + # If no batch size provided just vmap over all vars if batch_sizes is None: batch_sizes = dict.fromkeys(variables, 0) @@ -205,54 +213,19 @@ def productmap( return cast("FunctionWithArrayReturn", allow_only_kwargs(vmapped, enforce=False)) -def _base_productmap( - func: FunctionWithArrayReturn, product_axes: tuple[str, ...] -) -> FunctionWithArrayReturn: - """Map func over the Cartesian product of product_axes. - - Like vmap, this function does not preserve the function signature and does not allow - the function to be called with keyword arguments. - - Args: - func: The function to be dispatched. Cannot have keyword-only arguments. - product_axes: Tuple with names of arguments over which we apply vmap. - - Returns: - A callable with the same arguments as func. See `product_map` for details. - - """ - signature = inspect.signature(func) - parameters = list(signature.parameters) - positions = [parameters.index(ax) for ax in product_axes if ax in parameters] - - vmap_specs = [] - # We iterate in reverse order such that the output dimensions are in the same order - # as the input dimensions. - for pos in reversed(positions): - spec: list[int | None] = cast("list[int | None]", [None] * len(parameters)) - spec[pos] = 0 - vmap_specs.append(spec) - - vmapped = func - for spec in vmap_specs: - vmapped = vmap(vmapped, in_axes=spec) - return vmapped - - def _base_productmap_batched( func: FunctionWithArrayReturn, product_axes: tuple[str, ...], batch_sizes: dict[str, int], ) -> FunctionWithArrayReturn: - """Map func over the Cartesian product of product_axes. + """Map func over the Cartesian product of product_axes and execute in batches. - Like vmap, this function does not preserve the function signature and does not allow - the function to be called with keyword arguments. + Like vmap, this function does not preserve the function signature. Args: func: The function to be dispatched. Cannot have keyword-only arguments. product_axes: Tuple with names of arguments over which we apply vmap. - + batch_sizes: Dict with the batch sizes for each product_axis. Returns: A callable with the same arguments as func. See `product_map` for details. diff --git a/src/lcm_examples/mahler_yum_2024/_model.py b/src/lcm_examples/mahler_yum_2024/_model.py index 3ff4369c..536da538 100644 --- a/src/lcm_examples/mahler_yum_2024/_model.py +++ b/src/lcm_examples/mahler_yum_2024/_model.py @@ -37,7 +37,7 @@ Period, RegimeName, ) -from lcm.utils.dispatchers import _base_productmap +from lcm.utils.dispatchers import productmap _DATA_DIR = Path(__file__).parent / "data" @@ -502,8 +502,8 @@ def calc_base(period: Period, health: Int1D, education: Int1D) -> Float1D: * jnp.exp(((sdztemp**2.0) ** 2.0) / 2.0) ) - mapped = _base_productmap(calc_base, ("_period", "health", "education")) - return mapped(j, health, education) + mapped = productmap(func=calc_base, variables=("period", "health", "education")) + return mapped(period=j, health=health, education=education) eff_grid: Float1D = jnp.linspace(0, 1, 40) @@ -536,13 +536,18 @@ def health_trans( return jnp.exp(y) / (1.0 + jnp.exp(y)) -mapped_health_trans = _base_productmap( - health_trans, ("period", "health", "eff", "eff_1", "edu", "ht") +mapped_health_trans = productmap( + func=health_trans, variables=("period", "health", "eff", "eff_1", "edu", "ht") ) tr2yp_grid = tr2yp_grid.at[:, :, :, :, :, :, 1].set( mapped_health_trans( - j, jnp.arange(2), jnp.arange(40), jnp.arange(40), jnp.arange(2), jnp.arange(2) + period=j, + health=jnp.arange(2), + eff=jnp.arange(40), + eff_1=jnp.arange(40), + edu=jnp.arange(2), + ht=jnp.arange(2), ) ) tr2yp_grid = tr2yp_grid.at[:, :, :, :, :, :, 0].set( From 3e29e5bd65d48d9962a39ae77a9002ce215881a7 Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 30 Mar 2026 19:19:10 +0200 Subject: [PATCH 6/8] Fix Typing --- tests/test_runtime_shock_params.py | 4 +++- tests/test_shock_draw.py | 22 ++++++++++++++-------- tests/test_shock_grids.py | 12 +++++++----- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/test_runtime_shock_params.py b/tests/test_runtime_shock_params.py index 9c225249..1fdf5ffd 100644 --- a/tests/test_runtime_shock_params.py +++ b/tests/test_runtime_shock_params.py @@ -84,7 +84,9 @@ def test_runtime_shock_params_property(): def test_fully_specified_shock(): """Tauchen with all params should have no runtime-supplied params.""" - grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=False, **_TAUCHEN_PARAMS) + grid = lcm.shocks.ar1.Tauchen( + n_points=5, gauss_hermite=False, batch_size=0, **_TAUCHEN_PARAMS + ) assert grid.params_to_pass_at_runtime == () assert grid.is_fully_specified diff --git a/tests/test_shock_draw.py b/tests/test_shock_draw.py index 086c9434..6356f86b 100644 --- a/tests/test_shock_draw.py +++ b/tests/test_shock_draw.py @@ -48,7 +48,7 @@ def test_draw_shock_uniform(params_at_init): """Uniform.draw_shock uses start/stop params.""" kwargs = {"start": 2.0, "stop": 4.0} if params_at_init: - grid = lcm.shocks.iid.Uniform(n_points=5, **kwargs) + grid = lcm.shocks.iid.Uniform(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.iid.Uniform(n_points=5) @@ -64,10 +64,12 @@ def test_draw_shock_normal(params_at_init): """Normal.draw_shock uses mu/sigma params.""" kwargs = {"mu": 5.0, "sigma": 0.1} if params_at_init: - grid = lcm.shocks.iid.Normal(n_points=5, gauss_hermite=True, **kwargs) + grid = lcm.shocks.iid.Normal( + n_points=5, batch_size=0, gauss_hermite=True, **kwargs + ) params = grid.params else: - grid = lcm.shocks.iid.Normal(n_points=5, gauss_hermite=True) + grid = lcm.shocks.iid.Normal(n_points=5, batch_size=0, gauss_hermite=True) params = MappingProxyType(kwargs) draws = _draw_many(grid, params) aaae(draws.mean(), 5.0, decimal=1) @@ -79,7 +81,9 @@ def test_draw_shock_lognormal(params_at_init): """LogNormal.draw_shock produces positive samples with correct log-moments.""" kwargs = {"mu": 1.0, "sigma": 0.1} if params_at_init: - grid = lcm.shocks.iid.LogNormal(n_points=5, gauss_hermite=True, **kwargs) + grid = lcm.shocks.iid.LogNormal( + n_points=5, batch_size=0, gauss_hermite=True, **kwargs + ) params = grid.params else: grid = lcm.shocks.iid.LogNormal(n_points=5, gauss_hermite=True) @@ -95,7 +99,9 @@ def test_draw_shock_tauchen(params_at_init): """Tauchen.draw_shock uses mu/sigma/rho params.""" kwargs = {"rho": 0.5, "sigma": 0.1, "mu": 2.0} if params_at_init: - grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=True, **kwargs) + grid = lcm.shocks.ar1.Tauchen( + n_points=5, batch_size=0, gauss_hermite=True, **kwargs + ) params = grid.params else: grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=True) @@ -110,7 +116,7 @@ def test_draw_shock_rouwenhorst(params_at_init): """Rouwenhorst.draw_shock uses mu/sigma/rho params.""" kwargs = {"rho": 0.5, "sigma": 0.1, "mu": 2.0} if params_at_init: - grid = lcm.shocks.ar1.Rouwenhorst(n_points=5, **kwargs) + grid = lcm.shocks.ar1.Rouwenhorst(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.ar1.Rouwenhorst(n_points=5) @@ -125,7 +131,7 @@ def test_draw_shock_normal_mixture(params_at_init): """NormalMixture.draw_shock produces draws with correct mixture moments.""" kwargs = _NORMAL_MIXTURE_KWARGS if params_at_init: - grid = lcm.shocks.iid.NormalMixture(n_points=5, **kwargs) + grid = lcm.shocks.iid.NormalMixture(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.iid.NormalMixture(n_points=5) @@ -147,7 +153,7 @@ def test_draw_shock_tauchen_normal_mixture(params_at_init): """TauchenNormalMixture.draw_shock produces yields correct conditional moments.""" kwargs = _TAUCHEN_NORMAL_MIXTURE_KWARGS if params_at_init: - grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=5, **kwargs) + grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=5) diff --git a/tests/test_shock_grids.py b/tests/test_shock_grids.py index 80b63103..14d46540 100644 --- a/tests/test_shock_grids.py +++ b/tests/test_shock_grids.py @@ -371,7 +371,9 @@ def test_lognormal_gauss_hermite_weights_sum_to_one(): def test_normal_mixture_transition_probs_rows_sum_to_one(): """NormalMixture transition probability rows sum to 1.""" - grid = lcm.shocks.iid.NormalMixture(n_points=7, **_NORMAL_MIXTURE_KWARGS) + grid = lcm.shocks.iid.NormalMixture( + n_points=7, batch_size=0, **_NORMAL_MIXTURE_KWARGS + ) P = grid.get_transition_probs() row_sums = P.sum(axis=1) aaae(row_sums, jnp.ones(7), decimal=DECIMAL_PRECISION) @@ -380,7 +382,7 @@ def test_normal_mixture_transition_probs_rows_sum_to_one(): def test_iid_normal_mixture_stationary_moments(): """IID NormalMixture stationary mean and std match mixture moments.""" kwargs = _NORMAL_MIXTURE_KWARGS - grid = lcm.shocks.iid.NormalMixture(n_points=21, **kwargs) + grid = lcm.shocks.iid.NormalMixture(n_points=21, batch_size=0, **kwargs) got_mean, got_std = _stationary_moments( grid.get_gridpoints(), grid.get_transition_probs() ) @@ -410,7 +412,7 @@ def test_iid_normal_mixture_stationary_moments(): def test_tauchen_normal_mixture_transition_probs_rows_sum_to_one(): """TauchenNormalMixture transition probability rows sum to 1.""" grid = lcm.shocks.ar1.TauchenNormalMixture( - n_points=7, **_TAUCHEN_NORMAL_MIXTURE_KWARGS + n_points=7, batch_size=0, **_TAUCHEN_NORMAL_MIXTURE_KWARGS ) P = grid.get_transition_probs() row_sums = P.sum(axis=1) @@ -420,7 +422,7 @@ def test_tauchen_normal_mixture_transition_probs_rows_sum_to_one(): def test_tauchen_normal_mixture_centers_on_unconditional_mean(): """TauchenNormalMixture gridpoints center on (mu + mean_eps) / (1 - rho).""" kwargs = _TAUCHEN_NORMAL_MIXTURE_KWARGS - grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=11, **kwargs) + grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=11, batch_size=0, **kwargs) points = grid.get_gridpoints() midpoint = (points[0] + points[-1]) / 2 mean_eps = kwargs["p1"] * kwargs["mu1"] + (1 - kwargs["p1"]) * kwargs["mu2"] @@ -431,7 +433,7 @@ def test_tauchen_normal_mixture_centers_on_unconditional_mean(): def test_tauchen_normal_mixture_stationary_moments_and_autocorrelation(): """TauchenNormalMixture stationary mean, std, and autocorrelation match theory.""" kwargs = _TAUCHEN_NORMAL_MIXTURE_KWARGS - grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=21, **kwargs) + grid = lcm.shocks.ar1.TauchenNormalMixture(batch_size=0, n_points=21, **kwargs) points = grid.get_gridpoints() P = grid.get_transition_probs() From 6150b1e8a564ed920efe4311bca77c1b94ea0b95 Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 30 Mar 2026 20:06:28 +0200 Subject: [PATCH 7/8] Fix MY Result --- .../f64/mahler_yum_simulation.pkl | Bin 14458 -> 14458 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/tests/data/regression_tests/f64/mahler_yum_simulation.pkl b/tests/data/regression_tests/f64/mahler_yum_simulation.pkl index eb25d48b7c28d1072b24568f8bf560cfc66afd83..b4109af22980f351acaa72b25f088cbefd3cac8a 100644 GIT binary patch delta 1852 zcmcgrc~H}55GA2XKrDgCZ38hO0SpiW0pu_n8KmJ35E60{Tf7b6)I!?9s03=g06#q1 zf}(<`C};sII_oes4OvI-s< za@}RHb-2oZkB=1gRwKKtN+~>jU!aoj7Xw}SD@U(a48r~tX9 z;W$V(tMvE&Q~(v}Rad_C5voD*VSCx2MgR|}7hl&Z7J&>1KI%R%2r@1kk;^>=-&4V_^ffB<$X{A*Y-(9vYy35X@G5Fs82lXX)9r1xN<0<1rPSVlEXNw zthw!?9L~I_U!; z78P7xQ4Ba&^_f|PM#F}IXLoK-@}Pp#s5*H&61G>~PV4Ilhr$SpAf+M{^zt@|i^sWe z@5SAqlH)-zwLeaz2s8==mGm-hwQ&HT+7B4&>Cx?IvP%I~%{eXX`wFOTbMd|rrhxRk z&(q23c(5`_nEOrJe&Zw3q>GmvG80cY6%WXOcYUrT*GmQ_wSv=cU8Jz0(nDowCV_QR z%^}pE#jxp+Z(38n7>r+z+HaW^L1*_f+u~{w&|;H)%~pxPaPU(55Lu82?n8qU`Gi=| z<0Lfp4a5T9CUuJbk`GfMYdy%_F@QUgp$e*whO4ZHyqA1KBJ5^v&xK$skxQJD136 zaF@g8=dFDSHt}$%^41Rf7CG#!6?D!Xl|zT!KvW7}1{ziAdhrT55XbsA*!s$#mXvcG zWyoM|m=mnHA_Yd7Fhx@%(N2p0(di~J^hg@-jO2-+;%xMXc`-sb#bxXiuZe?{qrNY8 z91+4pGCL!bqP+&2+DGV0z5u!o1vJFSW5IE+;glvS8jj8l*U)@oAoqQ4*Voi&kQ7(p zPX5k=GCfLtOB&Cvgb z$=sF&5NcP_mR{!%TbySK+0y`4oes)_Laf)q=)KD?<0{PU$y6#$3kroor7nC_hec4d zF)EcrBBJv|3YudXqC%n_#vbiq+LZSYb5^1Tsy)hbAfj%nk-_{T2#7)3Je$=h;i~Q8vO5wxBWkE`%raXqkZ+E=I`2dL`K5%JtMmb0Dj zWBxrC!~f!92Mzy!x@UzUX>tV$Xj1?H delta 1907 zcmc&#drVVj6t_H!;J{X>Kx}!m)RqD*0wSWGioC7RR$BTlW^o9}LkGw}%9v!jAyq%J zMG>4}5D^*Wtwx^H!4!o>#dN-g3W_33SyaemEtJK$m}S`?JNMpm?s?`rznfp~RO(b} zUAM7H^@whAWVr$|T(U+BJd}_)va=_^TnXhTxsRT`lEFpo_*+hu9N7D&pEJmEm~dPl zOu8n6W3KDsPWDUTP{PBsUa<@$$?fsH6bYyo9@l1GmH^wq`*@*`1gvT=x$#595S&@3 z?!P92>7`)}gHj=MNkxf;{;?`}s&y>czd;CLYHZTeI02ZSt#?7`F_2b#u5f0{X1F+X z! zs})jq5UeO7XG9jUVJfljlIrm~&}Jo;#t#R=f>}Z4_yE-=(S#&-Pzh0kAAcFsqy+a) z{E;z<5|rhKwKwQW*n7`qm0Pz0cwFxKkx3bhSI6!r|% z-F((0fd{{)H&*9Mpqrk5$~Ia8rXIt|6`5j46ZT>6;>EzXGWI*4CkC&Txovd|#3~5z zy>;DJB!aJ7a~s!=#e)Aa)3jEfSTGg;u+Oqq2qi*RbUY zcryMOD_aTKr7P(LzDlsRLx&fU6;NFBd$sML91gHFTD7wrCZA8Y*ty7b)63nvW+adr zG!h@xEP-=rU6cz25||9Qz3Jw(7$S4D2}^3kz&jCXtLmds?1Zx=H zau~Mm-uoR#26a(?Oy(p>q44wmkmX9&W?qll`o%gsUm2>S(sH8 z#KQg3H9M`bIvVr^yl`FiQULs`6?P;u0azz=)0(Dv@Oocp5w3?1xYdI3%k@0?#*TAe zrsl%k;GUw8q-f|Bh;w_3H^J@oXN^(JMrdp@JGhM)0aqe>go6QLaN4N)Ovfl2t_qH3 z{YC*uDtv5TGZP3ypVE1nlK@9K+cTD3H}cbedjEzBY@@9c^&O#5bPi=!Ty<|`GMPw3 zWkfQ1>uZ5d5J?ykO7eDSxJx`}hDbC!BXdA06gyN(v9Tb*%&ZVl8-;-48B5U}3KmB( zM4jiefG~JDnqcfe*7jy-HwA~P86n69ze^WmqD3?;O7+CDtp6&X8|rBAKV+$h%OF^I z7^+yPj7C6bbh8W(#s>{}643tTPAHmbgTTuM`m@**4BdExoj=12 z8t@{ZbeatcW7^PIhOGBLSSD(tQPCI$tA}Cn`!G17&zUg|LG%{Pe}HkT4zh&zV9e9S z=m&$*bne}Z1m0jATsjJ{akGuzxC^_Ceszrzq6JN`})BEtCTIYKoeQD+Il I_FHB2C$zq7_y7O^ From c729e8e9eeec0bb0216cac0b965109728c4f0f7e Mon Sep 17 00:00:00 2001 From: mj023 Date: Tue, 31 Mar 2026 00:54:12 +0200 Subject: [PATCH 8/8] Fix spelling --- src/lcm/utils/dispatchers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lcm/utils/dispatchers.py b/src/lcm/utils/dispatchers.py index 60efebf4..a7024e0b 100644 --- a/src/lcm/utils/dispatchers.py +++ b/src/lcm/utils/dispatchers.py @@ -161,9 +161,9 @@ def productmap( variables: tuple[str, ...], batch_sizes: dict[str, int] | None = None, ) -> FunctionWithArrayReturn: - """Apply mp such that func is evaluated on the Cartesian product of variables. + """Apply map such that func is evaluated on the Cartesian product of variables. - This is achieved by an iterative application of mp. + This is achieved by an iterative application of map. In contrast to _base_productmap_batched, productmap preserves the function signature and allows the function to be called with keyword arguments. @@ -242,21 +242,21 @@ def batched_vmap(**kwargs: FloatND) -> FloatND: # Recursively map over one more product axe def map_one_more( - loop: FunctionWithArrayReturn, axis: str + loop_func: FunctionWithArrayReturn, axis: str ) -> FunctionWithArrayReturn: def func_mapped_over_one_more_axis( *already_mapped_args: Float1D, **already_mapped_kwargs: Float1D ) -> FloatND: if parameters[axis].kind == inspect.Parameter.POSITIONAL_ONLY: return jax.lax.map( - lambda axis_i: loop( + lambda axis_i: loop_func( axis_i, *already_mapped_args, **already_mapped_kwargs ), jax.numpy.atleast_1d(kwargs[axis]), batch_size=batch_sizes[axis], ) return jax.lax.map( - lambda axis_i: loop( + lambda axis_i: loop_func( *already_mapped_args, **{axis: axis_i}, **already_mapped_kwargs ), jax.numpy.atleast_1d(kwargs[axis]),