From 33f7ee966965ac7c20edc3405c3870ee81fc6a72 Mon Sep 17 00:00:00 2001 From: mj023 Date: Wed, 27 Aug 2025 17:04:08 +0200 Subject: [PATCH 1/4] Add device sharding --- src/lcm/entry_point.py | 4 ++++ src/lcm/state_action_space.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 89d947a9..3c3e2704 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -42,6 +42,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"], debug_mode: bool = True, jit: bool = True, + multi_device_support: bool = False, ) -> tuple[Callable[..., dict[int, Array] | pd.DataFrame], ParamsDict]: """Entry point for users to get high level functions generated by lcm. @@ -58,6 +59,8 @@ def get_lcm_function( "solve_and_simulate" are supported. debug_mode: Whether to log debug messages. jit: Whether to jit the internal functions. + multi_device_support: Whether to use sharded arrays to distribute the + computation on multiple devices. Returns: - A function that can be used to solve and/or simulate the model (see below). @@ -106,6 +109,7 @@ def get_lcm_function( state_action_space = create_state_action_space( model=internal_model, is_last_period=is_last_period, + multi_device_support=multi_device_support ) state_space_info = create_state_space_info( diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index ebdc7b87..5dfd0b3f 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +import jax from lcm.grids import ContinuousGrid, DiscreteGrid from lcm.interfaces import InternalModel, StateActionSpace, StateSpaceInfo @@ -16,6 +17,7 @@ def create_state_action_space( *, states: dict[str, Array] | None = None, is_last_period: bool = False, + multi_device_support: bool = False, ) -> StateActionSpace: """Create a state-action-space. @@ -28,6 +30,8 @@ def create_state_action_space( are used. is_last_period: Whether the state-action-space is created for the last period, in which case auxiliary variables are not included. + multi_device_support: Whether to use sharded arrays to distribute the + computation on multiple devices. Returns: A state-action-space. Contains the grids of the discrete and continuous actions, @@ -41,6 +45,22 @@ def create_state_action_space( if states is None: _states = {sn: model.grids[sn] for sn in vi.query("is_state").index} + if multi_device_support: + device_count = jax.device_count() + sucess = False + for state in _states: + if (_states[state].shape[0] % device_count) == 0: + mesh = jax.make_mesh((device_count,), ('x')) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) + _states[state] = jax.device_put(_states[state], device=sharding) + sucess = True + break + if not sucess: + raise ValueError( + "If you want to use multiple devices, at least one state variable has to" + f" have a number of gridpoints divisible by the number of available devices.\n" + f"Available devices: {device_count}", + ) else: _validate_all_states_present( provided_states=states, @@ -48,6 +68,7 @@ def create_state_action_space( ) _states = states + discrete_actions = { name: model.grids[name] for name in vi.query("is_action & is_discrete").index } From 07535468c5169db58707b0b7d8ce3422ae2fef54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Aug 2025 12:17:58 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lcm/entry_point.py | 2 +- src/lcm/state_action_space.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 3c3e2704..8cd7a707 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -109,7 +109,7 @@ def get_lcm_function( state_action_space = create_state_action_space( model=internal_model, is_last_period=is_last_period, - multi_device_support=multi_device_support + multi_device_support=multi_device_support, ) state_space_info = create_state_space_info( diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index 5dfd0b3f..0e601722 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import jax + from lcm.grids import ContinuousGrid, DiscreteGrid from lcm.interfaces import InternalModel, StateActionSpace, StateSpaceInfo @@ -50,16 +51,18 @@ def create_state_action_space( sucess = False for state in _states: if (_states[state].shape[0] % device_count) == 0: - mesh = jax.make_mesh((device_count,), ('x')) - sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) + mesh = jax.make_mesh((device_count,), ("x")) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("x") + ) _states[state] = jax.device_put(_states[state], device=sharding) sucess = True break if not sucess: raise ValueError( - "If you want to use multiple devices, at least one state variable has to" - f" have a number of gridpoints divisible by the number of available devices.\n" - f"Available devices: {device_count}", + "If you want to use multiple devices, at least one state variable has to" + f" have a number of gridpoints divisible by the number of available devices.\n" + f"Available devices: {device_count}", ) else: _validate_all_states_present( @@ -68,7 +71,6 @@ def create_state_action_space( ) _states = states - discrete_actions = { name: model.grids[name] for name in vi.query("is_action & is_discrete").index } From 3006080cd8817e5f96fe4a71a4b6baa5146452ac Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 28 Aug 2025 14:37:12 +0200 Subject: [PATCH 3/4] Include simulation --- src/lcm/state_action_space.py | 44 +++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index 0e601722..b5327320 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -46,24 +46,6 @@ def create_state_action_space( if states is None: _states = {sn: model.grids[sn] for sn in vi.query("is_state").index} - if multi_device_support: - device_count = jax.device_count() - sucess = False - for state in _states: - if (_states[state].shape[0] % device_count) == 0: - mesh = jax.make_mesh((device_count,), ("x")) - sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec("x") - ) - _states[state] = jax.device_put(_states[state], device=sharding) - sucess = True - break - if not sucess: - raise ValueError( - "If you want to use multiple devices, at least one state variable has to" - f" have a number of gridpoints divisible by the number of available devices.\n" - f"Available devices: {device_count}", - ) else: _validate_all_states_present( provided_states=states, @@ -71,6 +53,32 @@ def create_state_action_space( ) _states = states + if multi_device_support: + device_count = jax.device_count() + sucess = False + for key, value in _states.items(): + if (value.shape[0] % device_count) == 0: + mesh = jax.make_mesh((device_count,), ("x")) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("x") + ) + _states[key] = jax.device_put(value, device=sharding) + sucess = True + break + if not sucess: + if states is None: + raise ValueError( + "If you want to use multiple devices, the number of initial states " + "has to be divisible by the number of available devices.\n" + f"Available devices: {device_count}", + ) + raise ValueError( + "If you want to use multiple devices, at least one state variable " + "has to have a number of gridpoints divisible by the number" + " of available devices.\n" + f"Available devices: {device_count}", + ) + discrete_actions = { name: model.grids[name] for name in vi.query("is_action & is_discrete").index } From 499b8a010ad800e7169a6cadc5dc7ce2d82f9e6d Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 28 Aug 2025 18:24:12 +0200 Subject: [PATCH 4/4] Move check to new function --- src/lcm/state_action_space.py | 53 ++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/src/lcm/state_action_space.py b/src/lcm/state_action_space.py index b5327320..2fa32fe8 100644 --- a/src/lcm/state_action_space.py +++ b/src/lcm/state_action_space.py @@ -55,29 +55,16 @@ def create_state_action_space( if multi_device_support: device_count = jax.device_count() - sucess = False - for key, value in _states.items(): - if (value.shape[0] % device_count) == 0: - mesh = jax.make_mesh((device_count,), ("x")) - sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec("x") - ) - _states[key] = jax.device_put(value, device=sharding) - sucess = True - break - if not sucess: - if states is None: - raise ValueError( - "If you want to use multiple devices, the number of initial states " - "has to be divisible by the number of available devices.\n" - f"Available devices: {device_count}", - ) - raise ValueError( - "If you want to use multiple devices, at least one state variable " - "has to have a number of gridpoints divisible by the number" - " of available devices.\n" - f"Available devices: {device_count}", - ) + divisible_state_name = _find_divisible_state( + states=_states, device_count=device_count, simulation=states is None + ) + mesh = jax.make_mesh((device_count,), (divisible_state_name,)) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(divisible_state_name) + ) + _states[divisible_state_name] = jax.device_put( + _states[divisible_state_name], device=sharding + ) discrete_actions = { name: model.grids[name] for name in vi.query("is_action & is_discrete").index @@ -151,3 +138,23 @@ def _validate_all_states_present( f"\n\nMissing initial states: {missing}\n", f"Provided variables that are not states: {too_many}", ) + + +def _find_divisible_state( + states: dict[str, Array], device_count: int, *, simulation: bool +) -> str: + for key, value in states.items(): + if (value.shape[0] % device_count) == 0: + return key + if simulation: + raise ValueError( + "If you want to use multiple devices, the number of initial states " + "has to be divisible by the number of available devices.\n" + f"Available devices: {device_count}", + ) + raise ValueError( + "If you want to use multiple devices, at least one state variable " + "has to have a number of gridpoints divisible by the number" + " of available devices.\n" + f"Available devices: {device_count}", + )