Skip to content

Commit 72874af

Browse files
committed
Add tests for more openmp string patterns
1 parent 9481dd4 commit 72874af

1 file changed

Lines changed: 132 additions & 1 deletion

File tree

src/numba/openmp/tests/test_openmp.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,7 @@ def test_impl(v):
20322032
with self.assertRaises(NonconstantOpenmpSpecification) as raises:
20332033
test_impl(np.zeros(100))
20342034
self.assertIn(
2035-
"Non-constant OpenMP specification at line", str(raises.exception)
2035+
"Cannot infer constant OpenMP specification", str(raises.exception)
20362036
)
20372037

20382038
# def test_parallel_for_blocking_if(self):
@@ -5037,5 +5037,136 @@ def test_impl():
50375037
self.assertGreater(jit_num_devices, 0)
50385038

50395039

5040+
class TestOpenmpStringPatterns(TestOpenmpBase):
5041+
def __init__(self, *args):
5042+
TestOpenmpBase.__init__(self, *args)
5043+
5044+
def test_omp_jit_const_string(self):
5045+
@njit
5046+
def test_impl(x):
5047+
with openmp("parallel num_threads(4)"):
5048+
tid = omp_get_thread_num()
5049+
x[tid] = x[tid] + 1
5050+
return x
5051+
5052+
x = np.zeros(4)
5053+
x = test_impl(x)
5054+
np.testing.assert_array_equal(x, np.ones(4))
5055+
5056+
def test_omp_py_const_string(self):
5057+
omp_string = "parallel num_threads(4)"
5058+
5059+
@njit
5060+
def test_impl(x):
5061+
with openmp(omp_string):
5062+
tid = omp_get_thread_num()
5063+
x[tid] = x[tid] + 1
5064+
return x
5065+
5066+
x = np.zeros(4)
5067+
x = test_impl(x)
5068+
np.testing.assert_array_equal(x, np.ones(4))
5069+
5070+
def test_omp_jit_fstring(self):
5071+
@njit
5072+
def test_impl(x):
5073+
num_threads = 4
5074+
with openmp(f"parallel num_threads({num_threads})"):
5075+
tid = omp_get_thread_num()
5076+
x[tid] = x[tid] + 1
5077+
return x
5078+
5079+
x = np.zeros(4)
5080+
x = test_impl(x)
5081+
np.testing.assert_array_equal(x, np.ones(4))
5082+
5083+
def test_omp_py_fstring(self):
5084+
num_threads = 4
5085+
omp_string = f"parallel num_threads({num_threads})"
5086+
5087+
@njit
5088+
def test_impl(x):
5089+
with openmp(omp_string):
5090+
tid = omp_get_thread_num()
5091+
x[tid] = x[tid] + 1
5092+
return x
5093+
5094+
x = np.zeros(4)
5095+
x = test_impl(x)
5096+
np.testing.assert_array_equal(x, np.ones(4))
5097+
5098+
def test_omp_string_concat_literals(self):
5099+
@njit
5100+
def test_impl(x):
5101+
with openmp("parallel " + "num_threads(4)"):
5102+
tid = omp_get_thread_num()
5103+
x[tid] = x[tid] + 1
5104+
return x
5105+
5106+
x = np.zeros(4)
5107+
x = test_impl(x)
5108+
np.testing.assert_array_equal(x, np.ones(4))
5109+
5110+
def test_omp_string_concat_jit_variables(self):
5111+
@njit
5112+
def test_impl(x):
5113+
prefix = "parallel "
5114+
suffix = "num_threads(4)"
5115+
with openmp(prefix + suffix):
5116+
tid = omp_get_thread_num()
5117+
x[tid] = x[tid] + 1
5118+
return x
5119+
5120+
x = np.zeros(4)
5121+
x = test_impl(x)
5122+
np.testing.assert_array_equal(x, np.ones(4))
5123+
5124+
def test_omp_string_concat_variables(self):
5125+
num_threads = 4
5126+
omp_string = "parallel num_threads(" + str(num_threads) + ")"
5127+
5128+
@njit
5129+
def test_impl(x):
5130+
with openmp(omp_string):
5131+
tid = omp_get_thread_num()
5132+
x[tid] = x[tid] + 1
5133+
return x
5134+
5135+
x = np.zeros(4)
5136+
x = test_impl(x)
5137+
np.testing.assert_array_equal(x, np.ones(4))
5138+
5139+
def test_omp_nested_concat(self):
5140+
prefix = "parallel "
5141+
suffix = "num_threads(4)"
5142+
omp_string = prefix + suffix
5143+
5144+
@njit
5145+
def test_impl(x):
5146+
with openmp(omp_string):
5147+
tid = omp_get_thread_num()
5148+
x[tid] = x[tid] + 1
5149+
return x
5150+
5151+
x = np.zeros(4)
5152+
x = test_impl(x)
5153+
np.testing.assert_array_equal(x, np.ones(4))
5154+
5155+
def test_omp_explicit_str_call(self):
5156+
num_threads = 4
5157+
omp_string = "parallel " + "num_threads(" + str(num_threads) + ")"
5158+
5159+
@njit
5160+
def test_impl(x):
5161+
with openmp(omp_string):
5162+
tid = omp_get_thread_num()
5163+
x[tid] = x[tid] + 1
5164+
return x
5165+
5166+
x = np.zeros(4)
5167+
x = test_impl(x)
5168+
np.testing.assert_array_equal(x, np.ones(4))
5169+
5170+
50405171
if __name__ == "__main__":
50415172
unittest.main()

0 commit comments

Comments
 (0)