diff --git a/tzrec/utils/test_util.py b/tzrec/utils/test_util.py index e484dbc9..2176715b 100644 --- a/tzrec/utils/test_util.py +++ b/tzrec/utils/test_util.py @@ -21,9 +21,8 @@ from torch import nn from torch.fx import GraphModule -from tzrec.acc.aot_utils import export_model_aot, load_model_aot +from tzrec.acc.aot_utils import export_unified_model_aot, load_model_aot from tzrec.models.model import ScriptWrapper -from tzrec.utils.export_util import split_model from tzrec.utils.fx_util import symbolic_trace nv_gpu_unavailable: Tuple[bool, str] = ( @@ -79,8 +78,7 @@ def create_test_model( model = ScriptWrapper(model) assert data is not None assert test_dir, "test_dir must be specified for AOT_INDUCTOR" - sparse, dense, meta_info = split_model(data, model, test_dir) - export_model_aot(sparse, dense, data, meta_info, test_dir) + export_unified_model_aot(model, data, test_dir) model = load_model_aot(test_dir, torch.device("cuda:0")) return model else: