@@ -2032,7 +2032,7 @@ def test_impl(v):
20322032 with self .assertRaises (NonconstantOpenmpSpecification ) as raises :
20332033 test_impl (np .zeros (100 ))
20342034 self .assertIn (
2035- "Non- constant OpenMP specification at line " , str (raises .exception )
2035+ "Cannot infer constant OpenMP specification" , str (raises .exception )
20362036 )
20372037
20382038 # def test_parallel_for_blocking_if(self):
@@ -5037,5 +5037,136 @@ def test_impl():
50375037 self .assertGreater (jit_num_devices , 0 )
50385038
50395039
5040+ class TestOpenmpStringPatterns (TestOpenmpBase ):
5041+ def __init__ (self , * args ):
5042+ TestOpenmpBase .__init__ (self , * args )
5043+
5044+ def test_omp_jit_const_string (self ):
5045+ @njit
5046+ def test_impl (x ):
5047+ with openmp ("parallel num_threads(4)" ):
5048+ tid = omp_get_thread_num ()
5049+ x [tid ] = x [tid ] + 1
5050+ return x
5051+
5052+ x = np .zeros (4 )
5053+ x = test_impl (x )
5054+ np .testing .assert_array_equal (x , np .ones (4 ))
5055+
5056+ def test_omp_py_const_string (self ):
5057+ omp_string = "parallel num_threads(4)"
5058+
5059+ @njit
5060+ def test_impl (x ):
5061+ with openmp (omp_string ):
5062+ tid = omp_get_thread_num ()
5063+ x [tid ] = x [tid ] + 1
5064+ return x
5065+
5066+ x = np .zeros (4 )
5067+ x = test_impl (x )
5068+ np .testing .assert_array_equal (x , np .ones (4 ))
5069+
5070+ def test_omp_jit_fstring (self ):
5071+ @njit
5072+ def test_impl (x ):
5073+ num_threads = 4
5074+ with openmp (f"parallel num_threads({ num_threads } )" ):
5075+ tid = omp_get_thread_num ()
5076+ x [tid ] = x [tid ] + 1
5077+ return x
5078+
5079+ x = np .zeros (4 )
5080+ x = test_impl (x )
5081+ np .testing .assert_array_equal (x , np .ones (4 ))
5082+
5083+ def test_omp_py_fstring (self ):
5084+ num_threads = 4
5085+ omp_string = f"parallel num_threads({ num_threads } )"
5086+
5087+ @njit
5088+ def test_impl (x ):
5089+ with openmp (omp_string ):
5090+ tid = omp_get_thread_num ()
5091+ x [tid ] = x [tid ] + 1
5092+ return x
5093+
5094+ x = np .zeros (4 )
5095+ x = test_impl (x )
5096+ np .testing .assert_array_equal (x , np .ones (4 ))
5097+
5098+ def test_omp_string_concat_literals (self ):
5099+ @njit
5100+ def test_impl (x ):
5101+ with openmp ("parallel " + "num_threads(4)" ):
5102+ tid = omp_get_thread_num ()
5103+ x [tid ] = x [tid ] + 1
5104+ return x
5105+
5106+ x = np .zeros (4 )
5107+ x = test_impl (x )
5108+ np .testing .assert_array_equal (x , np .ones (4 ))
5109+
5110+ def test_omp_string_concat_jit_variables (self ):
5111+ @njit
5112+ def test_impl (x ):
5113+ prefix = "parallel "
5114+ suffix = "num_threads(4)"
5115+ with openmp (prefix + suffix ):
5116+ tid = omp_get_thread_num ()
5117+ x [tid ] = x [tid ] + 1
5118+ return x
5119+
5120+ x = np .zeros (4 )
5121+ x = test_impl (x )
5122+ np .testing .assert_array_equal (x , np .ones (4 ))
5123+
5124+ def test_omp_string_concat_variables (self ):
5125+ num_threads = 4
5126+ omp_string = "parallel num_threads(" + str (num_threads ) + ")"
5127+
5128+ @njit
5129+ def test_impl (x ):
5130+ with openmp (omp_string ):
5131+ tid = omp_get_thread_num ()
5132+ x [tid ] = x [tid ] + 1
5133+ return x
5134+
5135+ x = np .zeros (4 )
5136+ x = test_impl (x )
5137+ np .testing .assert_array_equal (x , np .ones (4 ))
5138+
5139+ def test_omp_nested_concat (self ):
5140+ prefix = "parallel "
5141+ suffix = "num_threads(4)"
5142+ omp_string = prefix + suffix
5143+
5144+ @njit
5145+ def test_impl (x ):
5146+ with openmp (omp_string ):
5147+ tid = omp_get_thread_num ()
5148+ x [tid ] = x [tid ] + 1
5149+ return x
5150+
5151+ x = np .zeros (4 )
5152+ x = test_impl (x )
5153+ np .testing .assert_array_equal (x , np .ones (4 ))
5154+
5155+ def test_omp_explicit_str_call (self ):
5156+ num_threads = 4
5157+ omp_string = "parallel " + "num_threads(" + str (num_threads ) + ")"
5158+
5159+ @njit
5160+ def test_impl (x ):
5161+ with openmp (omp_string ):
5162+ tid = omp_get_thread_num ()
5163+ x [tid ] = x [tid ] + 1
5164+ return x
5165+
5166+ x = np .zeros (4 )
5167+ x = test_impl (x )
5168+ np .testing .assert_array_equal (x , np .ones (4 ))
5169+
5170+
50405171if __name__ == "__main__" :
50415172 unittest .main ()
0 commit comments