Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_lcm_function(
targets: Literal["solve", "simulate", "solve_and_simulate"],
debug_mode: bool = True,
jit: bool = True,
multi_device_support: bool = False,
) -> tuple[Callable[..., dict[int, Array] | pd.DataFrame], ParamsDict]:
"""Entry point for users to get high level functions generated by lcm.

Expand All @@ -58,6 +59,8 @@ def get_lcm_function(
"solve_and_simulate" are supported.
debug_mode: Whether to log debug messages.
jit: Whether to jit the internal functions.
multi_device_support: Whether to use sharded arrays to distribute the
computation on multiple devices.

Returns:
- A function that can be used to solve and/or simulate the model (see below).
Expand Down Expand Up @@ -106,6 +109,7 @@ def get_lcm_function(
state_action_space = create_state_action_space(
model=internal_model,
is_last_period=is_last_period,
multi_device_support=multi_device_support,
)

state_space_info = create_state_space_info(
Expand Down
38 changes: 38 additions & 0 deletions src/lcm/state_action_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import TYPE_CHECKING

import jax

from lcm.grids import ContinuousGrid, DiscreteGrid
from lcm.interfaces import InternalModel, StateActionSpace, StateSpaceInfo

Expand All @@ -16,6 +18,7 @@ def create_state_action_space(
*,
states: dict[str, Array] | None = None,
is_last_period: bool = False,
multi_device_support: bool = False,
) -> StateActionSpace:
"""Create a state-action-space.

Expand All @@ -28,6 +31,8 @@ def create_state_action_space(
are used.
is_last_period: Whether the state-action-space is created for the last period,
in which case auxiliary variables are not included.
multi_device_support: Whether to use sharded arrays to distribute the
computation on multiple devices.

Returns:
A state-action-space. Contains the grids of the discrete and continuous actions,
Expand All @@ -48,6 +53,19 @@ def create_state_action_space(
)
_states = states

if multi_device_support:
device_count = jax.device_count()
divisible_state_name = _find_divisible_state(
states=_states, device_count=device_count, simulation=states is None
)
mesh = jax.make_mesh((device_count,), (divisible_state_name,))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(divisible_state_name)
)
_states[divisible_state_name] = jax.device_put(
_states[divisible_state_name], device=sharding
)

discrete_actions = {
name: model.grids[name] for name in vi.query("is_action & is_discrete").index
}
Expand Down Expand Up @@ -120,3 +138,23 @@ def _validate_all_states_present(
f"\n\nMissing initial states: {missing}\n",
f"Provided variables that are not states: {too_many}",
)


def _find_divisible_state(
states: dict[str, Array], device_count: int, *, simulation: bool
) -> str:
for key, value in states.items():
if (value.shape[0] % device_count) == 0:
return key
if simulation:
raise ValueError(
"If you want to use multiple devices, the number of initial states "
"has to be divisible by the number of available devices.\n"
f"Available devices: {device_count}",
)
raise ValueError(
"If you want to use multiple devices, at least one state variable "
"has to have a number of gridpoints divisible by the number"
" of available devices.\n"
f"Available devices: {device_count}",
)
Loading