-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathqat_example.py
More file actions
383 lines (306 loc) · 12.3 KB
/
qat_example.py
File metadata and controls
383 lines (306 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Example script for exporting simple models to flatbuffer
# refernce from
# https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master#
import argparse
import copy
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
VgfQuantizer,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.backends.arm.util.arm_model_evaluator import (
GenericModelEvaluator,
MobileNetV2Evaluator,
)
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
# To use Cortex-M backend
from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import (
QuantizedOpFusionPass,
)
from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
ReplaceQuantNodesPass,
)
from executorch.devtools import generate_etrecord
from executorch.devtools.backend_debug import get_delegation_info
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
from executorch.extension.export_util.utils import save_pte_program
from tabulate import tabulate
from torch.utils.data import DataLoader
# Quantize model if required using the standard export quantizaion flow.
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
import torch.nn.utils.prune as prune
import typing
from example_models.mbv2cifar10 import mobilenetv2
import torchvision
import torchvision.transforms as transforms
from torch.ao.quantization.quantize_fx import fuse_fx
def get_dataloader(
data_path: str,
img_size: int = 224,
batch_size: int = 128,
worker: int = 4,
distributed: bool = False,
download: bool = False,
mean: tuple = (0.4914, 0.4822, 0.4465),
std: tuple = (0.2471, 0.2435, 0.2616),
) -> typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
""" Constructs the dataloaders for training and validating
Args:
data_path (str): The path of the dataset
img_size (int, optional): The size of the image. Defaults to 224.
batch_size (int, optional): The batch size of the dataloader. Defaults to 128.
worker (int, optional): The number of workers. Defaults to 4.
distributed (bool, optional): Whether to use DDP. Defaults to False.
download (bool, optional): Whether to download the dataset. Defaults to False.
mean (tuple, optional): Normalize mean
std (tuple, optional): Normalize std
Returns:
typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: The dataloaders for training and \
validating
"""
train_dataset = torchvision.datasets.CIFAR10(
root=data_path,
train=True,
download=download,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.Resize(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
),
)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=worker,
pin_memory=True,
)
val_dataset = torchvision.datasets.CIFAR10(
root=data_path,
train=False,
download=False,
transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize(mean, std)]
),
)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True
)
return train_loader, val_loader
def quantize_qat_mb(
model: torch.nn.Module,
compile_specs: EthosUCompileSpec,
img_size: int = 32,
) -> torch.nn.Module:
logging.info("QAT Quantizing Model...")
quantizer = EthosUQuantizer(compile_specs)
try:
operator_config = get_symmetric_quantization_config(is_qat=True)
except TypeError:
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
m = prepare_qat_pt2e(model, quantizer)
## batch size should be 1
train_loader, val_loader = get_dataloader(
data_path="./data",
img_size=img_size,
batch_size=1,
worker=4,
download=True
)
max_iteration = None
num_epochs = 1
# max_iteration = 10
# num_epochs = 1
optimizer = torch.optim.SGD(m.parameters(), lr=1e-4, momentum=0.9, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
### the loss should accumulation at 64 steps to simulate 64 batch size to do back propagation
accumulation_steps = 64
try:
m.train()
except NotImplementedError:
pass
for epoch in range(num_epochs):
print(f"\n--- QAT Epoch {epoch+1}/{num_epochs} ---")
running_loss = 0.0
optimizer.zero_grad()
for i, (image, label) in enumerate(train_loader):
if max_iteration is not None and i >= max_iteration:
break
output = m(image)
if isinstance(output, tuple) or isinstance(output, list):
output = output[0]
elif isinstance(output, dict) and 'out' in output:
output = output['out']
loss = criterion(output, label)
loss = loss / accumulation_steps
loss.backward()
running_loss += loss.item() * accumulation_steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(f'QAT Epoch [{epoch+1}]: Image [{i+1}/{len(train_loader)}]\tAvg Loss: {running_loss/accumulation_steps:.4f}')
running_loss = 0.0
scheduler.step()
try:
m.eval()
except NotImplementedError:
pass
m = convert_pt2e(m)
logging.debug(f"QAT Quantized model: {m}")
return m
def evaluate_accuracy(model: Any, data_loader: DataLoader, device: str = 'cpu', limit: Optional[int] = None) -> float:
"""Evaluates the classification accuracy of a model on a given dataset."""
try:
if hasattr(model, 'eval'):
model.eval()
except NotImplementedError:
pass # ExportedProgram may not support it, but it's already in eval mode
correct = 0
total = 0
print(f"Starting evaluation on {len(data_loader)} batches...")
with torch.no_grad():
for i, (images, labels) in enumerate(data_loader):
if limit is not None and i >= limit:
break
images, labels = images.to(device), labels.to(device)
outputs = model(images)
# Extract main output tensor if evaluating an ExportedProgram module that returns a tuple/dict
if isinstance(outputs, tuple) or isinstance(outputs, list):
outputs = outputs[0]
elif isinstance(outputs, dict) and 'out' in outputs:
outputs = outputs['out']
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if (i + 1) % 1000 == 0:
print(f'Eval: [{i+1}/{len(data_loader)}]\tAcc: {100 * correct / total:.2f}%')
acc = 100 * correct / total
return acc
if __name__ == "__main__": # noqa: C901
model_w = 32
model_h = 32
example_inputs = (torch.randn(1, 3, model_w, model_h),)
original_model = mobilenetv2.MobileNetV2()
original_model.load_state_dict(torch.load(mobilenetv2.DEFAULT_STATE_DICT, weights_only=False))
model = original_model.eval()
# export under the assumption we quantize, the exported form also works
# in to_edge if we don't quantize
exported_program = torch.export.export(
model, example_inputs , strict=True
)
model = exported_program.module()
model_fp32 = model
graph_module = exported_program.graph_module
# _ = graph_module.print_readable()
######################## TO TOSA delegate and quantize int8
# to_edge_TOSA_delegate
compile_spec = EthosUCompileSpec(
"ethos-u55-64",
system_config="My_Sys_Cfg",
memory_mode="My_Mem_Mode_Parent",
extra_flags=["--verbose-operators", "--verbose-cycle-estimate"],
# extra_flags=["--verbose-operators", "--verbose-cycle-estimate", "--optimise Size"],
config_ini="himax_vela.ini",
)
# Evaluate FP32 and INT8 models
print("\n--- Evaluating Models ---")
train_loader, val_loader = get_dataloader(
data_path="./data",
img_size=model_w,
batch_size=1,
worker=2,
download=True
)
print("\nEvaluating FP32 model...")
fp32_acc = evaluate_accuracy(model_fp32, val_loader)
print(f"FP32 Model Accuracy: {fp32_acc:.2f}%\n")
###### quantize_qat_model
print("\n--- Starting QAT ---")
qat_model_instance = copy.deepcopy(original_model)
for m in qat_model_instance.modules():
if hasattr(m, 'inplace'):
m.inplace = False
#### open eval mode to close drop out
qat_model_instance.eval()
#### workaround
#### at this torchao.quantization.pt2e.quantize_pt2e version prepare_qat_pt2e could not correct fuse the operater about (Conv + BN -> Conv)
#### fuse (Conv + BN -> Conv)
print("Fusing Conv and BatchNorm dynamically before Export...")
fused_qat_model = fuse_fx(qat_model_instance)
example_inputs_qat = (torch.randn(1, 3, model_w, model_h),)
try:
model_qat = torch.export.export(
fused_qat_model,
example_inputs_qat,
strict=True
).module()
except Exception as e:
print(f"Capture failed: {e}")
#### do QAT
model_qat_int8 = quantize_qat_mb(
model_qat,
compile_spec,
model_w
)
print("\nEvaluating QAT INT8 model...")
try:
model_qat_int8.eval()
except NotImplementedError:
pass
qat_int8_acc = evaluate_accuracy(model_qat_int8, val_loader)
print(f"QAT INT8 Model Accuracy: {qat_int8_acc:.2f}%\n")
#### re-export int8 module which generated from QAT and freeze all weights and Bias
exported_qat_program_int8 = torch.export.export(
model_qat_int8,
example_inputs_qat,
strict=True
)
qat_partitioner = EthosUPartitioner(compile_spec)
qat_edge = to_edge_transform_and_lower(
exported_qat_program_int8,
partitioner=[qat_partitioner],
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
qat_edge.transform(passes=[QuantizeInputs(qat_edge, [0]), QuantizeOutputs(qat_edge, [0])])
qat_executorch_program_manager = qat_edge.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
print("\n=== Delegated Graph (Check if CPU ops are gone!) ===")
print(qat_executorch_program_manager.exported_program().graph_module)
save_pte_program(qat_executorch_program_manager, "mbv2_cifar10_qat_himax_ini_vela_4_5_0.pte")
print("\nQAT PTE file generated successfully!")