diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index c1af7efd..c81f9c6a 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -16,6 +16,7 @@ ) +@dataclass(frozen=True, kw_only=True) class ContinuousGrid(Grid): """Base class for grids representing continuous values with coordinate lookup. @@ -24,6 +25,9 @@ class ContinuousGrid(Grid): """ + batch_size: int = 0 + """Size of the batches that are looped over during the solution.""" + @overload def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ... @overload diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index 41451dee..f844aad6 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -6,15 +6,26 @@ from lcm.utils.containers import get_field_names_and_values -class _DiscreteGridBase(Grid): - """Base class for discrete grids: categories, codes, and JAX conversion.""" +class DiscreteGrid(Grid): + """A discrete grid defining the outcome space of a categorical variable. + + Args: + category_class: The category class representing the grid categories. Must + be a dataclass with fields that have unique int values. + + Raises: + GridInitializationError: If the `category_class` is not a dataclass with int + fields. + + """ - def __init__(self, category_class: type) -> None: + def __init__(self, category_class: type, batch_size: int = 0) -> None: _validate_discrete_grid(category_class) names_and_values = get_field_names_and_values(category_class) self.__categories = tuple(names_and_values.keys()) self.__codes = tuple(names_and_values.values()) self.__ordered: bool = getattr(category_class, "_ordered", False) + self.__batch_size: int = batch_size @property def categories(self) -> tuple[str, ...]: @@ -31,23 +42,11 @@ def ordered(self) -> bool: """Return whether the categories have a meaningful ordering.""" return self.__ordered + @property + def batch_size(self) -> int: + """Return batch size during solution.""" + return self.__batch_size + def to_jax(self) -> Int1D: """Convert the grid to a Jax array.""" return jnp.array(self.codes) - - -class DiscreteGrid(_DiscreteGridBase): - """A discrete grid defining the outcome space of a categorical variable. - - Args: - category_class: The category class representing the grid categories. Must - be a dataclass with fields that have unique int values. - - Raises: - GridInitializationError: If the `category_class` is not a dataclass with int - fields. - - """ - - def __init__(self, category_class: type) -> None: - super().__init__(category_class) diff --git a/src/lcm/regime_building/max_Q_over_a.py b/src/lcm/regime_building/max_Q_over_a.py index 62b32dca..fa875229 100644 --- a/src/lcm/regime_building/max_Q_over_a.py +++ b/src/lcm/regime_building/max_Q_over_a.py @@ -21,6 +21,7 @@ def get_max_Q_over_a( *, Q_and_F: Callable[..., tuple[FloatND, BoolND]], + batch_sizes: dict[str, int], action_names: tuple[str, ...], state_names: tuple[str, ...], ) -> MaxQOverAFunction: @@ -80,7 +81,7 @@ def max_Q_over_a( ) return Q_arr.max(where=F_arr, initial=-jnp.inf) - return productmap(func=max_Q_over_a, variables=state_names) + return productmap(func=max_Q_over_a, variables=state_names, batch_sizes=batch_sizes) def get_argmax_and_max_Q_over_a( diff --git a/src/lcm/regime_building/processing.py b/src/lcm/regime_building/processing.py index aca9c797..6efe8355 100644 --- a/src/lcm/regime_building/processing.py +++ b/src/lcm/regime_building/processing.py @@ -243,6 +243,7 @@ def _build_solve_functions( max_Q_over_a = _build_max_Q_over_a_per_period( state_action_space=state_action_space, Q_and_F_functions=Q_and_F_functions, + all_grids=all_grids[regime_name], enable_jit=enable_jit, ) @@ -1256,6 +1257,7 @@ def _build_max_Q_over_a_per_period( *, state_action_space: StateActionSpace, Q_and_F_functions: MappingProxyType[int, QAndFFunction], + all_grids: MappingProxyType[str, Grid], enable_jit: bool, ) -> MappingProxyType[int, MaxQOverAFunction]: """Build max-Q-over-a closures for each period.""" @@ -1263,6 +1265,7 @@ def _build_max_Q_over_a_per_period( for period, Q_and_F in Q_and_F_functions.items(): func = get_max_Q_over_a( Q_and_F=Q_and_F, + batch_sizes={name: grid.batch_size for name, grid in all_grids.items()}, # ty: ignore[unresolved-attribute] action_names=state_action_space.action_names, state_names=state_action_space.state_names, ) diff --git a/src/lcm/utils/dispatchers.py b/src/lcm/utils/dispatchers.py index 22605685..a7024e0b 100644 --- a/src/lcm/utils/dispatchers.py +++ b/src/lcm/utils/dispatchers.py @@ -1,10 +1,13 @@ import inspect from collections.abc import Callable +from functools import partial from types import MappingProxyType from typing import Literal, TypeVar, cast +import jax from jax import Array, vmap +from lcm.typing import Float1D, FloatND from lcm.utils.containers import find_duplicates from lcm.utils.functools import allow_args, allow_only_kwargs @@ -65,7 +68,13 @@ def simulation_spacemap( mappable_func = allow_args(func) - vmapped = _base_productmap(mappable_func, action_names) + vmapped = allow_args( + productmap( + func=mappable_func, + variables=action_names, + batch_sizes=dict.fromkeys(action_names, 0), + ) + ) vmapped = vmap_1d(func=vmapped, variables=state_names, callable_with="only_args") # Callables do not necessarily have a __signature__ attribute. @@ -147,19 +156,23 @@ def vmap_1d( def productmap( - *, func: FunctionWithArrayReturn, variables: tuple[str, ...] + *, + func: FunctionWithArrayReturn, + variables: tuple[str, ...], + batch_sizes: dict[str, int] | None = None, ) -> FunctionWithArrayReturn: - """Apply vmap such that func is evaluated on the Cartesian product of variables. + """Apply map such that func is evaluated on the Cartesian product of variables. - This is achieved by an iterative application of vmap. + This is achieved by an iterative application of map. - In contrast to _base_productmap, productmap preserves the function signature and - allows the function to be called with keyword arguments. + In contrast to _base_productmap_batched, productmap preserves the function signature + and allows the function to be called with keyword arguments. Args: func: The function to be dispatched. variables: Tuple with names of arguments that over which the Cartesian product should be formed. + batch_sizes: Dict with the batch sizes for each variable. Returns: A callable with the same arguments as func (but with an additional leading @@ -179,45 +192,82 @@ def productmap( func_callable_with_args = allow_args(func) - vmapped = _base_productmap(func_callable_with_args, variables) + # If no batch size provided just vmap over all vars + if batch_sizes is None: + batch_sizes = dict.fromkeys(variables, 0) - # Callables do not necessarily have a __signature__ attribute. - vmapped.__signature__ = inspect.signature(func_callable_with_args) # ty: ignore[unresolved-attribute] + vmapped = _base_productmap_batched( + func_callable_with_args, variables, batch_sizes=batch_sizes + ) + + # Create new signature where every parameter is kw-only as + # batched_vmap takes only kwargs + signature = inspect.signature(func_callable_with_args) + new_parameters = [ + p.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for p in signature.parameters.values() + ] + new_signature = signature.replace(parameters=new_parameters) + vmapped.__signature__ = new_signature # ty: ignore[unresolved-attribute] return cast("FunctionWithArrayReturn", allow_only_kwargs(vmapped, enforce=False)) -def _base_productmap( - func: FunctionWithArrayReturn, product_axes: tuple[str, ...] +def _base_productmap_batched( + func: FunctionWithArrayReturn, + product_axes: tuple[str, ...], + batch_sizes: dict[str, int], ) -> FunctionWithArrayReturn: - """Map func over the Cartesian product of product_axes. + """Map func over the Cartesian product of product_axes and execute in batches. - Like vmap, this function does not preserve the function signature and does not allow - the function to be called with keyword arguments. + Like vmap, this function does not preserve the function signature. Args: func: The function to be dispatched. Cannot have keyword-only arguments. product_axes: Tuple with names of arguments over which we apply vmap. - + batch_sizes: Dict with the batch sizes for each product_axis. Returns: A callable with the same arguments as func. See `product_map` for details. """ - signature = inspect.signature(func) - parameters = list(signature.parameters) - - positions = [parameters.index(ax) for ax in product_axes if ax in parameters] - - vmap_specs = [] - # We iterate in reverse order such that the output dimensions are in the same order - # as the input dimensions. - for pos in reversed(positions): - spec: list[int | None] = cast("list[int | None]", [None] * len(parameters)) - spec[pos] = 0 - vmap_specs.append(spec) - - vmapped = func - for spec in vmap_specs: - vmapped = vmap(vmapped, in_axes=spec) + parameters = inspect.signature(func).parameters + + def batched_vmap(**kwargs: FloatND) -> FloatND: + non_array_kwargs = { + key: val for key, val in kwargs.items() if key not in product_axes + } + func_with_partialled_args = cast( + "FunctionWithArrayReturn", partial(func, **non_array_kwargs) + ) - return vmapped + # Recursively map over one more product axe + def map_one_more( + loop_func: FunctionWithArrayReturn, axis: str + ) -> FunctionWithArrayReturn: + def func_mapped_over_one_more_axis( + *already_mapped_args: Float1D, **already_mapped_kwargs: Float1D + ) -> FloatND: + if parameters[axis].kind == inspect.Parameter.POSITIONAL_ONLY: + return jax.lax.map( + lambda axis_i: loop_func( + axis_i, *already_mapped_args, **already_mapped_kwargs + ), + jax.numpy.atleast_1d(kwargs[axis]), + batch_size=batch_sizes[axis], + ) + return jax.lax.map( + lambda axis_i: loop_func( + *already_mapped_args, **{axis: axis_i}, **already_mapped_kwargs + ), + jax.numpy.atleast_1d(kwargs[axis]), + batch_size=batch_sizes[axis], + ) + + return cast("FunctionWithArrayReturn", func_mapped_over_one_more_axis) + + # Loop over all product axes + for axis in reversed(product_axes): + func_with_partialled_args = map_one_more(func_with_partialled_args, axis) + return func_with_partialled_args() # ty: ignore[invalid-return-type] + + return cast("FunctionWithArrayReturn", batched_vmap) diff --git a/src/lcm_examples/mahler_yum_2024/_model.py b/src/lcm_examples/mahler_yum_2024/_model.py index 3ff4369c..536da538 100644 --- a/src/lcm_examples/mahler_yum_2024/_model.py +++ b/src/lcm_examples/mahler_yum_2024/_model.py @@ -37,7 +37,7 @@ Period, RegimeName, ) -from lcm.utils.dispatchers import _base_productmap +from lcm.utils.dispatchers import productmap _DATA_DIR = Path(__file__).parent / "data" @@ -502,8 +502,8 @@ def calc_base(period: Period, health: Int1D, education: Int1D) -> Float1D: * jnp.exp(((sdztemp**2.0) ** 2.0) / 2.0) ) - mapped = _base_productmap(calc_base, ("_period", "health", "education")) - return mapped(j, health, education) + mapped = productmap(func=calc_base, variables=("period", "health", "education")) + return mapped(period=j, health=health, education=education) eff_grid: Float1D = jnp.linspace(0, 1, 40) @@ -536,13 +536,18 @@ def health_trans( return jnp.exp(y) / (1.0 + jnp.exp(y)) -mapped_health_trans = _base_productmap( - health_trans, ("period", "health", "eff", "eff_1", "edu", "ht") +mapped_health_trans = productmap( + func=health_trans, variables=("period", "health", "eff", "eff_1", "edu", "ht") ) tr2yp_grid = tr2yp_grid.at[:, :, :, :, :, :, 1].set( mapped_health_trans( - j, jnp.arange(2), jnp.arange(40), jnp.arange(40), jnp.arange(2), jnp.arange(2) + period=j, + health=jnp.arange(2), + eff=jnp.arange(40), + eff_1=jnp.arange(40), + edu=jnp.arange(2), + ht=jnp.arange(2), ) ) tr2yp_grid = tr2yp_grid.at[:, :, :, :, :, :, 0].set( diff --git a/tests/data/regression_tests/f64/mahler_yum_simulation.pkl b/tests/data/regression_tests/f64/mahler_yum_simulation.pkl index eb25d48b..b4109af2 100644 Binary files a/tests/data/regression_tests/f64/mahler_yum_simulation.pkl and b/tests/data/regression_tests/f64/mahler_yum_simulation.pkl differ diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index c4cd87b9..92db0a97 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -119,6 +119,7 @@ def _Q_and_F( Q_and_F=_Q_and_F, action_names=("consumption", "labor_supply"), state_names=("lazy", "wealth"), + batch_sizes={"lazy": 0, "wealth": 0}, ) # ================================================================================== @@ -176,6 +177,7 @@ def _Q_and_F(a, c, b, d, next_regime_to_V_arr): # noqa: ARG001 Q_and_F=_Q_and_F, action_names=("d",), state_names=("a", "b", "c"), + batch_sizes={"a": 0, "b": 0, "c": 0}, ) expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]]) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 4ae09d3d..ead717b4 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -142,18 +142,6 @@ def test_productmap_with_all_arguments_mapped_some_len_one(): aaae(calculated, expected) -def test_productmap_with_all_arguments_mapped_some_scalar(): - grids = { - "a": 1, - "b": 2, - "c": jnp.linspace(1, 5, 5), - } - - decorated = productmap(func=f, variables=("a", "b", "c")) - with pytest.raises(ValueError, match="vmap was requested to map its argument"): - decorated(**grids) # ty: ignore[missing-argument] - - def test_productmap_with_some_arguments_mapped(): grids = { "a": jnp.linspace(-5, 5, 10), diff --git a/tests/test_runtime_shock_params.py b/tests/test_runtime_shock_params.py index 9c225249..1fdf5ffd 100644 --- a/tests/test_runtime_shock_params.py +++ b/tests/test_runtime_shock_params.py @@ -84,7 +84,9 @@ def test_runtime_shock_params_property(): def test_fully_specified_shock(): """Tauchen with all params should have no runtime-supplied params.""" - grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=False, **_TAUCHEN_PARAMS) + grid = lcm.shocks.ar1.Tauchen( + n_points=5, gauss_hermite=False, batch_size=0, **_TAUCHEN_PARAMS + ) assert grid.params_to_pass_at_runtime == () assert grid.is_fully_specified diff --git a/tests/test_shock_draw.py b/tests/test_shock_draw.py index 086c9434..6356f86b 100644 --- a/tests/test_shock_draw.py +++ b/tests/test_shock_draw.py @@ -48,7 +48,7 @@ def test_draw_shock_uniform(params_at_init): """Uniform.draw_shock uses start/stop params.""" kwargs = {"start": 2.0, "stop": 4.0} if params_at_init: - grid = lcm.shocks.iid.Uniform(n_points=5, **kwargs) + grid = lcm.shocks.iid.Uniform(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.iid.Uniform(n_points=5) @@ -64,10 +64,12 @@ def test_draw_shock_normal(params_at_init): """Normal.draw_shock uses mu/sigma params.""" kwargs = {"mu": 5.0, "sigma": 0.1} if params_at_init: - grid = lcm.shocks.iid.Normal(n_points=5, gauss_hermite=True, **kwargs) + grid = lcm.shocks.iid.Normal( + n_points=5, batch_size=0, gauss_hermite=True, **kwargs + ) params = grid.params else: - grid = lcm.shocks.iid.Normal(n_points=5, gauss_hermite=True) + grid = lcm.shocks.iid.Normal(n_points=5, batch_size=0, gauss_hermite=True) params = MappingProxyType(kwargs) draws = _draw_many(grid, params) aaae(draws.mean(), 5.0, decimal=1) @@ -79,7 +81,9 @@ def test_draw_shock_lognormal(params_at_init): """LogNormal.draw_shock produces positive samples with correct log-moments.""" kwargs = {"mu": 1.0, "sigma": 0.1} if params_at_init: - grid = lcm.shocks.iid.LogNormal(n_points=5, gauss_hermite=True, **kwargs) + grid = lcm.shocks.iid.LogNormal( + n_points=5, batch_size=0, gauss_hermite=True, **kwargs + ) params = grid.params else: grid = lcm.shocks.iid.LogNormal(n_points=5, gauss_hermite=True) @@ -95,7 +99,9 @@ def test_draw_shock_tauchen(params_at_init): """Tauchen.draw_shock uses mu/sigma/rho params.""" kwargs = {"rho": 0.5, "sigma": 0.1, "mu": 2.0} if params_at_init: - grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=True, **kwargs) + grid = lcm.shocks.ar1.Tauchen( + n_points=5, batch_size=0, gauss_hermite=True, **kwargs + ) params = grid.params else: grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=True) @@ -110,7 +116,7 @@ def test_draw_shock_rouwenhorst(params_at_init): """Rouwenhorst.draw_shock uses mu/sigma/rho params.""" kwargs = {"rho": 0.5, "sigma": 0.1, "mu": 2.0} if params_at_init: - grid = lcm.shocks.ar1.Rouwenhorst(n_points=5, **kwargs) + grid = lcm.shocks.ar1.Rouwenhorst(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.ar1.Rouwenhorst(n_points=5) @@ -125,7 +131,7 @@ def test_draw_shock_normal_mixture(params_at_init): """NormalMixture.draw_shock produces draws with correct mixture moments.""" kwargs = _NORMAL_MIXTURE_KWARGS if params_at_init: - grid = lcm.shocks.iid.NormalMixture(n_points=5, **kwargs) + grid = lcm.shocks.iid.NormalMixture(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.iid.NormalMixture(n_points=5) @@ -147,7 +153,7 @@ def test_draw_shock_tauchen_normal_mixture(params_at_init): """TauchenNormalMixture.draw_shock produces yields correct conditional moments.""" kwargs = _TAUCHEN_NORMAL_MIXTURE_KWARGS if params_at_init: - grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=5, **kwargs) + grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=5, batch_size=0, **kwargs) params = grid.params else: grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=5) diff --git a/tests/test_shock_grids.py b/tests/test_shock_grids.py index 80b63103..14d46540 100644 --- a/tests/test_shock_grids.py +++ b/tests/test_shock_grids.py @@ -371,7 +371,9 @@ def test_lognormal_gauss_hermite_weights_sum_to_one(): def test_normal_mixture_transition_probs_rows_sum_to_one(): """NormalMixture transition probability rows sum to 1.""" - grid = lcm.shocks.iid.NormalMixture(n_points=7, **_NORMAL_MIXTURE_KWARGS) + grid = lcm.shocks.iid.NormalMixture( + n_points=7, batch_size=0, **_NORMAL_MIXTURE_KWARGS + ) P = grid.get_transition_probs() row_sums = P.sum(axis=1) aaae(row_sums, jnp.ones(7), decimal=DECIMAL_PRECISION) @@ -380,7 +382,7 @@ def test_normal_mixture_transition_probs_rows_sum_to_one(): def test_iid_normal_mixture_stationary_moments(): """IID NormalMixture stationary mean and std match mixture moments.""" kwargs = _NORMAL_MIXTURE_KWARGS - grid = lcm.shocks.iid.NormalMixture(n_points=21, **kwargs) + grid = lcm.shocks.iid.NormalMixture(n_points=21, batch_size=0, **kwargs) got_mean, got_std = _stationary_moments( grid.get_gridpoints(), grid.get_transition_probs() ) @@ -410,7 +412,7 @@ def test_iid_normal_mixture_stationary_moments(): def test_tauchen_normal_mixture_transition_probs_rows_sum_to_one(): """TauchenNormalMixture transition probability rows sum to 1.""" grid = lcm.shocks.ar1.TauchenNormalMixture( - n_points=7, **_TAUCHEN_NORMAL_MIXTURE_KWARGS + n_points=7, batch_size=0, **_TAUCHEN_NORMAL_MIXTURE_KWARGS ) P = grid.get_transition_probs() row_sums = P.sum(axis=1) @@ -420,7 +422,7 @@ def test_tauchen_normal_mixture_transition_probs_rows_sum_to_one(): def test_tauchen_normal_mixture_centers_on_unconditional_mean(): """TauchenNormalMixture gridpoints center on (mu + mean_eps) / (1 - rho).""" kwargs = _TAUCHEN_NORMAL_MIXTURE_KWARGS - grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=11, **kwargs) + grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=11, batch_size=0, **kwargs) points = grid.get_gridpoints() midpoint = (points[0] + points[-1]) / 2 mean_eps = kwargs["p1"] * kwargs["mu1"] + (1 - kwargs["p1"]) * kwargs["mu2"] @@ -431,7 +433,7 @@ def test_tauchen_normal_mixture_centers_on_unconditional_mean(): def test_tauchen_normal_mixture_stationary_moments_and_autocorrelation(): """TauchenNormalMixture stationary mean, std, and autocorrelation match theory.""" kwargs = _TAUCHEN_NORMAL_MIXTURE_KWARGS - grid = lcm.shocks.ar1.TauchenNormalMixture(n_points=21, **kwargs) + grid = lcm.shocks.ar1.TauchenNormalMixture(batch_size=0, n_points=21, **kwargs) points = grid.get_gridpoints() P = grid.get_transition_probs()