Skip to content

Commit 3ba9138

Browse files
Harden parsing device info (#64)
1 parent 83b0e7f commit 3ba9138

2 files changed

Lines changed: 18 additions & 11 deletions

File tree

src/numba/openmp/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,17 @@ def _init_offloading_info():
115115

116116
num_devices = omp_get_num_devices()
117117

118+
if num_devices == 0:
119+
if DEBUG_OPENMP >= 1:
120+
print("No OpenMP offloading devices found")
121+
return
122+
118123
try:
119124
addr = ll.address_of_symbol("__tgt_get_device_info")
120125
if not addr:
121-
if DEBUG_OPENMP >= 1:
122-
print(
123-
"Symbol __tgt_get_device_info not found in OpenMP runtime, skipping device info initialization"
124-
)
126+
raise RuntimeError(
127+
"Symbol __tgt_get_device_info not found in OpenMP runtime"
128+
)
125129
from ctypes import (
126130
CFUNCTYPE,
127131
c_void_p,
@@ -147,8 +151,7 @@ def _copy_cb(ptr, size):
147151
add_device_info(i, info_str)
148152

149153
except Exception as e:
150-
if DEBUG_OPENMP >= 1:
151-
print(f"Warning: Failed to initialize offloading info: {e}")
154+
raise e
152155

153156

154157
def _init():

src/numba/openmp/offloading.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def add_device_info(device_id, info_text):
2121
try:
2222
_device_info_map[device_id] = _parse_device_info(info_text)
2323
except Exception as e:
24-
raise RuntimeError(
25-
f"Warning: Failed to parse device info for device {device_id}: {e}"
26-
)
24+
raise RuntimeError(f"Failed to parse device info for device {device_id}: {e}")
2725

2826

2927
def _parse_device_info(output: str):
@@ -52,11 +50,17 @@ def _parse_device_info(output: str):
5250

5351
device_info_all[key] = val
5452

55-
if "amd" in device_info_all.get("vendor name", ""):
53+
if (
54+
"hsa openmp device number" in device_info_all.keys()
55+
or "amd" in device_info_all.get("vendor name", "")
56+
):
5657
vendor = "amd"
5758
devtype = "gpu"
5859
arch = device_info_all.get("device name", "")
59-
elif "nvidia" in device_info_all.get("device name", ""):
60+
elif (
61+
"cuda openmp device number" in device_info_all.keys()
62+
or "nvidia" in device_info_all.get("device name", "")
63+
):
6064
vendor = "nvidia"
6165
devtype = "gpu"
6266
arch = device_info_all.get("compute capabilities", "")

0 commit comments

Comments
 (0)