diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 89d947a9..8cd7a707 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -42,6 +42,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"], debug_mode: bool = True, jit: bool = True, + multi_device_support: bool = False, ) -> tuple[Callable[..., dict[int, Array] | pd.DataFrame], ParamsDict]: """Entry point for users to get high level functions generated by lcm. @@ -58,6 +59,8 @@ def get_lcm_function( "solve_and_simulate" are supported. debug_mode: Whether to log debug messages. jit: Whether to jit the internal functions. + multi_device_support: Whether to use sharded arrays to distribute the + computation on multiple devices. Returns: - A function that can be used to solve and/or simulate the model (see below). @@ -106,6 +109,7 @@ def get_lcm_function( state_action_space = create_state_action_space( model=internal_model, is_last_period=is_last_period, + multi_device_support=multi_device_support, ) state_space_info = create_state_space_info( diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index ebdc7b87..2fa32fe8 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING +import jax + from lcm.grids import ContinuousGrid, DiscreteGrid from lcm.interfaces import InternalModel, StateActionSpace, StateSpaceInfo @@ -16,6 +18,7 @@ def create_state_action_space( *, states: dict[str, Array] | None = None, is_last_period: bool = False, + multi_device_support: bool = False, ) -> StateActionSpace: """Create a state-action-space. @@ -28,6 +31,8 @@ def create_state_action_space( are used. is_last_period: Whether the state-action-space is created for the last period, in which case auxiliary variables are not included. + multi_device_support: Whether to use sharded arrays to distribute the + computation on multiple devices. Returns: A state-action-space. Contains the grids of the discrete and continuous actions, @@ -48,6 +53,19 @@ def create_state_action_space( ) _states = states + if multi_device_support: + device_count = jax.device_count() + divisible_state_name = _find_divisible_state( + states=_states, device_count=device_count, simulation=states is None + ) + mesh = jax.make_mesh((device_count,), (divisible_state_name,)) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(divisible_state_name) + ) + _states[divisible_state_name] = jax.device_put( + _states[divisible_state_name], device=sharding + ) + discrete_actions = { name: model.grids[name] for name in vi.query("is_action & is_discrete").index } @@ -120,3 +138,23 @@ def _validate_all_states_present( f"\n\nMissing initial states: {missing}\n", f"Provided variables that are not states: {too_many}", ) + + +def _find_divisible_state( + states: dict[str, Array], device_count: int, *, simulation: bool +) -> str: + for key, value in states.items(): + if (value.shape[0] % device_count) == 0: + return key + if simulation: + raise ValueError( + "If you want to use multiple devices, the number of initial states " + "has to be divisible by the number of available devices.\n" + f"Available devices: {device_count}", + ) + raise ValueError( + "If you want to use multiple devices, at least one state variable " + "has to have a number of gridpoints divisible by the number" + " of available devices.\n" + f"Available devices: {device_count}", + )