2323
2424
2525class Config (object ):
26- def __init__ (self , onnx_model , input_shapes , target , work_dir ) -> None :
26+ def __init__ (self , onnx_model , input_shapes , target , tunning_option ) -> None :
2727 self .onnx_model = onnx_model
2828 self .input_shapes = input_shapes
29- self .work_dir = work_dir
29+ self .tunning_option = tunning_option
30+ self .work_dir = tunning_option ["work_dir" ] or "./log_db"
3031
3132 if target == "gpu" :
3233 self .target = self ._detect_cuda_target ()
3334
3435 def tune_option (self ):
35- return {
36+ default = {
3637 "target" : self .target ,
3738 "builder" : ms .builder .LocalBuilder (),
3839 "runner" : ms .runner .LocalRunner (),
@@ -41,6 +42,9 @@ def tune_option(self):
4142 "work_dir" : self .work_dir ,
4243 }
4344
45+ default .update (self .tunning_option )
46+ return default
47+
4448 def _detect_cuda_target (self ):
4549 dev = tvm .cuda ()
4650 if not dev .exist :
@@ -59,10 +63,10 @@ def _detect_cuda_target(self):
5963
6064
6165class Kernel (object ):
62- def __init__ (self , name , onnx_model , input_shapes , enable_tunning , work_dir ):
66+ def __init__ (self , name , onnx_model , input_shapes , enable_tunning , tunning_option ):
6367 self ._name = name
6468 self ._enable_tunning = enable_tunning
65- self ._config = Config (onnx_model , input_shapes , "gpu" , work_dir )
69+ self ._config = Config (onnx_model , input_shapes , "gpu" , tunning_option )
6670
6771 self ._lib = None
6872 self ._module = None
0 commit comments