-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvnext_tiny.py
More file actions
127 lines (113 loc) · 6.52 KB
/
convnext_tiny.py
File metadata and controls
127 lines (113 loc) · 6.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# proprietary to SiMa and may be covered by U.S. and Foreign Patents,
# patents in process, and are protected by trade secret or copyright law.
#
# Dissemination of this information or reproduction of this material is
# strictly forbidden unless prior written permission is obtained from
# SiMa.ai. Access to the source code contained herein is hereby forbidden
# to anyone except current SiMa.ai employees, managers or contractors who
# have executed Confidentiality and Non-disclosure agreements explicitly
# covering such access.
#
# The copyright notice above does not evidence any actual or intended
# publication or disclosure of this source code, which includes information
# that is confidential and/or proprietary, and is a trade secret, of SiMa.ai.
#
# ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, PUBLIC PERFORMANCE, OR PUBLIC
# DISPLAY OF OR THROUGH USE OF THIS SOURCE CODE WITHOUT THE EXPRESS WRITTEN
# CONSENT OF SiMa.ai IS STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE
# LAWS AND INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS TO
# REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, USE, OR
# SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
#
# **************************************************************************
import os
import logging
import numpy as np
import dataclasses
from afe.apis.defines import QuantizationParams, quantization_scheme, default_calibration, CalibrationMethod, gen1_target, gen2_target, bfloat16_scheme
from afe.apis.loaded_net import load_model
from afe.apis.model import Model
from afe.apis.release_v1 import get_model_sdk_version
from afe.apis.error_handling_variables import enable_verbose_error_messages
from afe.backends.mpk.interface import L2CachingMode
from afe.ir.defines import InputName, RequantizationMode
from afe.ir.tensor_type import ScalarType
from afe.load.importers.general_importer import onnx_source
def main(arm_only, asym, per_ch, calibration,
retain, l2_cache, verbose, load_net,
bias_correction, mode, nbits,
batch_size, target, bf16, ceq):
if verbose:
enable_verbose_error_messages()
# Get Model SDK version
sdk_version = get_model_sdk_version()
print(f"Model SDK version: {sdk_version}")
# Model information
input_name, input_shape, input_type = ("modelInput", (1, 3, 224, 224), ScalarType.float32)
input_shapes_dict = {input_name: input_shape}
input_types_dict = {input_name: input_type}
model_path = "models/convnext_tiny.onnx"
importer_params = onnx_source(model_path, input_shapes_dict, input_types_dict)
platform = gen1_target if target == 'gen1' else gen2_target
loaded_net = load_model(importer_params, target=platform)
inputs = {InputName(input_name): np.random.rand(1, 224, 224, 3)}
calibration_data = [inputs]
quant_configs: QuantizationParams = QuantizationParams(calibration_method=default_calibration(),
activation_quantization_scheme=quantization_scheme(asym, False),
weight_quantization_scheme=quantization_scheme(False, per_ch),
node_names={''},
custom_quantization_configs=None)
if nbits == 16:
asymmetric_per_tensor_nbits = quantization_scheme(asym, False, nbits)
quant_configs = quant_configs.with_activation_quantization(asymmetric_per_tensor_nbits)
if bf16:
quant_configs = quant_configs.with_activation_quantization(bfloat16_scheme()).with_weight_quantization(bfloat16_scheme())
if mode == 'tflite':
requantization_mode = RequantizationMode.tflite
quant_configs = quant_configs.with_requantization_mode(requantization_mode)
if bias_correction:
quant_configs = quant_configs.with_bias_correction(bias_correction)
if ceq:
quant_configs = quant_configs.with_channel_equalization(ceq)
if calibration in ['min_max', 'moving_average', 'entropy', 'percentile']:
calibration_method = CalibrationMethod.from_str(calibration)
quant_configs = dataclasses.replace(quant_configs, calibration_method=calibration_method)
sdk_net = loaded_net.quantize(calibration_data,
quant_configs,
model_name="convnext_tiny",
arm_only=arm_only,
log_level=logging.INFO)
# Execute the quantized net
sdk_net_output = sdk_net.execute(inputs=inputs)
saved_model_name = f"convnext_tiny_asym_{asym}_per_ch_{per_ch}"
if load_net:
# Save the SDK net and two files are generated: sima model file and JSON file for Netron
# Extension ".sima" is added internally if not present in the provided name
saved_model_directory = os.path.join(os.getcwd(), 'result', saved_model_name, 'sdk_net')
os.makedirs(saved_model_directory, mode=0o777, exist_ok=True)
sdk_net.save(model_name=saved_model_name, output_directory=saved_model_directory)
# Load a saved net - note that sima extention is optional
load_model_name = f"convnext_tiny_asym_{asym}_per_ch_{per_ch}.sima"
net_read_back = Model.load(model_name=load_model_name, network_directory=saved_model_directory)
assert isinstance(net_read_back, Model)
# Compile the quantized net and generate LM file and MPK JSON file
saved_mpk_directory = os.path.join(os.getcwd(), 'result', saved_model_name, 'mpk')
os.makedirs(saved_mpk_directory, mode=0o777, exist_ok=True)
if retain and l2_cache:
sdk_net.compile(output_path=saved_mpk_directory,
batch_size=batch_size,
l2_caching_mode=L2CachingMode.SINGLE_MODEL,
retained_temporary_directory_name=saved_mpk_directory)
elif retain:
sdk_net.compile(output_path=saved_mpk_directory,
batch_size=batch_size,
retained_temporary_directory_name=saved_mpk_directory)
elif l2_cache:
sdk_net.compile(output_path=saved_mpk_directory,
batch_size=batch_size,
l2_caching_mode=L2CachingMode.SINGLE_MODEL)
else:
sdk_net.compile(output_path=saved_mpk_directory, batch_size=batch_size)
if __name__ == "__main__":
main(False, True, True, 'mse', True, False, True, False, False, 'sima', 8, 1, 'gen1', False, False)