diff --git a/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml index 238dbc7de..89cd95dd0 100644 --- a/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml @@ -23,7 +23,6 @@ extra_container_mounts = [] [[git_repos]] url = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git" commit = "v0.3.0" -mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "b200" diff --git a/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml index b06b9d1f6..d0ca6e2fb 100644 --- a/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml @@ -23,7 +23,6 @@ extra_container_mounts = [] [[git_repos]] url = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git" commit = "v0.3.0" -mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "gb200" diff --git a/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml index 22b692e27..f68899a08 100644 --- a/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml @@ -23,7 +23,6 @@ extra_container_mounts = [] [[git_repos]] url = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git" commit = "v0.3.0" -mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "gb300" @@ -32,6 +31,7 @@ model_family_name = "qwen" model_recipe_name = "qwen3_30b_a3b" gpus_per_node = 4 num_gpus = 8 +# mb = 4 # In case OOM, uncomment this for smaller micro-batch size domain = "llm" task = "pretrain" compute_dtype = "fp8_mx" diff --git a/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml index 68d378052..84c52f893 100644 --- a/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml @@ -23,7 +23,6 @@ extra_container_mounts = [] [[git_repos]] url = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git" commit = "v0.3.0" -mount_as = "/opt/Megatron-Bridge" [cmd_args] gpu_type = "h100" diff --git a/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_qwen_30b.toml b/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_qwen_30b.toml index 16d218f84..d9dc23419 100644 --- a/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_qwen_30b.toml +++ b/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_qwen_30b.toml @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_r0.3.0_qwen_30b.toml b/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_r0.3.0_qwen_30b.toml new file mode 100644 index 000000000..70e62e188 --- /dev/null +++ b/conf/experimental/megatron_bridge/test_scenario/megatron_bridge_r0.3.0_qwen_30b.toml @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name = "megatron_bridge_qwen_30b" + +[[Tests]] +id = "megatron_bridge_qwen_30b" +test_name = "megatron_bridge_qwen_30b" +num_nodes = "2" + + [[Tests.git_repos]] + url = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git" + commit = "r0.3.0" + mount_as = "/opt/Megatron-Bridge" + init_submodules = true + + [Tests.extra_env_vars] + PYTHONPATH = "/opt/Megatron-Bridge/3rdparty/Megatron-LM:${PYTHONPATH}" diff --git a/src/cloudai/_core/installables.py b/src/cloudai/_core/installables.py index f8527876c..d6db18589 100644 --- a/src/cloudai/_core/installables.py +++ b/src/cloudai/_core/installables.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -91,6 +91,7 @@ class GitRepo(Installable, BaseModel): url: str commit: str + init_submodules: bool = False installed_path: Optional[Path] = None mount_as: Optional[str] = None diff --git a/src/cloudai/_core/test_scenario.py b/src/cloudai/_core/test_scenario.py index 5480adf78..5eb28820c 100644 --- a/src/cloudai/_core/test_scenario.py +++ b/src/cloudai/_core/test_scenario.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -62,6 +62,7 @@ class TestRun: test: TestDefinition num_nodes: Union[int, list[int]] nodes: List[str] + exclude_nodes: List[str] = field(default_factory=list) output_path: Path = Path("") iterations: int = 1 current_iteration: int = 0 diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index cb5452af2..beeb84244 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -73,6 +73,13 @@ class TestRunModel(BaseModel): test_name: Optional[str] = None num_nodes: int | list[int] | None = None nodes: list[str] = Field(default_factory=list) + exclude_nodes: list[str] = Field( + default_factory=list, + description=( + "Hostnames to exclude from the resolved node list. " + "Supports Slurm range syntax, e.g. ['node-048', 'node-[101-104]']." + ), + ) weight: int = 0 iterations: int = 1 sol: Optional[float] = None diff --git a/src/cloudai/systems/kubernetes/kubernetes_installer.py b/src/cloudai/systems/kubernetes/kubernetes_installer.py index 6f18f2a7a..5b1ebdf30 100644 --- a/src/cloudai/systems/kubernetes/kubernetes_installer.py +++ b/src/cloudai/systems/kubernetes/kubernetes_installer.py @@ -156,11 +156,23 @@ def _install_one_git_repo(self, item: GitRepo) -> InstallStatusResult: verify_res = self._verify_commit(item.commit, repo_path) if not verify_res.success: return verify_res + if item.init_submodules: + res = self._init_submodules(repo_path) + if not res.success: + return res item.installed_path = repo_path msg = f"Git repository already exists at {repo_path}." logging.debug(msg) return InstallStatusResult(True, msg) + res = self._clone_and_setup_repo(item, repo_path) + if not res.success: + return res + + item.installed_path = repo_path + return InstallStatusResult(True) + + def _clone_and_setup_repo(self, item: GitRepo, repo_path: Path) -> InstallStatusResult: res = self._clone_repository(item.url, repo_path) if not res.success: return res @@ -172,7 +184,14 @@ def _install_one_git_repo(self, item: GitRepo) -> InstallStatusResult: rmtree(repo_path) return res - item.installed_path = repo_path + if item.init_submodules: + res = self._init_submodules(repo_path) + if not res.success: + logging.error(f"Submodule init failed, removing cloned repository at {repo_path}") + if repo_path.exists(): + rmtree(repo_path) + return res + return InstallStatusResult(True) def _install_python_executable(self, item: PythonExecutable) -> InstallStatusResult: @@ -237,6 +256,14 @@ def _checkout_commit(self, commit_hash: str, path: Path) -> InstallStatusResult: return InstallStatusResult(False, f"Failed to checkout commit {commit_hash}: {result.stderr}") return InstallStatusResult(True) + def _init_submodules(self, path: Path) -> InstallStatusResult: + logging.debug(f"Initializing submodules in {path}") + submodule_cmd = ["git", "submodule", "update", "--init", "--recursive"] + result = subprocess.run(submodule_cmd, cwd=str(path), capture_output=True, text=True) + if result.returncode != 0: + return InstallStatusResult(False, f"Failed to initialize submodules: {result.stderr}") + return InstallStatusResult(True) + def _verify_commit(self, ref: str, path: Path) -> InstallStatusResult: try: result = subprocess.run(["git", "rev-parse", "HEAD"], cwd=str(path), capture_output=True, text=True) diff --git a/src/cloudai/systems/slurm/slurm_command_gen_strategy.py b/src/cloudai/systems/slurm/slurm_command_gen_strategy.py index 5c17b4d17..67619a1b9 100644 --- a/src/cloudai/systems/slurm/slurm_command_gen_strategy.py +++ b/src/cloudai/systems/slurm/slurm_command_gen_strategy.py @@ -424,6 +424,9 @@ def _append_nodes_related_directives(self, content: List[str]) -> Optional[Path] content.append(f"#SBATCH -N {num_nodes}") + if self.test_run.exclude_nodes: + content.append(f"#SBATCH --exclude={','.join(self.test_run.exclude_nodes)}") + return None def _format_env_vars(self, env_vars: Dict[str, Any]) -> str: @@ -465,6 +468,7 @@ def get_cached_nodes_spec(self) -> tuple[int, list[str]]: str(self.test_run.step), str(self.test_run.nnodes), ",".join(self.test_run.nodes), + ",".join(self.test_run.exclude_nodes), ] ) @@ -472,5 +476,9 @@ def get_cached_nodes_spec(self) -> tuple[int, list[str]]: logging.debug(f"Using cached node allocation for {cache_key}: {self._node_spec_cache[cache_key]}") return self._node_spec_cache[cache_key] - self._node_spec_cache[cache_key] = self.system.get_nodes_by_spec(self.test_run.nnodes, self.test_run.nodes) + num_nodes, node_list = self.system.get_nodes_by_spec( + self.test_run.nnodes, self.test_run.nodes, exclude_nodes=self.test_run.exclude_nodes or None + ) + + self._node_spec_cache[cache_key] = (num_nodes, node_list) return self._node_spec_cache[cache_key] diff --git a/src/cloudai/systems/slurm/slurm_installer.py b/src/cloudai/systems/slurm/slurm_installer.py index 4800becc5..e30a87395 100644 --- a/src/cloudai/systems/slurm/slurm_installer.py +++ b/src/cloudai/systems/slurm/slurm_installer.py @@ -209,11 +209,23 @@ def _install_one_git_repo(self, item: GitRepo) -> InstallStatusResult: verify_res = self._verify_commit(item.commit, repo_path) if not verify_res.success: return verify_res + if item.init_submodules: + res = self._init_submodules(repo_path) + if not res.success: + return res item.installed_path = repo_path msg = f"Git repository already exists at {repo_path}." logging.debug(msg) return InstallStatusResult(True, msg) + res = self._clone_and_setup_repo(item, repo_path) + if not res.success: + return res + + item.installed_path = repo_path + return InstallStatusResult(True) + + def _clone_and_setup_repo(self, item: GitRepo, repo_path: Path) -> InstallStatusResult: res = self._clone_repository(item.url, repo_path) if not res.success: return res @@ -225,7 +237,14 @@ def _install_one_git_repo(self, item: GitRepo) -> InstallStatusResult: rmtree(repo_path) return res - item.installed_path = repo_path + if item.init_submodules: + res = self._init_submodules(repo_path) + if not res.success: + logging.error(f"Submodule init failed, removing cloned repository at {repo_path}") + if repo_path.exists(): + rmtree(repo_path) + return res + return InstallStatusResult(True) def _install_python_executable(self, item: PythonExecutable) -> InstallStatusResult: @@ -290,6 +309,14 @@ def _checkout_commit(self, commit_hash: str, path: Path) -> InstallStatusResult: return InstallStatusResult(False, f"Failed to checkout commit {commit_hash}: {result.stderr}") return InstallStatusResult(True) + def _init_submodules(self, path: Path) -> InstallStatusResult: + logging.debug(f"Initializing submodules in {path}") + submodule_cmd = ["git", "submodule", "update", "--init", "--recursive"] + result = subprocess.run(submodule_cmd, cwd=str(path), capture_output=True, text=True) + if result.returncode != 0: + return InstallStatusResult(False, f"Failed to initialize submodules: {result.stderr}") + return InstallStatusResult(True) + def _verify_commit(self, ref: str, path: Path) -> InstallStatusResult: try: result = subprocess.run(["git", "rev-parse", "HEAD"], cwd=str(path), capture_output=True, text=True) diff --git a/src/cloudai/systems/slurm/slurm_system.py b/src/cloudai/systems/slurm/slurm_system.py index 5edbed01a..2ec73b6fb 100644 --- a/src/cloudai/systems/slurm/slurm_system.py +++ b/src/cloudai/systems/slurm/slurm_system.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -443,7 +443,11 @@ def format_range(lst: List[int], padding: int) -> List[str]: return ", ".join(formatted_ranges) def get_available_nodes_from_group( - self, partition_name: str, group_name: str, number_of_nodes: Union[int, str] + self, + partition_name: str, + group_name: str, + number_of_nodes: Union[int, str], + exclude_nodes: list[str] | None = None, ) -> List[SlurmNode]: """ Retrieve a specific number of potentially available nodes from a group within a partition. @@ -456,6 +460,7 @@ def get_available_nodes_from_group( group_name (str): The name of the group. number_of_nodes (Union[int,str]): The number of nodes to retrieve. Could also be 'all' to retrieve all the nodes from the group. + exclude_nodes (list[str] | None): Node names to exclude from the pool before selection. Returns: List[SlurmNode]: Objects that are potentially available for use. @@ -468,7 +473,7 @@ def get_available_nodes_from_group( self.validate_partition_and_group(partition_name, group_name) - grouped_nodes = self.group_nodes_by_state(partition_name, group_name) + grouped_nodes = self.group_nodes_by_state(partition_name, group_name, exclude_nodes=exclude_nodes) try: allocated_nodes = self.allocate_nodes(grouped_nodes, number_of_nodes, group_name) @@ -505,14 +510,19 @@ def validate_partition_and_group(self, partition_name: str, group_name: str) -> if group_name not in self.groups[partition_name]: raise ValueError(f"Group '{group_name}' not found in partition '{partition_name}'.") - def group_nodes_by_state(self, partition_name: str, group_name: str) -> Dict[SlurmNodeState, List[SlurmNode]]: + def group_nodes_by_state( + self, + partition_name: str, + group_name: str, + exclude_nodes: list[str] | None = None, + ) -> Dict[SlurmNodeState, List[SlurmNode]]: """ Group nodes by their states, excluding nodes allocated to the current user. Args: partition_name (str): The name of the partition. group_name (str): The name of the group. - current_user (str): The username of the current user. + exclude_nodes (list[str] | None): Node names to exclude from the pool before grouping. Returns: Dict[SlurmNodeState, List[SlurmNode]]: A dictionary grouping nodes by their state. @@ -521,9 +531,12 @@ def group_nodes_by_state(self, partition_name: str, group_name: str) -> Dict[Slu SlurmNodeState.IDLE: [], SlurmNodeState.COMPLETING: [], SlurmNodeState.ALLOCATED: [], + SlurmNodeState.RESERVED: [], } for node in self.groups[partition_name][group_name]: + if exclude_nodes and node.name in exclude_nodes: + continue if node.state in grouped_nodes: grouped_nodes[node.state].append(node) @@ -554,6 +567,7 @@ def allocate_nodes( if isinstance(number_of_nodes, str) and number_of_nodes == "max_avail": allocated_nodes.extend(grouped_nodes[SlurmNodeState.IDLE]) allocated_nodes.extend(grouped_nodes[SlurmNodeState.COMPLETING]) + allocated_nodes.extend(grouped_nodes[SlurmNodeState.RESERVED]) if len(allocated_nodes) == 0: raise ValueError( @@ -671,7 +685,7 @@ def convert_state_to_enum(self, state_str: str) -> SlurmNodeState: logging.debug(f"Unknown node state: {core_state}") return SlurmNodeState.UNKNOWN_STATE - def parse_nodes(self, nodes: List[str]) -> List[str]: + def parse_nodes(self, nodes: List[str], exclude_nodes: list[str] | None = None) -> List[str]: """ Parse a list of node specifications into individual node names. @@ -684,6 +698,8 @@ def parse_nodes(self, nodes: List[str]) -> List[str]: "partition:group:num_nodes", where "partition" is the partition name, "group" is a group within that partition, and "num_nodes" is the number of nodes requested. Node ranges should be specified with square brackets and dashes, e.g., "node[01-03]" for "node01", "node02", "node03". + exclude_nodes (list[str] | None): Node names (or Slurm range expressions) to exclude from group pools + before selection. Ranges are expanded internally. Returns: List[str]: A list of node names. For specifications, it includes names of allocated nodes based on the @@ -693,6 +709,9 @@ def parse_nodes(self, nodes: List[str]) -> List[str]: ValueError: If a specification is malformed, a specified node is not found, or a node range cannot be parsed. This ensures users are aware of incorrect inputs. """ + if exclude_nodes: + exclude_nodes = [n for spec in exclude_nodes for n in parse_node_list(spec)] + parsed_nodes = [] for node_spec in nodes: if ":" in node_spec: @@ -701,17 +720,23 @@ def parse_nodes(self, nodes: List[str]) -> List[str]: raise ValueError("Format should be partition:group:num_nodes") partition_name, group_name, num_nodes_spec = parts num_nodes = int(num_nodes_spec) if num_nodes_spec != "max_avail" else num_nodes_spec - group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes) + group_nodes = self.get_available_nodes_from_group( + partition_name, group_name, num_nodes, exclude_nodes=exclude_nodes + ) parsed_nodes += [node.name for node in group_nodes] else: expanded_nodes = parse_node_list(node_spec) + if exclude_nodes: + expanded_nodes = [n for n in expanded_nodes if n not in exclude_nodes] parsed_nodes += expanded_nodes # Remove duplicates while preserving order parsed_nodes = list(dict.fromkeys(parsed_nodes)) return parsed_nodes - def get_nodes_by_spec(self, num_nodes: int, nodes: list[str]) -> Tuple[int, list[str]]: + def get_nodes_by_spec( + self, num_nodes: int, nodes: list[str], exclude_nodes: list[str] | None = None + ) -> Tuple[int, list[str]]: """ Retrieve a list of node names based on specifications. @@ -721,15 +746,29 @@ def get_nodes_by_spec(self, num_nodes: int, nodes: list[str]) -> Tuple[int, list Args: num_nodes (int): The number of nodes, can't be `0`. nodes (list[str]): A list of node names specifications, slurm format or `PARTITION:GROUP:NUM_NODES`. + exclude_nodes (list[str] | None): Node names to exclude from group pools before selection. Returns: Tuple[int, list[str]]: The number of nodes and a list of node names. + + Raises: + ValueError: If node specifications were provided but resolved to an empty list. """ num_nodes, node_list = num_nodes, [] - parsed_nodes = self.parse_nodes(nodes) + parsed_nodes = self.parse_nodes(nodes, exclude_nodes=exclude_nodes) if parsed_nodes: num_nodes = len(parsed_nodes) node_list = parsed_nodes + elif nodes: + reason = ( + f"after excluding nodes {exclude_nodes}" + if exclude_nodes + else "— no nodes are available (all may be DRAIN/DOWN)" + ) + raise ValueError( + f"Node specifications {nodes} resolved to an empty node list {reason}. " + "Cannot fall back to unconstrained allocation." + ) return num_nodes, sorted(node_list) def system_installables(self) -> list[Installable]: diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index 774e8dd9c..b324d638c 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -208,6 +208,7 @@ def _create_test_run( post_test=post_test, reports=get_reporters(test_info, tdef), extra_srun_args=test_info.extra_srun_args, + exclude_nodes=test_info.exclude_nodes, ) return tr diff --git a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py index 783072c59..7ff8f72b3 100644 --- a/src/cloudai/workloads/megatron_bridge/megatron_bridge.py +++ b/src/cloudai/workloads/megatron_bridge/megatron_bridge.py @@ -184,10 +184,16 @@ class MegatronBridgeTestDefinition(TestDefinition): @staticmethod def _select_megatron_bridge_repo(git_repos: list[GitRepo]) -> GitRepo | None: - """Return the Megatron-Bridge repo from `git_repos` (normalized to mount_as=/opt/Megatron-Bridge).""" + """ + Return the Megatron-Bridge repo from `git_repos`. + + When the user sets ``mount_as`` (e.g. ``/opt/Megatron-Bridge``), the installed clone will be bind-mounted + into the container at that path, overriding whatever the container image ships. When ``mount_as`` is *not* + set the container's built-in ``/opt/Megatron-Bridge`` is used. + """ for repo in git_repos: if "Megatron-Bridge" in repo.url or (repo.mount_as or "").rstrip("/") == "/opt/Megatron-Bridge": - return repo if repo.mount_as else repo.model_copy(update={"mount_as": "/opt/Megatron-Bridge"}) + return repo return None @field_validator("git_repos", mode="after") diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index 8d08c5696..abb0a1330 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -37,6 +37,14 @@ class MegatronBridgeSlurmCommandGenStrategy(SlurmCommandGenStrategy): The launcher submits the actual training sbatch job; CloudAI tracks that job ID via SlurmRunner parsing. """ + CONTAINER_RUNTIME_ENV_VARS: frozenset[str] = frozenset( + { + "MELLANOX_VISIBLE_DEVICES", + "NVIDIA_VISIBLE_DEVICES", + "NVIDIA_DRIVER_CAPABILITIES", + } + ) + def _container_mounts(self) -> list[str]: # This workload submits its own sbatch job and passes mounts via `-cm`. return [] @@ -96,6 +104,25 @@ def _build_custom_bash_env_exports(self) -> list[str]: exports.extend(["-cb", shlex.quote(f"export {key}={value}")]) return exports + def _container_runtime_env_exports(self) -> list[str]: + """ + Build ``export`` lines for container-runtime env vars. + + Variables like ``MELLANOX_VISIBLE_DEVICES`` and ``NVIDIA_VISIBLE_DEVICES`` + are consumed by the NVIDIA container toolkit / enroot at container-creation + time to decide which devices to mount. They must be present in the process + environment **before** the Megatron-Bridge launcher calls ``sbatch`` so that + Slurm inherits them into the job and ``srun`` passes them to the container + runtime. Exporting them in the wrapper script (which runs on the submit + node) achieves this. The same variables are still passed via ``-cb`` as + well, so they are also set inside the container for any runtime readers. + """ + lines: list[str] = [] + for key, value in sorted(self.final_env_vars.items()): + if key in self.CONTAINER_RUNTIME_ENV_VARS: + lines.append(f"export {key}={shlex.quote(str(value))}") + return lines + def _normalize_recompute_modules(self, val: Any) -> str: if isinstance(val, list): items = [str(x).strip().strip("\"'") for x in val if str(x).strip()] @@ -108,6 +135,30 @@ def _normalize_recompute_modules(self, val: Any) -> str: joined = ",".join(items) return f'"{joined}"' + @staticmethod + def _parse_srun_args_as_slurm_params(srun_args: str) -> list[str]: + """ + Convert ``--key value`` pairs from extra_srun_args into ``key=value`` for --additional_slurm_params. + + Standalone boolean flags (e.g. ``--exclusive``) are emitted as bare + key names without a ``=value`` suffix. + """ + params: list[str] = [] + tokens = shlex.split(srun_args) + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok.startswith("--") and "=" in tok: + key, val = tok[2:].split("=", 1) + params.append(f"{key}={val}") + elif tok.startswith("--") and i + 1 < len(tokens) and not tokens[i + 1].startswith("--"): + params.append(f"{tok[2:]}={tokens[i + 1]}") + i += 1 + elif tok.startswith("--"): + params.append(tok[2:]) + i += 1 + return params + def _normalize_cuda_graph_scope_arg(self, val: Any) -> str: s = str(val).strip().strip("\"'") if s.startswith("[") and s.endswith("]"): @@ -128,6 +179,8 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher wrapper_path = output_dir / "cloudai_megatron_bridge_submit_and_parse_jobid.sh" log_path = output_dir / "cloudai_megatron_bridge_launcher.log" + container_runtime_exports = self._container_runtime_env_exports() + script_lines = [ "#!/usr/bin/env bash", "set -o pipefail", @@ -140,7 +193,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher # Mirror wrapper stdout/stderr to files for debugging while still emitting to the parent process. 'exec > >(tee -a "$WRAPPER_STDOUT") 2> >(tee -a "$WRAPPER_STDERR" >&2)', "", - # Launch Megatron-Bridge (log stdout/stderr to file) + *container_runtime_exports, "", ': >"$LOG"', "WANDB_INSTALL_RC=0", @@ -225,14 +278,14 @@ def _installed_container_path() -> str: else: container_path = _installed_container_path() - # Use only test-level extra_container_mounts; never mount the Megatron-Bridge repo via -cm - # because the container uses its built-in copy. - mounts = [str(m).strip() for m in (tdef.extra_container_mounts or []) if str(m).strip()] - mounts = [ - m - for m in mounts - if "/opt/Megatron-Bridge" not in m and "Megatron-Bridge" not in m.split(":")[0].split("/")[-1] - ] + mounts: list[str] = [str(m).strip() for m in (tdef.extra_container_mounts or []) if str(m).strip()] + + # When the user sets mount_as on the Megatron-Bridge git repo, bind-mount the + # installed clone into the container to override the image's built-in copy. + mb_repo = tdef.megatron_bridge_repo + if mb_repo.mount_as: + mb_host = mb_repo.installed_path.absolute() if mb_repo.installed_path else repo_path + mounts.append(f"{mb_host}:{mb_repo.mount_as}") venv_path = tdef.python_executable.venv_path or (self.system.install_path / tdef.python_executable.venv_name) python_bin = (venv_path / "bin" / "python").absolute() @@ -400,6 +453,26 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("nsys_trace", "--nsys_trace", self._list_or_comma_str(args.nsys_trace)) add_field("nsys_extra_args", "--nsys_extra_args", self._list_or_comma_str(args.nsys_extra_args)) + additional_slurm_params: list[str] = [] + + if args.gpus_per_node and self.system.supports_gpu_directives: + additional_slurm_params.append(f"gpus-per-node={args.gpus_per_node}") + additional_slurm_params.append(f"gres=gpu:{args.gpus_per_node}") + + _, node_list = self.get_cached_nodes_spec() + if node_list: + nodelist_str = ",".join(node_list) + additional_slurm_params.append(f"nodelist={nodelist_str}") + elif self.test_run.exclude_nodes: + additional_slurm_params.append(f"exclude={','.join(self.test_run.exclude_nodes)}") + + for source in (self.system.extra_srun_args, self.test_run.extra_srun_args): + if source: + additional_slurm_params.extend(self._parse_srun_args_as_slurm_params(source)) + + if additional_slurm_params: + parts.extend(["--additional_slurm_params", shlex.quote(";".join(additional_slurm_params))]) + # Config variant add_field("config_variant", "-cv", args.config_variant) if args.list_config_variants and "list_config_variants" in fields_set: diff --git a/tests/ref_data/megatron-bridge.sbatch b/tests/ref_data/megatron-bridge.sbatch index a65cc44ed..884b90091 100644 --- a/tests/ref_data/megatron-bridge.sbatch +++ b/tests/ref_data/megatron-bridge.sbatch @@ -19,7 +19,7 @@ if [ "${WANDB_INSTALL_RC}" -ne 0 ]; then fi LAUNCH_RC=0 -NEMORUN_HOME="__OUTPUT_DIR__/output" __INSTALL_DIR__/Run__main-venv/bin/python __INSTALL_DIR__/Megatron-Bridge__main/scripts/performance/setup_experiment.py -p main -i __OUTPUT_DIR__/output/megatron_bridge_image.sqsh -hf dummy_token -ng 8 -gn 4 --golden_values_path cloudai_megatron_bridge_golden_values.json -cb 'export CUDA_VISIBLE_DEVICES=0,1,2,3' -cb 'export NCCL_DEBUG=INFO' -m qwen3 -mr 30b_a3b --detach false >>"$LOG" 2>&1 || LAUNCH_RC=$? +NEMORUN_HOME="__OUTPUT_DIR__/output" __INSTALL_DIR__/Run__main-venv/bin/python __INSTALL_DIR__/Megatron-Bridge__main/scripts/performance/setup_experiment.py -p main -i __OUTPUT_DIR__/output/megatron_bridge_image.sqsh -hf dummy_token -ng 8 -gn 4 --golden_values_path cloudai_megatron_bridge_golden_values.json -cm __INSTALL_DIR__/Megatron-Bridge__main:/opt/Megatron-Bridge -cb 'export CUDA_VISIBLE_DEVICES=0,1,2,3' -cb 'export NCCL_DEBUG=INFO' -m qwen3 -mr 30b_a3b --detach false --additional_slurm_params 'gpus-per-node=4;gres=gpu:4' >>"$LOG" 2>&1 || LAUNCH_RC=$? JOB_ID="" diff --git a/tests/systems/slurm/test_allocation.py b/tests/systems/slurm/test_allocation.py index 9426f5fa5..c8f06a624 100644 --- a/tests/systems/slurm/test_allocation.py +++ b/tests/systems/slurm/test_allocation.py @@ -115,3 +115,22 @@ def test_group_allocation_is_preserved_on_updated(self, slurm_system: SlurmSyste ): system.update() assert all(node.state == SlurmNodeState.ALLOCATED for node in system.group_allocated) + + def test_exclude_nodes_selects_from_remaining_pool( + self, slurm_system: SlurmSystem, monkeypatch: pytest.MonkeyPatch + ): + """Excluding a node from a group should still yield the requested count from the remaining pool.""" + system, _all_nodes, _ = self.prepare(slurm_system, [], monkeypatch) + nnodes, nodes_list = system.get_nodes_by_spec(1, ["main:group1:4"], exclude_nodes=["node03"]) + assert nnodes == 4 + assert "node03" not in nodes_list + assert len(nodes_list) == 4 + + def test_exclude_multiple_nodes_from_group(self, slurm_system: SlurmSystem, monkeypatch: pytest.MonkeyPatch): + """Excluding multiple nodes still selects the requested count from remaining nodes.""" + system, _all_nodes, _ = self.prepare(slurm_system, [], monkeypatch) + nnodes, nodes_list = system.get_nodes_by_spec(1, ["main:group1:3"], exclude_nodes=["node01", "node05"]) + assert nnodes == 3 + assert "node01" not in nodes_list + assert "node05" not in nodes_list + assert len(nodes_list) == 3 diff --git a/tests/systems/slurm/test_command_gen_strategy.py b/tests/systems/slurm/test_command_gen_strategy.py index e18fdf5ed..98e103704 100644 --- a/tests/systems/slurm/test_command_gen_strategy.py +++ b/tests/systems/slurm/test_command_gen_strategy.py @@ -361,6 +361,39 @@ def test_distribution_fallback_when_no_nodes(strategy_fixture: SlurmCommandGenSt assert "#SBATCH --nodelist=" not in content +def test_exclude_nodes_directive_when_no_nodelist(strategy_fixture: SlurmCommandGenStrategy) -> None: + strategy_fixture.test_run.nodes = [] + strategy_fixture.test_run.num_nodes = 3 + strategy_fixture.test_run.exclude_nodes = ["node01", "node02"] + content: List[str] = [] + strategy_fixture._append_nodes_related_directives(content) + + assert "#SBATCH -N 3" in content + assert "#SBATCH --exclude=node01,node02" in content + + +def test_no_exclude_directive_when_nodelist_present(slurm_system: SlurmSystem, testrun_fixture: TestRun) -> None: + testrun_fixture.nodes = ["node3", "node4"] + testrun_fixture.exclude_nodes = ["node01", "node02"] + strategy = MySlurmCommandGenStrategy(slurm_system, testrun_fixture) + content: List[str] = [] + strategy._append_nodes_related_directives(content) + + assert "#SBATCH --nodelist=node3,node4" in content + assert "#SBATCH --exclude=" not in content + + +def test_no_exclude_directive_when_exclude_nodes_unset(strategy_fixture: SlurmCommandGenStrategy) -> None: + strategy_fixture.test_run.nodes = [] + strategy_fixture.test_run.num_nodes = 2 + strategy_fixture.test_run.exclude_nodes = [] + content: List[str] = [] + strategy_fixture._append_nodes_related_directives(content) + + assert "#SBATCH -N 2" in content + assert not any("--exclude" in line for line in content) + + def test_nodelist_over_num_nodes(slurm_system: SlurmSystem, testrun_fixture: TestRun) -> None: testrun_fixture.nodes = ["nodeA", "nodeB", "nodeC"] testrun_fixture.num_nodes = 5 diff --git a/tests/systems/slurm/test_system.py b/tests/systems/slurm/test_system.py index 0ad79b61f..477015ed2 100644 --- a/tests/systems/slurm/test_system.py +++ b/tests/systems/slurm/test_system.py @@ -140,6 +140,7 @@ def grouped_nodes() -> dict[SlurmNodeState, list[SlurmNode]]: SlurmNode(name="node04", partition=partition_name, state=SlurmNodeState.COMPLETING) ], SlurmNodeState.ALLOCATED: [SlurmNode(name="node05", partition=partition_name, state=SlurmNodeState.ALLOCATED)], + SlurmNodeState.RESERVED: [SlurmNode(name="node06", partition=partition_name, state=SlurmNodeState.RESERVED)], } return grouped_nodes @@ -166,6 +167,7 @@ def test_allocate_nodes_max_avail(slurm_system: SlurmSystem, grouped_nodes: dict grouped_nodes[SlurmNodeState.IDLE][0].name, grouped_nodes[SlurmNodeState.IDLE][1].name, grouped_nodes[SlurmNodeState.COMPLETING][0].name, + grouped_nodes[SlurmNodeState.RESERVED][0].name, ] returned_node_names = [node.name for node in available_nodes] @@ -193,8 +195,8 @@ def test_allocate_nodes_exceeding_limit( slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]] ): group_name = "group_name" - num_nodes = 5 - available_nodes = 4 + num_nodes = 6 + available_nodes = 5 with pytest.raises( ValueError, @@ -363,10 +365,32 @@ def test_explicit_node_names( num_nodes, node_list = slurm_system.get_nodes_by_spec(in_nnodes, in_nodes) - mock_parse_nodes.assert_called_once_with(in_nodes) + mock_parse_nodes.assert_called_once_with(in_nodes, exclude_nodes=None) assert num_nodes == exp_nnodes assert node_list == exp_nodes + @patch("cloudai.systems.slurm.slurm_system.SlurmSystem.parse_nodes") + def test_raises_when_all_nodes_excluded(self, mock_parse_nodes: Mock, slurm_system: SlurmSystem): + mock_parse_nodes.return_value = [] + exclude = ["node01", "node02"] + + with pytest.raises(ValueError, match="after excluding nodes"): + slurm_system.get_nodes_by_spec(2, ["node0[1-2]"], exclude_nodes=exclude) + + @patch("cloudai.systems.slurm.slurm_system.SlurmSystem.parse_nodes") + def test_raises_when_parse_nodes_returns_empty_for_nonempty_specs( + self, mock_parse_nodes: Mock, slurm_system: SlurmSystem + ): + mock_parse_nodes.return_value = [] + + with pytest.raises(ValueError, match="no nodes are available"): + slurm_system.get_nodes_by_spec(1, ["main:group1:3"]) + + def test_empty_nodes_with_exclude_still_returns_unconstrained(self, slurm_system: SlurmSystem): + num_nodes, node_list = slurm_system.get_nodes_by_spec(3, [], exclude_nodes=["node01"]) + assert num_nodes == 3 + assert node_list == [] + class ConcreteSlurmStrategy(SlurmCommandGenStrategy): def _container_mounts(self) -> list[str]: diff --git a/tests/test_git_repo_installer.py b/tests/test_git_repo_installer.py index 545f1fe1d..fbcaffe52 100644 --- a/tests/test_git_repo_installer.py +++ b/tests/test_git_repo_installer.py @@ -232,6 +232,45 @@ def test_verify_commit_branch_name_match(self, installer: Union[KubernetesInstal res = installer._verify_commit(ref, repo_path) assert res.success + def test_submodules_initialized(self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo): + repo_path = installer.system.install_path / git.repo_name + repo_path.mkdir() + with patch("subprocess.run") as mock_run: + mock_run.return_value = CompletedProcess(args=[], returncode=0) + res = installer._init_submodules(repo_path) + assert res.success + mock_run.assert_called_once_with( + ["git", "submodule", "update", "--init", "--recursive"], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + def test_error_initializing_submodules(self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo): + repo_path = installer.system.install_path / git.repo_name + repo_path.mkdir() + with patch("subprocess.run") as mock_run: + mock_run.return_value = CompletedProcess(args=[], returncode=1, stderr="err") + res = installer._init_submodules(repo_path) + assert not res.success + assert res.message == "Failed to initialize submodules: err" + + def test_submodule_failure_cleans_up_repo( + self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo + ): + git.init_submodules = True + repo_path = installer.system.install_path / git.repo_name + installer._clone_repository = Mock( + side_effect=lambda url, path: (path.mkdir(parents=True, exist_ok=True), InstallStatusResult(True))[1] + ) + installer._checkout_commit = Mock(return_value=InstallStatusResult(True)) + installer._init_submodules = Mock( + return_value=InstallStatusResult(False, "Failed to initialize submodules: err") + ) + res = installer._install_one_git_repo(git) + assert not res.success + assert not repo_path.exists() + def test_all_good_flow(self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo): installer._clone_repository = Mock(return_value=InstallStatusResult(True)) installer._checkout_commit = Mock(return_value=InstallStatusResult(True)) @@ -239,6 +278,38 @@ def test_all_good_flow(self, installer: Union[KubernetesInstaller, SlurmInstalle assert res.success assert git.installed_path == installer.system.install_path / git.repo_name + def test_submodules_skipped_when_not_requested( + self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo + ): + installer._clone_repository = Mock(return_value=InstallStatusResult(True)) + installer._checkout_commit = Mock(return_value=InstallStatusResult(True)) + installer._init_submodules = Mock(return_value=InstallStatusResult(True)) + res = installer._install_one_git_repo(git) + assert res.success + installer._init_submodules.assert_not_called() + + def test_submodules_run_when_requested(self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo): + git.init_submodules = True + installer._clone_repository = Mock(return_value=InstallStatusResult(True)) + installer._checkout_commit = Mock(return_value=InstallStatusResult(True)) + installer._init_submodules = Mock(return_value=InstallStatusResult(True)) + res = installer._install_one_git_repo(git) + assert res.success + installer._init_submodules.assert_called_once() + + def test_existing_repo_inits_submodules_when_requested( + self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo + ): + git.init_submodules = True + repo_path = installer.system.install_path / git.repo_name + repo_path.mkdir() + installer._verify_commit = Mock(return_value=InstallStatusResult(True)) + installer._init_submodules = Mock(return_value=InstallStatusResult(True)) + res = installer._install_one_git_repo(git) + assert res.success + assert git.installed_path == repo_path + installer._init_submodules.assert_called_once_with(repo_path) + def test_uninstall_no_repo(self, installer: Union[KubernetesInstaller, SlurmInstaller], git: GitRepo): res = installer._uninstall_git_repo(git) assert res.success diff --git a/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py b/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py index 23708e50c..6a977d95e 100644 --- a/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py +++ b/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py @@ -64,6 +64,7 @@ def _make( *, cmd_args_overrides: dict[str, Any] | None = None, git_commit: str = "r0.2.0", + mount_as: str | None = "/opt/Megatron-Bridge", output_subdir: str = "out", num_nodes: int = 2, ) -> TestRun: @@ -78,19 +79,20 @@ def _make( if cmd_args_overrides: cmd_args_data.update(cmd_args_overrides) + repo_kwargs: dict[str, Any] = { + "url": "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", + "commit": git_commit, + } + if mount_as is not None: + repo_kwargs["mount_as"] = mount_as + tdef = MegatronBridgeTestDefinition( name="mb", description="desc", test_template_name="MegatronBridge", cmd_args=MegatronBridgeCmdArgs.model_validate(cmd_args_data), extra_container_mounts=[], - git_repos=[ - GitRepo( - url="https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", - commit=git_commit, - mount_as="/opt/Megatron-Bridge", - ) - ], + git_repos=[GitRepo(**repo_kwargs)], ) self._configure_fake_installs(tdef, tmp_path) return TestRun( @@ -223,6 +225,38 @@ def test_env_vars_are_forwarded_via_custom_bash_cmds( assert "-cb 'export CUDA_VISIBLE_DEVICES=0,1,2,3'" in wrapper_content assert "-cb 'export NCCL_DEBUG=INFO'" in wrapper_content + def test_container_runtime_env_vars_exported_in_wrapper_script( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + ) -> None: + configured_slurm_system.global_env_vars = { + "MELLANOX_VISIBLE_DEVICES": "0,1,4,5", + "NCCL_IB_HCA": "roce_p0_r0,roce_p0_r1,roce_p0_r2,roce_p0_r3", + "NCCL_IB_GID_INDEX": "3", + } + tr = make_test_run(output_subdir="out_container_rt") + tdef = cast(MegatronBridgeTestDefinition, tr.test) + tdef.extra_env_vars = {"NVIDIA_VISIBLE_DEVICES": "all", "NCCL_DEBUG": "INFO"} + + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + + launcher_idx = wrapper_content.index("setup_experiment.py") + + assert "export MELLANOX_VISIBLE_DEVICES=0,1,4,5" in wrapper_content + assert "export NVIDIA_VISIBLE_DEVICES=all" in wrapper_content + mvd_idx = wrapper_content.index("export MELLANOX_VISIBLE_DEVICES=") + nvd_idx = wrapper_content.index("export NVIDIA_VISIBLE_DEVICES=") + assert mvd_idx < launcher_idx, "MELLANOX_VISIBLE_DEVICES must be exported before the launcher" + assert nvd_idx < launcher_idx, "NVIDIA_VISIBLE_DEVICES must be exported before the launcher" + + assert "-cb 'export MELLANOX_VISIBLE_DEVICES=0,1,4,5'" in wrapper_content + assert "-cb 'export NVIDIA_VISIBLE_DEVICES=all'" in wrapper_content + assert "-cb 'export NCCL_IB_HCA=roce_p0_r0,roce_p0_r1,roce_p0_r2,roce_p0_r3'" in wrapper_content + assert "-cb 'export NCCL_DEBUG=INFO'" in wrapper_content + + assert "export NCCL_IB_HCA=" not in wrapper_content.split("setup_experiment.py")[0] + assert "export NCCL_DEBUG=" not in wrapper_content.split("setup_experiment.py")[0] + def test_wrapper_emits_job_id_even_when_launcher_non_zero( self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] ) -> None: @@ -317,3 +351,79 @@ def test_use_recipes_emitted_only_when_true( cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) wrapper_content = self._wrapper_content(cmd_gen) assert ("--use_recipes" in wrapper_content) is expected_in_wrapper + + def test_mount_as_adds_repo_to_container_mounts( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun], tmp_path: Path + ) -> None: + tr = make_test_run(mount_as="/opt/custom-megatron", output_subdir="out_mount") + tdef = cast(MegatronBridgeTestDefinition, tr.test) + repo_path = tdef.megatron_bridge_repo.installed_path + assert repo_path is not None + + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + assert f"-cm {repo_path.absolute()}:/opt/custom-megatron" in wrapper_content + + def test_no_mount_as_skips_repo_container_mount( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + ) -> None: + tr = make_test_run(mount_as=None, output_subdir="out_no_mount") + tdef = cast(MegatronBridgeTestDefinition, tr.test) + repo_path = tdef.megatron_bridge_repo.installed_path + assert repo_path is not None + + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + assert f"{repo_path.absolute()}:" not in wrapper_content + assert ":/opt/Megatron-Bridge" not in wrapper_content + + def test_gpus_per_node_passed_as_additional_slurm_param( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + ) -> None: + configured_slurm_system.supports_gpu_directives_cache = True + tr = make_test_run(cmd_args_overrides={"gpus_per_node": 2}, output_subdir="out_gpus") + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + assert "--additional_slurm_params" in wrapper_content + assert "gpus-per-node=2" in wrapper_content + assert "gres=gpu:2" in wrapper_content + + def test_gpus_per_node_skipped_when_gpu_directives_unsupported( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + ) -> None: + configured_slurm_system.supports_gpu_directives_cache = False + tr = make_test_run(cmd_args_overrides={"gpus_per_node": 2}, output_subdir="out_no_gpu_directives") + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + assert "gpus-per-node=2" not in wrapper_content + assert "gres=gpu:2" not in wrapper_content + + def test_system_extra_srun_args_forwarded( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + ) -> None: + configured_slurm_system.extra_srun_args = "--reservation my_reserv" + tr = make_test_run(output_subdir="out_srun") + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + assert "reservation=my_reserv" in wrapper_content + + def test_test_run_extra_srun_args_forwarded( + self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + ) -> None: + tr = make_test_run(output_subdir="out_tr_srun") + tr.extra_srun_args = "--constraint gpu" + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) + wrapper_content = self._wrapper_content(cmd_gen) + assert "constraint=gpu" in wrapper_content + + def test_parse_srun_args_as_slurm_params(self) -> None: + result = MegatronBridgeSlurmCommandGenStrategy._parse_srun_args_as_slurm_params( + "--reservation my_reserv --constraint=gpu" + ) + assert result == ["reservation=my_reserv", "constraint=gpu"] + + def test_parse_srun_args_boolean_flags(self) -> None: + result = MegatronBridgeSlurmCommandGenStrategy._parse_srun_args_as_slurm_params( + "--exclusive --reservation my_reserv --overcommit" + ) + assert result == ["exclusive", "reservation=my_reserv", "overcommit"]