Skip to content

Commit 9e45a7b

Browse files
committed
Fix calling OpenMP runtime functions in interactive sessions
- Remove unnecessary inspect to check the caller context; __call__ method is only need by python mode - Add tests for calling OpenMP runtime functions within jit and python contexts
1 parent b5c38e3 commit 9e45a7b

2 files changed

Lines changed: 21 additions & 9 deletions

File tree

src/numba/openmp/omp_runtime.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@
66

77
class _OpenmpExternalFunction(types.ExternalFunction):
88
def __call__(self, *args):
9-
import inspect
10-
11-
frm = inspect.stack()[1]
12-
mod = inspect.getmodule(frm[0])
13-
if mod.__name__.startswith("numba") and not mod.__name__.startswith(
14-
"numba.openmp.tests"
15-
):
16-
return super(ExternalFunction, self).__call__(*args)
17-
189
# Resolve the function address via llvmlite's symbol table so we
1910
# call the same LLVM-registered symbol the JIT uses. Then wrap
2011
# it with ctypes CFUNCTYPE to call from Python. This avoids

src/numba/openmp/tests/test_openmp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5051,6 +5051,27 @@ def test_impl(lb, num_steps, pi_comp_func):
50515051
njit_output = njit(test_impl)(0, 1024, test_pi_comp_njit)
50525052
self.assert_outputs_equal(py_output, njit_output)
50535053

5054+
class TestOpenmpRuntimeFunctions(TestOpenmpBase):
5055+
def __init__(self, *args):
5056+
TestOpenmpBase.__init__(self, *args)
5057+
5058+
def test_omp_get_max_threads(self):
5059+
@njit
5060+
def test_impl():
5061+
return omp_get_max_threads()
5062+
5063+
jit_max_threads = test_impl()
5064+
python_max_threads = omp_get_max_threads()
5065+
self.assertEqual(jit_max_threads, python_max_threads)
5066+
5067+
def test_omp_get_num_procs(self):
5068+
@njit
5069+
def test_impl():
5070+
return omp_get_num_procs()
5071+
5072+
jit_num_procs = test_impl()
5073+
python_num_procs = omp_get_num_procs()
5074+
self.assertEqual(jit_num_procs, python_num_procs)
50545075

50555076
if __name__ == "__main__":
50565077
unittest.main()

0 commit comments

Comments
 (0)