Skip to content

Commit 0d8cddb

Browse files
committed
[tensorrt] [byoc] [plugin] Allow users to specify tunning option
1 parent 2e68c8b commit 0d8cddb

4 files changed

Lines changed: 26 additions & 15 deletions

File tree

python/tvm/tpat/cuda/kernel.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@
2323

2424

2525
class Config(object):
26-
def __init__(self, onnx_model, input_shapes, target, work_dir) -> None:
26+
def __init__(self, onnx_model, input_shapes, target, tunning_option) -> None:
2727
self.onnx_model = onnx_model
2828
self.input_shapes = input_shapes
29-
self.work_dir = work_dir
29+
self.tunning_option = tunning_option
30+
self.work_dir = tunning_option["work_dir"] or "./log_db"
3031

3132
if target == "gpu":
3233
self.target = self._detect_cuda_target()
3334

3435
def tune_option(self):
35-
return {
36+
default = {
3637
"target": self.target,
3738
"builder": ms.builder.LocalBuilder(),
3839
"runner": ms.runner.LocalRunner(),
@@ -41,6 +42,9 @@ def tune_option(self):
4142
"work_dir": self.work_dir,
4243
}
4344

45+
default.update(self.tunning_option)
46+
return default
47+
4448
def _detect_cuda_target(self):
4549
dev = tvm.cuda()
4650
if not dev.exist:
@@ -59,10 +63,10 @@ def _detect_cuda_target(self):
5963

6064

6165
class Kernel(object):
62-
def __init__(self, name, onnx_model, input_shapes, enable_tunning, work_dir):
66+
def __init__(self, name, onnx_model, input_shapes, enable_tunning, tunning_option):
6367
self._name = name
6468
self._enable_tunning = enable_tunning
65-
self._config = Config(onnx_model, input_shapes, "gpu", work_dir)
69+
self._config = Config(onnx_model, input_shapes, "gpu", tunning_option)
6670

6771
self._lib = None
6872
self._module = None

python/tvm/tpat/cuda/pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _extract_target_onnx_node(model, tunning_node):
5959

6060

6161
def pipeline(
62-
onnx_file: str, node_names: list[str], enable_tunning: bool, work_dir: str, output_onnx: str
62+
onnx_file: str, node_names: list[str], enable_tunning: bool, tunning_option: object, output_onnx: str
6363
) -> Tuple[str, list[str]]:
6464
"""Generate plugins for specified nodes in an ONNX model.
6565
@@ -73,8 +73,8 @@ def pipeline(
7373
Names of the nodes to be generated as TensorRT plugins.
7474
enable_tunning : bool
7575
Flag indicating whether tunning is enabled.
76-
work_dir : str
77-
Path to the tunning log file where the records will be saved.
76+
tunning_option : object
77+
Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target.
7878
output_onnx : str
7979
Path to the output ONNX file where the modified model will be saved.
8080
@@ -106,7 +106,7 @@ def pipeline(
106106

107107
subgraph, submodel, shapes = _extract_target_onnx_node(inferred_model, node)
108108

109-
kernel = Kernel(plugin_name, submodel, shapes, enable_tunning, work_dir)
109+
kernel = Kernel(plugin_name, submodel, shapes, enable_tunning, tunning_option)
110110
kernel.run()
111111

112112
## 3.1 fill in template

python/tvm/tpat/cuda/plugin/Makefile

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
# limitations under the License.
1515
#
1616

17+
# Variables need to be defined by Users
1718
CUDA_PATH = /path/to/cuda
1819
CUDNN_PATH = /path/to/cudnn
1920
TRT_PATH = /path/to/TensorRT
21+
ARCH = sm_86
22+
########################################
2023

2124
CUDA_INC_PATH = $(CUDA_PATH)/include
2225
CUDA_LIB_PATH = $(CUDA_PATH)/lib
@@ -28,13 +31,9 @@ CUDNN_LIB_PATH = $(CUDNN_PATH)/lib
2831
TRT_INC_PATH = $(TRT_PATH)/include
2932
TRT_LIB_PATH = $(TRT_PATH)/lib
3033

31-
32-
ARCH = sm_86
3334
GCC = g++
3435
NVCC = $(CUDA_PATH)/bin/nvcc
35-
# CCFLAGS = -g -std=c++11 -DNDEBUG
3636
CCFLAGS = -w -std=c++11
37-
# CCFLAGS+= -DDEBUG_ME
3837
INCLUDES := -I. -I$(CUDA_COM_PATH) -I$(CUDA_INC_PATH) -I$(CUDNN_INC_PATH) -I$(TRT_INC_PATH) -I/usr/include
3938

4039
LDFLAGS := -L$(CUDA_LIB_PATH) -L$(CUDNN_LIB_PATH) -L$(TRT_LIB_PATH)

tests/python/tpat/cuda/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ def name_without_num(name):
9494
ops_name = [op_name]
9595

9696
_, trt_plugin_names = tpat.cuda.pipeline(
97-
INPUT_MODEL_FILE, ops_name, False, "./log_db", OUTPUT_MODEL_FILE
97+
INPUT_MODEL_FILE,
98+
ops_name,
99+
False,
100+
{"work_dir": "./log_db", "max_trials_per_task": 500},
101+
OUTPUT_MODEL_FILE,
98102
)
99103

100104
load_plugin(trt_plugin_names)
@@ -197,7 +201,11 @@ def verify_with_ort_with_trt(
197201
ops_name = [op_name]
198202

199203
_, trt_plugin_names = tpat.cuda.pipeline(
200-
INPUT_MODEL_FILE, ops_name, False, "./log_db", OUTPUT_MODEL_FILE
204+
INPUT_MODEL_FILE,
205+
ops_name,
206+
False,
207+
{"work_dir": "./log_db", "max_trials_per_task": 500},
208+
OUTPUT_MODEL_FILE,
201209
)
202210

203211
load_plugin(trt_plugin_names)

0 commit comments

Comments
 (0)