diff --git a/src/lmflow/pipeline/evaluator.py b/src/lmflow/pipeline/evaluator.py index fd5a0127f..36d6be20b 100644 --- a/src/lmflow/pipeline/evaluator.py +++ b/src/lmflow/pipeline/evaluator.py @@ -32,7 +32,7 @@ from lmflow.datasets.dataset import Dataset from lmflow.pipeline.base_pipeline import BasePipeline from lmflow.utils.data_utils import answer_extraction, batchlize, set_random_seed -from lmflow.utils.envs import is_accelerate_env +from lmflow.utils.envs import is_accelerate_env, set_cuda_device from lmflow.utils.versioning import is_deepspeed_available os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers @@ -74,7 +74,7 @@ def __init__( set_random_seed(self.evaluator_args.random_seed) self.local_rank = int(os.getenv("LOCAL_RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) - torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error + set_cuda_device(self.local_rank) if is_accelerate_env(): self.accelerator = Accelerator() diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py index f79ba5d34..ec64b290d 100644 --- a/src/lmflow/pipeline/inferencer.py +++ b/src/lmflow/pipeline/inferencer.py @@ -23,7 +23,7 @@ from lmflow.pipeline.base_pipeline import BasePipeline from lmflow.utils.constants import IMAGE_TOKEN_INDEX from lmflow.utils.data_utils import batchlize, set_random_seed -from lmflow.utils.envs import is_accelerate_env +from lmflow.utils.envs import is_accelerate_env, set_cuda_device from lmflow.utils.versioning import is_deepspeed_available os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers @@ -74,7 +74,7 @@ def __init__( self.local_rank = int(os.getenv("LOCAL_RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) if inferencer_args.device == "gpu": # FIXME: a bit weird here - torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error + set_cuda_device(self.local_rank) if not is_accelerate_env() and is_deepspeed_available(): import deepspeed diff --git a/src/lmflow/pipeline/rm_inferencer.py b/src/lmflow/pipeline/rm_inferencer.py index bddac6241..f7fc18f30 100644 --- a/src/lmflow/pipeline/rm_inferencer.py +++ b/src/lmflow/pipeline/rm_inferencer.py @@ -25,7 +25,7 @@ batchlize, set_random_seed, ) -from lmflow.utils.envs import is_accelerate_env +from lmflow.utils.envs import is_accelerate_env, set_cuda_device from lmflow.utils.versioning import is_deepspeed_available, is_ray_available if is_ray_available(): @@ -70,7 +70,7 @@ def __init__( self.local_rank = int(os.getenv("LOCAL_RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) if inferencer_args.device == "gpu": # FIXME: a bit weird here - torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error + set_cuda_device(self.local_rank) if not is_accelerate_env() and is_deepspeed_available(): import deepspeed diff --git a/src/lmflow/utils/envs.py b/src/lmflow/utils/envs.py index ff1a40d8c..6368e32cf 100644 --- a/src/lmflow/utils/envs.py +++ b/src/lmflow/utils/envs.py @@ -2,38 +2,60 @@ ref: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py """ -import os import logging +import os +from typing import Any import torch - logger = logging.getLogger(__name__) -is_cuda_available = torch.cuda.is_available() + +__all__ = [ + "get_device_name", + "get_torch_device", + "is_accelerate_env", + "require_cuda_for_gpu_mode", + "set_cuda_device", +] def is_accelerate_env(): - for key, _ in os.environ.items(): - if key.startswith("ACCELERATE_"): - return True - return False + """Return True if any environment variable *name* starts with ``ACCELERATE_``.""" + return any(key.startswith("ACCELERATE_") for key in os.environ) + + +def require_cuda_for_gpu_mode() -> None: + """Raise if GPU execution was requested but CUDA is not available.""" + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available on this machine, but GPU execution was requested. " + "Install a CUDA-enabled PyTorch build and run on a GPU, or use CPU-compatible " + "settings where the pipeline supports them." + ) + + +def set_cuda_device(local_rank: int) -> None: + """Bind this process to ``local_rank`` on CUDA; raises if CUDA is unavailable.""" + require_cuda_for_gpu_mode() + torch.cuda.set_device(local_rank) def get_device_name() -> str: """ Get the device name based on the current machine. """ - if is_cuda_available: + if torch.cuda.is_available(): device = "cuda" else: device = "cpu" return device -def get_torch_device() -> any: - """Return the corresponding torch attribute based on the device type string. - Returns: - module: The corresponding torch device namespace, or torch.cuda if not found. +def get_torch_device() -> Any: + """Return ``torch.`` for the current device name. + + If ``torch`` has no attribute with that name, logs a warning and returns + ``torch.cuda`` as fallback. """ device_name = get_device_name() try: diff --git a/tests/utils/test_envs.py b/tests/utils/test_envs.py new file mode 100644 index 000000000..86b131e47 --- /dev/null +++ b/tests/utils/test_envs.py @@ -0,0 +1,72 @@ +import os +import unittest +from unittest.mock import patch + +import torch + +from lmflow.utils.envs import ( + get_device_name, + get_torch_device, + is_accelerate_env, + require_cuda_for_gpu_mode, + set_cuda_device, +) + + +class TestEnvs(unittest.TestCase): + def test_is_accelerate_env_false_without_prefix(self): + with patch.dict(os.environ, {"FOO": "1"}, clear=True): + self.assertFalse(is_accelerate_env()) + + def test_is_accelerate_env_true_with_prefix(self): + with patch.dict(os.environ, {"ACCELERATE_USE_CPU": "1"}, clear=True): + self.assertTrue(is_accelerate_env()) + + def test_is_accelerate_env_false_when_accelerate_not_prefix(self): + """Names containing 'ACCELERATE' but not starting with ACCELERATE_ must be ignored.""" + with patch.dict(os.environ, {"MY_ACCELERATE_SETTING": "1"}, clear=True): + self.assertFalse(is_accelerate_env()) + + @patch("torch.cuda.is_available", return_value=False) + def test_get_device_name_cpu_when_cuda_unavailable(self, _mock_cuda: object): + self.assertEqual(get_device_name(), "cpu") + + @patch("torch.cuda.is_available", return_value=True) + def test_get_device_name_cuda_when_cuda_available(self, _mock_cuda: object): + self.assertEqual(get_device_name(), "cuda") + + def test_get_torch_device_matches_device_name(self): + with patch("torch.cuda.is_available", return_value=False): + self.assertIs(get_torch_device(), torch.cpu) + with patch("torch.cuda.is_available", return_value=True): + self.assertIs(get_torch_device(), torch.cuda) + + @patch( + "lmflow.utils.envs.get_device_name", + return_value="zzz_nonexistent_lmflow_test", + ) + def test_get_torch_device_fallback_returns_cuda_on_attribute_error(self, _mock_name: object): + with self.assertLogs("lmflow.utils.envs", level="WARNING") as log_ctx: + self.assertIs(get_torch_device(), torch.cuda) + self.assertTrue( + any("zzz_nonexistent_lmflow_test" in entry and "not found" in entry for entry in log_ctx.output), + ) + + @patch("torch.cuda.is_available", return_value=False) + def test_require_cuda_for_gpu_mode_raises_when_cuda_unavailable(self, _mock_cuda: object): + with self.assertRaises(RuntimeError) as ctx: + require_cuda_for_gpu_mode() + self.assertIn("CUDA is not available", str(ctx.exception)) + + @patch("torch.cuda.is_available", return_value=True) + @patch("torch.cuda.set_device") + def test_set_cuda_device_calls_torch_set_device(self, mock_set_device: object, _mock_cuda: object): + set_cuda_device(2) + mock_set_device.assert_called_once_with(2) + + @patch("torch.cuda.is_available", return_value=False) + @patch("torch.cuda.set_device") + def test_set_cuda_device_raises_without_cuda(self, mock_set_device: object, _mock_cuda: object): + with self.assertRaises(RuntimeError): + set_cuda_device(2) + mock_set_device.assert_not_called()