@@ -160,27 +160,36 @@ class LangTransformer:
160160 The constructs of the target language. To be specialized by a subclass.
161161 """
162162
163- def __init__ (self , key , sregistry , platform , compiler ):
163+ def __init__ (self , key = None , options = None , sregistry = None , platform = None ,
164+ compiler = None , profiler = None , ** kwargs ):
164165 """
165166 Parameters
166167 ----------
167168 key : callable, optional
168169 Return True if an Iteration can and should be parallelized,
169170 False otherwise.
171+ options : dict, optional
172+ The optimization options.
170173 sregistry : SymbolRegistry
171174 The symbol registry, to access the symbols appearing in an IET.
172175 platform : Platform
173176 The underlying platform.
174177 compiler : Compiler
175178 The underlying JIT compiler.
179+ profiler : Profiler
180+ The underlying Profiler, used to instrument the IET.
176181 """
177182 if key is not None :
178183 self .key = key
179184 else :
180185 self .key = lambda i : False
186+
187+ self .uses_mpi = options ['mpi' ]
188+
181189 self .sregistry = sregistry
182190 self .platform = platform
183191 self .compiler = compiler
192+ self .profiler = profiler
184193
185194 @iet_pass
186195 def make_parallel (self , iet ):
@@ -228,11 +237,11 @@ class ShmTransformer(LangTransformer):
228237 shared-memory-parallel IETs for CPUs.
229238 """
230239
231- def __init__ (self , key , sregistry , options , platform , compiler ):
240+ def __init__ (self , key , options = None , ** kwargs ):
232241 """
233242 Parameters
234243 ----------
235- key : callable, optional
244+ key : callable
236245 Return True if an Iteration can and should be parallelized,
237246 False otherwise.
238247 sregistry : SymbolRegistry
@@ -251,12 +260,13 @@ def __init__(self, key, sregistry, options, platform, compiler):
251260 iteration exceeds this threshold. Otherwise, use static scheduling.
252261 * 'par-nested': nested parallelism if the number of hyperthreads
253262 per core is greater than this threshold.
263+ * 'mpi': tells whether MPI is enabled.
254264 platform : Platform
255265 The underlying platform.
256266 compiler : Compiler
257267 The underlying JIT compiler.
258268 """
259- super ().__init__ (key , sregistry , platform , compiler )
269+ super ().__init__ (key , options = options , ** kwargs )
260270
261271 self .collapse_ncores = options ['par-collapse-ncores' ]
262272 self .collapse_work = options ['par-collapse-work' ]
@@ -391,7 +401,7 @@ def deviceid(self):
391401 return self .sregistry .deviceid
392402
393403 @iet_pass
394- def initialize (self , iet , options = None ):
404+ def initialize (self , iet ):
395405 """
396406 An `iet_pass` which transforms an IET such that the target language
397407 runtime is initialized.
@@ -416,7 +426,7 @@ def _extract_objcomm(iet):
416426 # Fallback -- might end up here because the Operator has no
417427 # halo exchanges, but we now need it nonetheless to perform
418428 # the rank-GPU assignment
419- if options [ 'mpi' ] :
429+ if self . uses_mpi :
420430 for i in iet .parameters :
421431 try :
422432 return i .grid .distributor ._obj_comm
0 commit comments