Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions flashinfer/hip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_system_rocm_version_from_hipconfig():
check=False,
)
if result.returncode == 0:
match = re.search(r"(\d+\.\d+\.\d+)", result.stdout)
match = re.search(r"(\d+\.\d+(?:\.\d+)?)", result.stdout)
if match:
return match.group(1)
except (subprocess.TimeoutExpired, FileNotFoundError):
Expand Down Expand Up @@ -147,13 +147,18 @@ def get_system_rocm_version():
"""
Attempt to detect the system ROCm version.

For standard builds, tries methods in order of reliability.
For TheRock builds, prioritizes hipconfig as it's more reliable.
For standard ROCm installations, detection falls back through several
methods in order of reliability: ``ROCM_HOME/.info/version``, ``amd-smi``,
``dpkg``, and finally ``hipconfig``.

For TheRock builds, ``hipconfig`` is used directly because it reports the
HIP runtime version (consistent with ``torch.version.hip``), unlike
``.info/version`` which reports the TheRock SDK version (for example,
``"7.12.0"`` when HIP is ``7.3``).

Returns:
str: ROCm version like "7.1.0" or None if not detectable
"""
# For TheRock builds, prioritize hipconfig
if is_therock_build():
return get_system_rocm_version_from_hipconfig()

Expand Down Expand Up @@ -198,7 +203,7 @@ def validate_rocm_arch(arch_list: str = None, verbose: bool = False) -> str:
# Add new tuple for adding a new version group
_ROCM_ARCH_GROUPS = [
(
["7.3", "7.2", "7.1", "7.0"],
["7.13", "7.12", "7.11", "7.3", "7.2", "7.1", "7.0"],
[
Comment thread
eppaneamd marked this conversation as resolved.
"gfx950",
"gfx1201",
Expand Down
49 changes: 45 additions & 4 deletions tests/rocm_tests/test_hip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_available_gpu_count,
get_rocm_home,
get_supported_device_indices,
get_system_rocm_version_from_hipconfig,
is_therock_build,
validate_flashinfer_rocm_arch,
validate_rocm_arch,
Expand Down Expand Up @@ -94,13 +95,48 @@ def test_manifest_file_exists(self, tmp_path):
assert is_therock_build() is True

def test_manifest_file_missing_and_no_rocm_sdk(self, tmp_path):
import sys

sys.modules.pop("rocm_sdk", None)
with patch("flashinfer.hip_utils.get_rocm_home", return_value=str(tmp_path)):
with (
patch.dict("sys.modules", {"rocm_sdk": None}),
Comment thread
demandal25 marked this conversation as resolved.
patch("flashinfer.hip_utils.get_rocm_home", return_value=str(tmp_path)),
):
assert is_therock_build() is False


# get_system_rocm_version_from_hipconfig
class TestGetSystemRocmVersionFromHipconfig:
def _run_result(self, stdout, returncode=0):
result = MagicMock()
result.returncode = returncode
result.stdout = stdout
return result

@pytest.mark.parametrize(
"stdout,expected",
[
("7.1.0\n", "7.1.0"),
("7.13.26183-83e9908b71\n", "7.13.26183"),
("7.13\n", "7.13"),
],
)
def test_parses_version_string(self, stdout, expected):
with patch("subprocess.run", return_value=self._run_result(stdout)):
assert get_system_rocm_version_from_hipconfig() == expected

def test_returns_none_on_nonzero_returncode(self):
with patch("subprocess.run", return_value=self._run_result("", returncode=1)):
assert get_system_rocm_version_from_hipconfig() is None

def test_returns_none_when_hipconfig_not_found(self):
with patch("subprocess.run", side_effect=FileNotFoundError):
assert get_system_rocm_version_from_hipconfig() is None

def test_returns_none_on_timeout(self):
with patch(
"subprocess.run", side_effect=subprocess.TimeoutExpired("hipconfig", 5)
):
assert get_system_rocm_version_from_hipconfig() is None


# validate_rocm_arch
class TestValidateRocmArch:
def _patch_rocm_version(self, version):
Expand Down Expand Up @@ -169,6 +205,11 @@ def test_rocm_7x_supports_gfx950(self, version):
with self._patch_rocm_version(version):
assert validate_rocm_arch(arch_list="gfx950") == "gfx950"

@pytest.mark.parametrize("version", ["7.13.26183", "7.13.0", "7.12.0", "7.11.0"])
def test_therock_versions_support_gfx950(self, version):
with self._patch_rocm_version(version):
assert validate_rocm_arch(arch_list="gfx950") == "gfx950"

@pytest.mark.parametrize("version", ["6.4.0", "6.3.0"])
def test_rocm_6x_supports_gfx942_not_gfx950(self, version):
with self._patch_rocm_version(version):
Expand Down
Loading