diff --git a/flashinfer/hip_utils.py b/flashinfer/hip_utils.py index e085d6d2a8..c28aa4f11f 100644 --- a/flashinfer/hip_utils.py +++ b/flashinfer/hip_utils.py @@ -82,7 +82,7 @@ def get_system_rocm_version_from_hipconfig(): check=False, ) if result.returncode == 0: - match = re.search(r"(\d+\.\d+\.\d+)", result.stdout) + match = re.search(r"(\d+\.\d+(?:\.\d+)?)", result.stdout) if match: return match.group(1) except (subprocess.TimeoutExpired, FileNotFoundError): @@ -147,13 +147,18 @@ def get_system_rocm_version(): """ Attempt to detect the system ROCm version. - For standard builds, tries methods in order of reliability. - For TheRock builds, prioritizes hipconfig as it's more reliable. + For standard ROCm installations, detection falls back through several + methods in order of reliability: ``ROCM_HOME/.info/version``, ``amd-smi``, + ``dpkg``, and finally ``hipconfig``. + + For TheRock builds, ``hipconfig`` is used directly because it reports the + HIP runtime version (consistent with ``torch.version.hip``), unlike + ``.info/version`` which reports the TheRock SDK version (for example, + ``"7.12.0"`` when HIP is ``7.3``). Returns: str: ROCm version like "7.1.0" or None if not detectable """ - # For TheRock builds, prioritize hipconfig if is_therock_build(): return get_system_rocm_version_from_hipconfig() @@ -198,7 +203,7 @@ def validate_rocm_arch(arch_list: str = None, verbose: bool = False) -> str: # Add new tuple for adding a new version group _ROCM_ARCH_GROUPS = [ ( - ["7.3", "7.2", "7.1", "7.0"], + ["7.13", "7.12", "7.11", "7.3", "7.2", "7.1", "7.0"], [ "gfx950", "gfx1201", diff --git a/tests/rocm_tests/test_hip_utils.py b/tests/rocm_tests/test_hip_utils.py index b2a30a0044..75ea294247 100644 --- a/tests/rocm_tests/test_hip_utils.py +++ b/tests/rocm_tests/test_hip_utils.py @@ -21,6 +21,7 @@ get_available_gpu_count, get_rocm_home, get_supported_device_indices, + get_system_rocm_version_from_hipconfig, is_therock_build, validate_flashinfer_rocm_arch, validate_rocm_arch, @@ -94,13 +95,48 @@ def test_manifest_file_exists(self, tmp_path): assert is_therock_build() is True def test_manifest_file_missing_and_no_rocm_sdk(self, tmp_path): - import sys - - sys.modules.pop("rocm_sdk", None) - with patch("flashinfer.hip_utils.get_rocm_home", return_value=str(tmp_path)): + with ( + patch.dict("sys.modules", {"rocm_sdk": None}), + patch("flashinfer.hip_utils.get_rocm_home", return_value=str(tmp_path)), + ): assert is_therock_build() is False +# get_system_rocm_version_from_hipconfig +class TestGetSystemRocmVersionFromHipconfig: + def _run_result(self, stdout, returncode=0): + result = MagicMock() + result.returncode = returncode + result.stdout = stdout + return result + + @pytest.mark.parametrize( + "stdout,expected", + [ + ("7.1.0\n", "7.1.0"), + ("7.13.26183-83e9908b71\n", "7.13.26183"), + ("7.13\n", "7.13"), + ], + ) + def test_parses_version_string(self, stdout, expected): + with patch("subprocess.run", return_value=self._run_result(stdout)): + assert get_system_rocm_version_from_hipconfig() == expected + + def test_returns_none_on_nonzero_returncode(self): + with patch("subprocess.run", return_value=self._run_result("", returncode=1)): + assert get_system_rocm_version_from_hipconfig() is None + + def test_returns_none_when_hipconfig_not_found(self): + with patch("subprocess.run", side_effect=FileNotFoundError): + assert get_system_rocm_version_from_hipconfig() is None + + def test_returns_none_on_timeout(self): + with patch( + "subprocess.run", side_effect=subprocess.TimeoutExpired("hipconfig", 5) + ): + assert get_system_rocm_version_from_hipconfig() is None + + # validate_rocm_arch class TestValidateRocmArch: def _patch_rocm_version(self, version): @@ -169,6 +205,11 @@ def test_rocm_7x_supports_gfx950(self, version): with self._patch_rocm_version(version): assert validate_rocm_arch(arch_list="gfx950") == "gfx950" + @pytest.mark.parametrize("version", ["7.13.26183", "7.13.0", "7.12.0", "7.11.0"]) + def test_therock_versions_support_gfx950(self, version): + with self._patch_rocm_version(version): + assert validate_rocm_arch(arch_list="gfx950") == "gfx950" + @pytest.mark.parametrize("version", ["6.4.0", "6.3.0"]) def test_rocm_6x_supports_gfx942_not_gfx950(self, version): with self._patch_rocm_version(version):