forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaot_neutron_compile.py
More file actions
395 lines (347 loc) · 13.7 KB
/
aot_neutron_compile.py
File metadata and controls
395 lines (347 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Example script to compile the model for the NXP Neutron NPU
import argparse
import io
import logging
from collections import defaultdict
import executorch.extension.pybindings.portable_lib
import executorch.kernels.quantized # noqa F401
import torch
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
NeutronEdgePassManager,
)
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
RemoveAdditionalQDQClustersPass,
)
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
RemoveIOQuantOpsPass,
)
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import (
core_aten_ops_exception_list,
generate_neutron_compile_spec,
)
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize
from executorch.devtools.visualization.visualization_utils import (
visualize_with_clusters,
)
from executorch.examples.models import MODEL_NAME_TO_MODEL
from executorch.examples.models.model_factory import EagerModelFactory
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.extension.export_util import save_pte_program
from torch.export import export
from torchao.quantization.pt2e import (
move_exported_model_to_eval,
move_exported_model_to_train,
)
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
from .experimental.cifar_net.cifar_net import (
CifarNet,
train_cifarnet_model,
verify_cifarnet_model,
)
from .models.mobilenet_v2 import MobilenetV2
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
def print_ops_in_edge_program(edge_program):
"""Find all ops used in the `edge_program` and print them out along with their occurrence counts."""
ops_and_counts = defaultdict(
lambda: 0
) # Mapping ops to the numer of times they are used.
for node in edge_program.graph.nodes:
if "call" not in node.op:
continue # `placeholder` or `output`. (not an operator)
if hasattr(node.target, "_schema"):
# Regular op.
# noinspection PyProtectedMember
op = node.target._schema.schema.name
else:
# Builtin function.
op = str(node.target)
ops_and_counts[op] += 1
# Sort the ops based on how many times they are used in the model.
ops_and_counts = sorted(ops_and_counts.items(), key=lambda x: x[1], reverse=True)
# Print the ops and use counts.
for op, count in ops_and_counts:
print(f"{op: <50} {count}x")
def get_model_and_inputs_from_name(model_name: str, use_random_dataset: bool):
"""Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None)
Raises RuntimeError if there is no example model corresponding to the given name.
"""
calibration_inputs = None
# Case 1: Model is defined in this file
if model_name in models.keys():
if use_random_dataset:
if model_name != "mobilenetv2":
raise NotImplementedError(
f"Random dataset for model {model_name} is not implemented."
)
m = models[model_name](use_random_dataset=use_random_dataset)
else:
m = models[model_name]()
model = m.get_eager_model()
example_inputs = m.get_example_inputs()
calibration_inputs = m.get_calibration_inputs(64)
# Case 2: Model is defined in executorch/examples/models/
elif model_name in MODEL_NAME_TO_MODEL.keys():
logging.warning(
"Using a model from examples/models not all of these are currently supported"
)
model, example_inputs, _, _ = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[model_name]
)
else:
raise RuntimeError(
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
)
return model, example_inputs, calibration_inputs
models = {
"cifar10": CifarNet,
"mobilenetv2": MobilenetV2,
}
if __name__ == "__main__": # noqa C901
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=True,
help=f"Provide model name. Valid ones: {set(models.keys())}",
)
parser.add_argument(
"-d",
"--delegate",
action="store_true",
required=False,
default=False,
help="Flag for producing eIQ NeutronBackend delegated model",
)
parser.add_argument(
"--target",
required=False,
default="imxrt700",
help="Platform for running the delegated model",
)
parser.add_argument(
"-q",
"--quantize",
action="store_true",
required=False,
default=False,
help="Produce a quantized model",
)
parser.add_argument(
"--use_qat",
action="store_true",
required=False,
default=False,
help="Use QAT mode for quantization (performs two QAT training epochs)",
)
parser.add_argument(
"-s",
"--so_library",
required=False,
default=None,
help="Path to custome kernel library",
)
parser.add_argument(
"--debug", action="store_true", help="Set the logging level to debug."
)
parser.add_argument(
"-t",
"--test",
action="store_true",
required=False,
default=False,
help="Test the selected model and print the accuracy between 0 and 1.",
)
parser.add_argument(
"-r",
"--remove-quant-io-ops",
action="store_true",
required=False,
default=False,
help="Remove I/O De/Quantize nodes. Model will start to accept quantized "
"inputs and produce quantized outputs.",
)
parser.add_argument(
"--operators_not_to_delegate",
required=False,
default=[],
type=str,
nargs="*",
help="List of operators not to delegate. E.g., --operators_not_to_delegate aten::convolution aten::mm",
)
parser.add_argument(
"--visualize",
choices=["show", "store"],
help="Visualize the lowered program. `show` launches a browser tab with the visualization. `store` stores the "
"visualization in a json file for later inspection. See `docs/source/visualize-with-clusters.md` for details.",
)
parser.add_argument(
"--use_channels_last_dim_order",
required=False,
default=False,
action="store_true",
help="The model (including the Neutron backend) will use the channels last dim order, which can result in "
"faster inference. The inputs must also be provided in the channels last dim order.",
)
parser.add_argument(
"--dump_kernel_selection_code",
required=False,
default=False,
action="store_true",
help="During conversion to Neutron microcode by Neutron Converter, a kernel selection file will be dumped in "
"the working directory. This file can be used for reduction of Neutron Firmware size in the built app."
"See `docs/source/backends/nxp/nxp-kernel-selection.md` for details.",
)
parser.add_argument(
"--use_random_dataset",
required=False,
default=False,
action="store_true",
help="The calibration and testing datasets will be generated randomly instead of being downloaded.",
)
parser.add_argument(
"--fetch_constants_to_sram",
required=False,
default=False,
action="store_true",
help="This feature allows running models which do not fit into SRAM by offloading them to an external memory.",
)
args = parser.parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
neutron_target_spec = NeutronTargetSpec(target=args.target)
# 1. pick model from one of the supported lists
model, example_inputs, calibration_inputs = get_model_and_inputs_from_name(
args.model_name, args.use_random_dataset
)
model = model.eval()
if args.use_channels_last_dim_order:
# Turn the model to channels last.
model.to(memory_format=torch.channels_last)
# The dim order of the example inputs will define the dim order of the intermediate tensors in the model.
example_inputs = tuple(
i.to(memory_format=torch.channels_last) for i in example_inputs
)
else:
# Notify the user of this option.
print(
"HINT: Converting your model to channels last may significantly improve inference speed. You can use the "
"flag `--use_channels_last_dim_order`. See `docs/source/backends/nxp/nxp-dim-order.md` for more information."
)
# 2. Export the model to ATEN
exported_program = torch.export.export(model, example_inputs, strict=True)
module = exported_program.module()
# 3. Quantize if required
if args.quantize:
quantizer = NeutronQuantizer(neutron_target_spec, is_qat=args.use_qat)
if args.use_qat:
match args.model_name:
case "cifar10":
print("Starting two epochs of QAT training with CifarNet model...")
module = prepare_qat_pt2e(module, quantizer)
module = move_exported_model_to_train(module)
module = train_cifarnet_model(module, num_epochs=2)
module = move_exported_model_to_eval(module)
module = convert_pt2e(module)
case _:
raise ValueError(
f"QAT training is not supported for model '{args.model_name}'"
)
else:
if calibration_inputs is None:
logging.warning(
"No calibration inputs available, using the example inputs instead"
)
calibration_inputs = example_inputs
module = calibrate_and_quantize(module, calibration_inputs, quantizer)
if args.so_library is not None:
logging.debug(f"Loading libraries: {args.so_library}")
torch.ops.load_library(args.so_library)
if args.test:
match args.model_name:
case "cifar10":
accuracy = verify_cifarnet_model(module)
case _:
raise NotImplementedError(
f"Testing of model `{args.model_name}` is not yet supported."
)
quantized_str = "quantized " if args.quantize else ""
print(f"\nAccuracy of the {quantized_str}`{args.model_name}`: {accuracy}\n")
# 4. Transform and lower
compile_spec = generate_neutron_compile_spec(
args.target,
operators_not_to_delegate=args.operators_not_to_delegate,
fetch_constants_to_sram=args.fetch_constants_to_sram,
dump_kernel_selection_code=args.dump_kernel_selection_code,
)
partitioners = (
[
NeutronPartitioner(
compile_spec,
neutron_target_spec,
post_quantization_state_dict=module.state_dict(),
)
]
if args.delegate
else []
)
edge_program_manager = to_edge_transform_and_lower(
export(module, example_inputs, strict=True),
transform_passes=NeutronEdgePassManager(),
partitioner=partitioners,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=core_aten_ops_exception_list,
),
)
if args.remove_quant_io_ops:
edge_program_manager = edge_program_manager.transform(
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
)
edge_program_manager = edge_program_manager.transform(
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
)
logging.debug(f"Lowered graph:\n{edge_program_manager.exported_program().graph}")
# 5. Export to ExecuTorch program
try:
exec_prog = edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
except RuntimeError as e:
if "Missing out variants" in str(e.args[0]):
raise RuntimeError(
e.args[0]
+ ".\nThis likely due to an external so library not being loaded. Supply a path to it with the "
"--so_library flag."
).with_traceback(e.__traceback__) from None
else:
raise e
def executorch_program_to_str(ep, verbose=False):
f = io.StringIO()
ep.dump_executorch_program(out=f, verbose=verbose)
return f.getvalue()
logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}")
# 6. Serialize to *.pte
model_name = f"{args.model_name}" + (
"_nxp_delegate" if args.delegate is True else ""
)
save_pte_program(exec_prog, model_name)
# 7. Optionally visualize the model.
if args.visualize == "show":
visualize_with_clusters(exec_prog.exported_program())
elif args.visualize == "store":
file_name = f"{args.model_name}-visualization.json"
logging.info(
f"Saved the graph visualization in `{file_name}`. It can be opened using the ModelExplorer."
)
visualize_with_clusters(exec_prog.exported_program(), file_name)