-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmain.py
More file actions
41 lines (38 loc) · 1.45 KB
/
main.py
File metadata and controls
41 lines (38 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch import optim
from torch.autograd import Variable
import numpy as np
import pickle
from utils import Hps
from utils import DataLoader
from utils import Logger
from utils import SingleDataset
from solver import Solver
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train', default=True, action='store_true')
parser.add_argument('--test', default=False, action='store_true')
parser.add_argument('--load_model', default=False, action='store_true')
parser.add_argument('-flag', default='train')
parser.add_argument('-hps_path', default='./hps/vctk.json')
parser.add_argument('-load_model_path', default='')
parser.add_argument('-dataset_path', default='./vctk.h5')
parser.add_argument('-index_path', default='./index.json')
parser.add_argument('-output_model_path', default='./pkl')
args = parser.parse_args()
hps = Hps()
hps.load(args.hps_path)
hps_tuple = hps.get_tuple()
dataset = SingleDataset(args.dataset_path,
args.index_path,
seg_len=hps_tuple.seg_len)
data_loader = DataLoader(dataset)
solver = Solver(hps_tuple, data_loader)
if args.load_model:
solver.load_model(args.load_model_path)
if args.train:
solver.train(args.output_model_path, args.flag, mode='pretrain_G')
solver.train(args.output_model_path, args.flag, mode='pretrain_D')
solver.train(args.output_model_path, args.flag, mode='train')
solver.train(args.output_model_path, args.flag, mode='patchGAN')