Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lcm/grids/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)


@dataclass(frozen=True, kw_only=True)
class ContinuousGrid(Grid):
"""Base class for grids representing continuous values with coordinate lookup.

Expand All @@ -24,6 +25,9 @@ class ContinuousGrid(Grid):

"""

batch_size: int = 0
"""Size of the batches that are looped over during the solution."""

@overload
def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ...
@overload
Expand Down
39 changes: 19 additions & 20 deletions src/lcm/grids/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,26 @@
from lcm.utils.containers import get_field_names_and_values


class _DiscreteGridBase(Grid):
"""Base class for discrete grids: categories, codes, and JAX conversion."""
class DiscreteGrid(Grid):
"""A discrete grid defining the outcome space of a categorical variable.

Args:
category_class: The category class representing the grid categories. Must
be a dataclass with fields that have unique int values.

Raises:
GridInitializationError: If the `category_class` is not a dataclass with int
fields.

"""

def __init__(self, category_class: type) -> None:
def __init__(self, category_class: type, batch_size: int = 0) -> None:
_validate_discrete_grid(category_class)
names_and_values = get_field_names_and_values(category_class)
self.__categories = tuple(names_and_values.keys())
self.__codes = tuple(names_and_values.values())
self.__ordered: bool = getattr(category_class, "_ordered", False)
self.__batch_size: int = batch_size

@property
def categories(self) -> tuple[str, ...]:
Expand All @@ -31,23 +42,11 @@ def ordered(self) -> bool:
"""Return whether the categories have a meaningful ordering."""
return self.__ordered

@property
def batch_size(self) -> int:
"""Return batch size during solution."""
return self.__batch_size

def to_jax(self) -> Int1D:
"""Convert the grid to a Jax array."""
return jnp.array(self.codes)


class DiscreteGrid(_DiscreteGridBase):
"""A discrete grid defining the outcome space of a categorical variable.

Args:
category_class: The category class representing the grid categories. Must
be a dataclass with fields that have unique int values.

Raises:
GridInitializationError: If the `category_class` is not a dataclass with int
fields.

"""

def __init__(self, category_class: type) -> None:
super().__init__(category_class)
3 changes: 2 additions & 1 deletion src/lcm/regime_building/max_Q_over_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
def get_max_Q_over_a(
*,
Q_and_F: Callable[..., tuple[FloatND, BoolND]],
batch_sizes: dict[str, int],
action_names: tuple[str, ...],
state_names: tuple[str, ...],
) -> MaxQOverAFunction:
Expand Down Expand Up @@ -80,7 +81,7 @@ def max_Q_over_a(
)
return Q_arr.max(where=F_arr, initial=-jnp.inf)

return productmap(func=max_Q_over_a, variables=state_names)
return productmap(func=max_Q_over_a, variables=state_names, batch_sizes=batch_sizes)


def get_argmax_and_max_Q_over_a(
Expand Down
3 changes: 3 additions & 0 deletions src/lcm/regime_building/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def _build_solve_functions(
max_Q_over_a = _build_max_Q_over_a_per_period(
state_action_space=state_action_space,
Q_and_F_functions=Q_and_F_functions,
all_grids=all_grids[regime_name],
enable_jit=enable_jit,
)

Expand Down Expand Up @@ -1256,13 +1257,15 @@ def _build_max_Q_over_a_per_period(
*,
state_action_space: StateActionSpace,
Q_and_F_functions: MappingProxyType[int, QAndFFunction],
all_grids: MappingProxyType[str, Grid],
enable_jit: bool,
) -> MappingProxyType[int, MaxQOverAFunction]:
"""Build max-Q-over-a closures for each period."""
result = {}
for period, Q_and_F in Q_and_F_functions.items():
func = get_max_Q_over_a(
Q_and_F=Q_and_F,
batch_sizes={name: grid.batch_size for name, grid in all_grids.items()}, # ty: ignore[unresolved-attribute]
action_names=state_action_space.action_names,
state_names=state_action_space.state_names,
)
Expand Down
114 changes: 82 additions & 32 deletions src/lcm/utils/dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import inspect
from collections.abc import Callable
from functools import partial
from types import MappingProxyType
from typing import Literal, TypeVar, cast

import jax
from jax import Array, vmap

from lcm.typing import Float1D, FloatND
from lcm.utils.containers import find_duplicates
from lcm.utils.functools import allow_args, allow_only_kwargs

Expand Down Expand Up @@ -65,7 +68,13 @@ def simulation_spacemap(

mappable_func = allow_args(func)

vmapped = _base_productmap(mappable_func, action_names)
vmapped = allow_args(
productmap(
func=mappable_func,
variables=action_names,
batch_sizes=dict.fromkeys(action_names, 0),
)
)
vmapped = vmap_1d(func=vmapped, variables=state_names, callable_with="only_args")

# Callables do not necessarily have a __signature__ attribute.
Expand Down Expand Up @@ -147,19 +156,23 @@ def vmap_1d(


def productmap(
*, func: FunctionWithArrayReturn, variables: tuple[str, ...]
*,
func: FunctionWithArrayReturn,
variables: tuple[str, ...],
batch_sizes: dict[str, int] | None = None,
) -> FunctionWithArrayReturn:
"""Apply vmap such that func is evaluated on the Cartesian product of variables.
"""Apply map such that func is evaluated on the Cartesian product of variables.

This is achieved by an iterative application of vmap.
This is achieved by an iterative application of map.

In contrast to _base_productmap, productmap preserves the function signature and
allows the function to be called with keyword arguments.
In contrast to _base_productmap_batched, productmap preserves the function signature
and allows the function to be called with keyword arguments.

Args:
func: The function to be dispatched.
variables: Tuple with names of arguments that over which the Cartesian product
should be formed.
batch_sizes: Dict with the batch sizes for each variable.

Returns:
A callable with the same arguments as func (but with an additional leading
Expand All @@ -179,45 +192,82 @@ def productmap(

func_callable_with_args = allow_args(func)

vmapped = _base_productmap(func_callable_with_args, variables)
# If no batch size provided just vmap over all vars
if batch_sizes is None:
batch_sizes = dict.fromkeys(variables, 0)

# Callables do not necessarily have a __signature__ attribute.
vmapped.__signature__ = inspect.signature(func_callable_with_args) # ty: ignore[unresolved-attribute]
vmapped = _base_productmap_batched(
func_callable_with_args, variables, batch_sizes=batch_sizes
)

# Create new signature where every parameter is kw-only as
# batched_vmap takes only kwargs
signature = inspect.signature(func_callable_with_args)
new_parameters = [
p.replace(kind=inspect.Parameter.KEYWORD_ONLY)
for p in signature.parameters.values()
]
new_signature = signature.replace(parameters=new_parameters)
vmapped.__signature__ = new_signature # ty: ignore[unresolved-attribute]

return cast("FunctionWithArrayReturn", allow_only_kwargs(vmapped, enforce=False))


def _base_productmap(
func: FunctionWithArrayReturn, product_axes: tuple[str, ...]
def _base_productmap_batched(
func: FunctionWithArrayReturn,
product_axes: tuple[str, ...],
batch_sizes: dict[str, int],
) -> FunctionWithArrayReturn:
"""Map func over the Cartesian product of product_axes.
"""Map func over the Cartesian product of product_axes and execute in batches.

Like vmap, this function does not preserve the function signature and does not allow
the function to be called with keyword arguments.
Like vmap, this function does not preserve the function signature.

Args:
func: The function to be dispatched. Cannot have keyword-only arguments.
product_axes: Tuple with names of arguments over which we apply vmap.

batch_sizes: Dict with the batch sizes for each product_axis.
Returns:
A callable with the same arguments as func. See `product_map` for details.

"""
signature = inspect.signature(func)
parameters = list(signature.parameters)

positions = [parameters.index(ax) for ax in product_axes if ax in parameters]

vmap_specs = []
# We iterate in reverse order such that the output dimensions are in the same order
# as the input dimensions.
for pos in reversed(positions):
spec: list[int | None] = cast("list[int | None]", [None] * len(parameters))
spec[pos] = 0
vmap_specs.append(spec)

vmapped = func
for spec in vmap_specs:
vmapped = vmap(vmapped, in_axes=spec)
parameters = inspect.signature(func).parameters

def batched_vmap(**kwargs: FloatND) -> FloatND:
non_array_kwargs = {
key: val for key, val in kwargs.items() if key not in product_axes
}
func_with_partialled_args = cast(
"FunctionWithArrayReturn", partial(func, **non_array_kwargs)
)

return vmapped
# Recursively map over one more product axe
def map_one_more(
loop_func: FunctionWithArrayReturn, axis: str
) -> FunctionWithArrayReturn:
def func_mapped_over_one_more_axis(
*already_mapped_args: Float1D, **already_mapped_kwargs: Float1D
) -> FloatND:
if parameters[axis].kind == inspect.Parameter.POSITIONAL_ONLY:
return jax.lax.map(
lambda axis_i: loop_func(
axis_i, *already_mapped_args, **already_mapped_kwargs
),
jax.numpy.atleast_1d(kwargs[axis]),
batch_size=batch_sizes[axis],
)
return jax.lax.map(
lambda axis_i: loop_func(
*already_mapped_args, **{axis: axis_i}, **already_mapped_kwargs
),
jax.numpy.atleast_1d(kwargs[axis]),
batch_size=batch_sizes[axis],
)

return cast("FunctionWithArrayReturn", func_mapped_over_one_more_axis)

# Loop over all product axes
for axis in reversed(product_axes):
func_with_partialled_args = map_one_more(func_with_partialled_args, axis)
return func_with_partialled_args() # ty: ignore[invalid-return-type]

return cast("FunctionWithArrayReturn", batched_vmap)
17 changes: 11 additions & 6 deletions src/lcm_examples/mahler_yum_2024/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
Period,
RegimeName,
)
from lcm.utils.dispatchers import _base_productmap
from lcm.utils.dispatchers import productmap

_DATA_DIR = Path(__file__).parent / "data"

Expand Down Expand Up @@ -502,8 +502,8 @@ def calc_base(period: Period, health: Int1D, education: Int1D) -> Float1D:
* jnp.exp(((sdztemp**2.0) ** 2.0) / 2.0)
)

mapped = _base_productmap(calc_base, ("_period", "health", "education"))
return mapped(j, health, education)
mapped = productmap(func=calc_base, variables=("period", "health", "education"))
return mapped(period=j, health=health, education=education)


eff_grid: Float1D = jnp.linspace(0, 1, 40)
Expand Down Expand Up @@ -536,13 +536,18 @@ def health_trans(
return jnp.exp(y) / (1.0 + jnp.exp(y))


mapped_health_trans = _base_productmap(
health_trans, ("period", "health", "eff", "eff_1", "edu", "ht")
mapped_health_trans = productmap(
func=health_trans, variables=("period", "health", "eff", "eff_1", "edu", "ht")
)

tr2yp_grid = tr2yp_grid.at[:, :, :, :, :, :, 1].set(
mapped_health_trans(
j, jnp.arange(2), jnp.arange(40), jnp.arange(40), jnp.arange(2), jnp.arange(2)
period=j,
health=jnp.arange(2),
eff=jnp.arange(40),
eff_1=jnp.arange(40),
edu=jnp.arange(2),
ht=jnp.arange(2),
)
)
tr2yp_grid = tr2yp_grid.at[:, :, :, :, :, :, 0].set(
Expand Down
Binary file modified tests/data/regression_tests/f64/mahler_yum_simulation.pkl
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/solution/test_solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _Q_and_F(
Q_and_F=_Q_and_F,
action_names=("consumption", "labor_supply"),
state_names=("lazy", "wealth"),
batch_sizes={"lazy": 0, "wealth": 0},
)

# ==================================================================================
Expand Down Expand Up @@ -176,6 +177,7 @@ def _Q_and_F(a, c, b, d, next_regime_to_V_arr): # noqa: ARG001
Q_and_F=_Q_and_F,
action_names=("d",),
state_names=("a", "b", "c"),
batch_sizes={"a": 0, "b": 0, "c": 0},
)

expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]])
Expand Down
12 changes: 0 additions & 12 deletions tests/test_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,6 @@ def test_productmap_with_all_arguments_mapped_some_len_one():
aaae(calculated, expected)


def test_productmap_with_all_arguments_mapped_some_scalar():
grids = {
"a": 1,
"b": 2,
"c": jnp.linspace(1, 5, 5),
}

decorated = productmap(func=f, variables=("a", "b", "c"))
with pytest.raises(ValueError, match="vmap was requested to map its argument"):
decorated(**grids) # ty: ignore[missing-argument]


def test_productmap_with_some_arguments_mapped():
grids = {
"a": jnp.linspace(-5, 5, 10),
Expand Down
4 changes: 3 additions & 1 deletion tests/test_runtime_shock_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def test_runtime_shock_params_property():

def test_fully_specified_shock():
"""Tauchen with all params should have no runtime-supplied params."""
grid = lcm.shocks.ar1.Tauchen(n_points=5, gauss_hermite=False, **_TAUCHEN_PARAMS)
grid = lcm.shocks.ar1.Tauchen(
n_points=5, gauss_hermite=False, batch_size=0, **_TAUCHEN_PARAMS
)
assert grid.params_to_pass_at_runtime == ()
assert grid.is_fully_specified

Expand Down
Loading
Loading