Skip to content

Commit 8c46112

Browse files
Merge pull request #2859 from devitocodes/TMA-aftermath
compiler: Pass kwargs when invoking compiler passes
2 parents 1e4dc67 + dc068a0 commit 8c46112

10 files changed

Lines changed: 48 additions & 58 deletions

File tree

devito/core/cpu.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,15 @@ class Cpu64NoopOperator(Cpu64OperatorMixin, CoreOperator):
140140
@timed_pass(name='specializing.IET')
141141
def _specialize_iet(cls, graph, **kwargs):
142142
options = kwargs['options']
143-
platform = kwargs['platform']
144-
compiler = kwargs['compiler']
145-
sregistry = kwargs['sregistry']
146143

147144
# Distributed-memory parallelism
148145
mpiize(graph, **kwargs)
149146

150147
# Shared-memory parallelism
151148
if options['openmp']:
152-
parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
149+
parizer = cls._Target.Parizer(**kwargs)
153150
parizer.make_parallel(graph)
154-
parizer.initialize(graph, options=options)
151+
parizer.initialize(graph)
155152

156153
# Symbol definitions
157154
cls._Target.DataManager(**kwargs).process(graph)
@@ -205,11 +202,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
205202
@classmethod
206203
@timed_pass(name='specializing.IET')
207204
def _specialize_iet(cls, graph, **kwargs):
208-
options = kwargs['options']
209-
platform = kwargs['platform']
210-
compiler = kwargs['compiler']
211-
sregistry = kwargs['sregistry']
212-
213205
# Flush denormal numbers
214206
avoid_denormals(graph, **kwargs)
215207

@@ -220,10 +212,10 @@ def _specialize_iet(cls, graph, **kwargs):
220212
relax_incr_dimensions(graph, **kwargs)
221213

222214
# Parallelism
223-
parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
215+
parizer = cls._Target.Parizer(**kwargs)
224216
parizer.make_simd(graph)
225217
parizer.make_parallel(graph)
226-
parizer.initialize(graph, options=options)
218+
parizer.initialize(graph)
227219

228220
# Misc optimizations
229221
hoist_prodders(graph)
@@ -300,12 +292,7 @@ def callback(f, *args):
300292

301293
@classmethod
302294
def _make_iet_passes_mapper(cls, **kwargs):
303-
options = kwargs['options']
304-
platform = kwargs['platform']
305-
compiler = kwargs['compiler']
306-
sregistry = kwargs['sregistry']
307-
308-
parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
295+
parizer = cls._Target.Parizer(**kwargs)
309296

310297
return {
311298
'denormals': partial(avoid_denormals, **kwargs),
@@ -316,7 +303,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
316303
'linearize': partial(linearize, **kwargs),
317304
'simd': partial(parizer.make_simd),
318305
'prodders': hoist_prodders,
319-
'init': partial(parizer.initialize, options=options)
306+
'init': partial(parizer.initialize)
320307
}
321308

322309
_known_passes = (

devito/core/gpu.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -180,18 +180,13 @@ class DeviceNoopOperator(DeviceOperatorMixin, CoreOperator):
180180
@classmethod
181181
@timed_pass(name='specializing.IET')
182182
def _specialize_iet(cls, graph, **kwargs):
183-
options = kwargs['options']
184-
platform = kwargs['platform']
185-
compiler = kwargs['compiler']
186-
sregistry = kwargs['sregistry']
187-
188183
# Distributed-memory parallelism
189184
mpiize(graph, **kwargs)
190185

191186
# GPU parallelism
192-
parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
187+
parizer = cls._Target.Parizer(**kwargs)
193188
parizer.make_parallel(graph)
194-
parizer.initialize(graph, options=options)
189+
parizer.initialize(graph)
195190

196191
# Symbol definitions
197192
cls._Target.DataManager(**kwargs).process(graph)
@@ -248,21 +243,16 @@ def _specialize_clusters(cls, clusters, **kwargs):
248243
@classmethod
249244
@timed_pass(name='specializing.IET')
250245
def _specialize_iet(cls, graph, **kwargs):
251-
options = kwargs['options']
252-
platform = kwargs['platform']
253-
compiler = kwargs['compiler']
254-
sregistry = kwargs['sregistry']
255-
256246
# Distributed-memory parallelism
257247
mpiize(graph, **kwargs)
258248

259249
# Lower BlockDimensions so that blocks of arbitrary shape may be used
260250
relax_incr_dimensions(graph, **kwargs)
261251

262252
# GPU parallelism
263-
parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
253+
parizer = cls._Target.Parizer(**kwargs)
264254
parizer.make_parallel(graph)
265-
parizer.initialize(graph, options=options)
255+
parizer.initialize(graph)
266256

267257
# Misc optimizations
268258
hoist_prodders(graph)
@@ -325,22 +315,17 @@ def _make_clusters_passes_mapper(cls, **kwargs):
325315

326316
@classmethod
327317
def _make_iet_passes_mapper(cls, **kwargs):
328-
options = kwargs['options']
329-
platform = kwargs['platform']
330-
compiler = kwargs['compiler']
331-
sregistry = kwargs['sregistry']
332-
333-
parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
318+
parizer = cls._Target.Parizer(**kwargs)
334319
orchestrator = cls._Target.Orchestrator(**kwargs)
335320

336321
return {
337322
'parallel': parizer.make_parallel,
338323
'orchestrate': partial(orchestrator.process),
339-
'pthreadify': partial(pthreadify, sregistry=sregistry),
324+
'pthreadify': partial(pthreadify, **kwargs),
340325
'mpi': partial(mpiize, **kwargs),
341326
'linearize': partial(linearize, **kwargs),
342327
'prodders': partial(hoist_prodders),
343-
'init': partial(parizer.initialize, options=options)
328+
'init': partial(parizer.initialize)
344329
}
345330

346331
_known_passes = (

devito/operator/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def _lower_uiet(cls, stree, profiler=None, **kwargs):
474474

475475
@classmethod
476476
@timed_pass(name='lowering.IET')
477-
def _lower_iet(cls, uiet, profiler=None, **kwargs):
477+
def _lower_iet(cls, uiet, **kwargs):
478478
"""
479479
Iteration/Expression tree lowering:
480480
@@ -496,7 +496,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
496496
# Instrument the IET for C-level profiling
497497
# Note: this is postponed until after _specialize_iet because during
498498
# specialization further Sections may be introduced
499-
cls._Target.instrument(graph, profiler=profiler, **kwargs)
499+
cls._Target.instrument(graph, **kwargs)
500500

501501
# Extract the necessary macros from the symbolic objects
502502
generate_macros(graph, **kwargs)

devito/operator/profiling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def record_ops_variation(self, initial, final):
180180
def all_sections(self):
181181
return list(self._sections) + flatten(self._subsections.values())
182182

183+
@property
184+
def high_verbosity(self):
185+
return self._verbosity >= 2
186+
183187
def summary(self, args, dtype, reduce_over=None):
184188
"""
185189
Return a PerformanceSummary of the profiled sections.

devito/passes/clusters/aliases.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,15 @@ def process(self, clusters):
292292
def callback(self, clusters, prefix, xtracted=None):
293293
if not prefix:
294294
return clusters
295-
d = prefix[-1].dim
295+
p = prefix[-1]
296+
d = p.dim
296297

297298
# Rule out extractions that would break data dependencies
298299
exclude = set().union(*[c.scope.writes for c in clusters])
299300

300301
# Rule out extractions that depend on the Dimension currently investigated,
301302
# as they clearly wouldn't be invariants
302-
exclude.add(d)
303+
exclude.update({d, *p.sub_iterators})
303304

304305
key = lambda c: self._lookup_key(c, d)
305306
processed = list(clusters)

devito/passes/iet/instrument.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def instrument(graph, **kwargs):
17-
profiler = kwargs['profiler']
17+
profiler = kwargs.get('profiler')
1818
if profiler is None:
1919
return
2020

devito/passes/iet/langbase.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,27 +160,36 @@ class LangTransformer:
160160
The constructs of the target language. To be specialized by a subclass.
161161
"""
162162

163-
def __init__(self, key, sregistry, platform, compiler):
163+
def __init__(self, key=None, options=None, sregistry=None, platform=None,
164+
compiler=None, profiler=None, **kwargs):
164165
"""
165166
Parameters
166167
----------
167168
key : callable, optional
168169
Return True if an Iteration can and should be parallelized,
169170
False otherwise.
171+
options : dict, optional
172+
The optimization options.
170173
sregistry : SymbolRegistry
171174
The symbol registry, to access the symbols appearing in an IET.
172175
platform : Platform
173176
The underlying platform.
174177
compiler : Compiler
175178
The underlying JIT compiler.
179+
profiler : Profiler
180+
The underlying Profiler, used to instrument the IET.
176181
"""
177182
if key is not None:
178183
self.key = key
179184
else:
180185
self.key = lambda i: False
186+
187+
self.uses_mpi = options['mpi']
188+
181189
self.sregistry = sregistry
182190
self.platform = platform
183191
self.compiler = compiler
192+
self.profiler = profiler
184193

185194
@iet_pass
186195
def make_parallel(self, iet):
@@ -228,11 +237,11 @@ class ShmTransformer(LangTransformer):
228237
shared-memory-parallel IETs for CPUs.
229238
"""
230239

231-
def __init__(self, key, sregistry, options, platform, compiler):
240+
def __init__(self, key, options=None, **kwargs):
232241
"""
233242
Parameters
234243
----------
235-
key : callable, optional
244+
key : callable
236245
Return True if an Iteration can and should be parallelized,
237246
False otherwise.
238247
sregistry : SymbolRegistry
@@ -251,12 +260,13 @@ def __init__(self, key, sregistry, options, platform, compiler):
251260
iteration exceeds this threshold. Otherwise, use static scheduling.
252261
* 'par-nested': nested parallelism if the number of hyperthreads
253262
per core is greater than this threshold.
263+
* 'mpi': tells whether MPI is enabled.
254264
platform : Platform
255265
The underlying platform.
256266
compiler : Compiler
257267
The underlying JIT compiler.
258268
"""
259-
super().__init__(key, sregistry, platform, compiler)
269+
super().__init__(key, options=options, **kwargs)
260270

261271
self.collapse_ncores = options['par-collapse-ncores']
262272
self.collapse_work = options['par-collapse-work']
@@ -391,7 +401,7 @@ def deviceid(self):
391401
return self.sregistry.deviceid
392402

393403
@iet_pass
394-
def initialize(self, iet, options=None):
404+
def initialize(self, iet):
395405
"""
396406
An `iet_pass` which transforms an IET such that the target language
397407
runtime is initialized.
@@ -416,7 +426,7 @@ def _extract_objcomm(iet):
416426
# Fallback -- might end up here because the Operator has no
417427
# halo exchanges, but we now need it nonetheless to perform
418428
# the rank-GPU assignment
419-
if options['mpi']:
429+
if self.uses_mpi:
420430
for i in iet.parameters:
421431
try:
422432
return i.grid.distributor._obj_comm

devito/passes/iet/parpragma.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ class PragmaShmTransformer(ShmTransformer, PragmaSimdTransformer):
225225
IETs for CPUs.
226226
"""
227227

228-
def __init__(self, sregistry, options, platform, compiler):
228+
def __init__(self, **kwargs):
229229
key = lambda i: i.is_ParallelRelaxed and not i.is_Vectorized
230-
super().__init__(key, sregistry, options, platform, compiler)
230+
super().__init__(key, **kwargs)
231231

232232
def _make_reductions(self, partree):
233233
if not any(i.is_ParallelAtomic for i in partree.collapsed):
@@ -491,8 +491,8 @@ class PragmaDeviceAwareTransformer(DeviceAwareMixin, PragmaShmTransformer):
491491
shared-memory-parallel, and device-parallel IETs.
492492
"""
493493

494-
def __init__(self, sregistry, options, platform, compiler):
495-
super().__init__(sregistry, options, platform, compiler)
494+
def __init__(self, options=None, **kwargs):
495+
super().__init__(options=options, **kwargs)
496496

497497
self.gpu_fit = options['gpu-fit']
498498
# Need to reset the tile in case was already used and iter over by blocking

devito/tools/data_structures.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __getnewargs_ex__(self):
9999
return tuple(self), sdict
100100

101101
def get(self, key, val=None):
102-
return self.getters.get(key, val)
102+
try:
103+
return self[key]
104+
except KeyError:
105+
return val
103106

104107
@property
105108
def items(self) -> tuple:

tests/test_linearize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,4 +688,4 @@ def test_cire_n_strides():
688688

689689
# NOTE: not exact equality because `op2` slightly changes the order of
690690
# arithmetic operations, which in turn causes some rounding differences
691-
assert np.allclose(u.data, u1.data, rtol=1e-5)
691+
assert np.allclose(u.data, u1.data, rtol=1e-4)

0 commit comments

Comments
 (0)