-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
195 lines (170 loc) · 7.22 KB
/
main.py
File metadata and controls
195 lines (170 loc) · 7.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""Lobe segmentation with MONAI"""
from monai.utils import set_determinism
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
import torch
import os
import yaml
import random
import glob
import argparse
from pathlib import Path
import MetricLogger
from dataloader import train_dataloader, val_dataloader, test_dataloader, npy_train_loader
from models import unet64, unet128, unet256, unet512, unet1024, unetr16
from scheduler import WarmupCosineSchedule
from train import train
from test import test
def run_train(config, config_id):
# unwrap directory paths
MODEL_DIR = os.path.join(config["model_dir"], config_id)
CHECKPOINT_DIR = os.path.join(config["checkpoint_dir"], config_id)
LOG_DIR = os.path.join(config["log_dir"], config_id)
DATA_DIR = config["data_dir"]
# Set randomness
set_determinism(seed=config["random_seed"])
random.seed(config["random_seed"])
# Make paths
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
Path(LOG_DIR).mkdir(parents=True, exist_ok=True)
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
# Logger
logger = MetricLogger.MetricLogger(config, config_id)
writer = SummaryWriter(log_dir=LOG_DIR)
# Load data
images = sorted(glob.glob(os.path.join(DATA_DIR, config["image_type"])))
# limit sample size if specified
if config["sample_size"]:
images = random.sample(images, config["sample_size"])
# split dataset into train and validation
val_size = int(len(images) * config["val_ratio"])
random.shuffle(images)
val_images, train_images = images[:val_size], images[val_size:]
# get dataloaders
if config["image_type"]=="*.npy":
print("From pre transformed npys")
train_loader = npy_train_loader(config, train_images)
val_loader = npy_train_loader(config, val_images)
else:
train_loader = train_dataloader(config, train_images)
val_loader = val_dataloader(config, val_images)
# LABEL_SHAPE = (512, 512, 320) # All labels have this shape, but input shapes vary
# Initialize Model, Loss, and Optimizer
device = torch.device("cuda:0")
if config["model"] == 'unet512':
model = unet512(6).to(device)
elif config["model"] == 'unet1024':
model = unet1024(6).to(device)
elif config["model"] == 'unetr16':
model = unetr16(6).to(device)
elif config["model"] == 'unet128':
model = unet128(6).to(device)
elif config["model"] == 'unet64':
model = unet64(6).to(device)
else:
model = unet256(6).to(device)
# loss_function = DiceLoss(include_background=config["include_bg_loss"], to_onehot_y=True, softmax=True)
loss_function = DiceCELoss(include_background=config["include_bg_loss"], to_onehot_y=True, softmax=True)
# optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])
dice_metric = DiceMetric(False, reduction="mean", get_not_nans=False)
start_epoch = 0
# scheduler
n_batches = len(train_loader)
print(f"Total steps: {config['epochs']*n_batches}")
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=config["warmup_steps"],
t_total=config["epochs"]*n_batches, last_epoch=start_epoch*n_batches-1)
# Resume training from checkpoint if indicated
if config["checkpoint"]:
print(f"Resuming training of {config_id} from {config['checkpoint']}")
checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, config["checkpoint"]))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
# Finetune pretrained model if indicated
if config["pretrained"]:
print(f"Fine tuning model from {config['pretrained']}")
pretrained = torch.load(config['pretrained'])
model.load_state_dict(pretrained)
train(config,
config_id,
model,
device,
optimizer,
scheduler,
loss_function,
dice_metric,
train_loader,
val_loader,
(start_epoch, config["epochs"]),
logger,
writer,
CHECKPOINT_DIR,
MODEL_DIR)
def run_test(config, config_id, out_name, output_seg=False, output_clip=False):
DATA_DIR = config["test_dir"]
MODEL_DIR = os.path.join(config["model_dir"], config_id)
metrics_path = os.path.join(MODEL_DIR, out_name)
seg_dir = os.path.join(MODEL_DIR, 'segs') if output_seg else False
clip_dir = os.path.join(MODEL_DIR, 'clips') if output_clip else False
model_path = os.path.join(MODEL_DIR, f"{config_id}_best_model.pth")
# model_path = '/home/local/VANDERBILT/litz/github/MASILab/lobe_seg/models/0418cv_luna16/fold1/0418cv_luna16_best_model.pth'
# Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
# Path(seg_dir).mkdir(parents=True, exist_ok=True)
# Path(clip_dir).mkdir(parents=True, exist_ok=True)
# Set randomness
set_determinism(seed=config["random_seed"])
random.seed(config["random_seed"])
# Load data
images = sorted(glob.glob(os.path.join(DATA_DIR, config["test_image_type"])))
test_loader = test_dataloader(config, images)
# Initialize Model and test metric
device = torch.device("cuda:0")
if config["model"] == 'unet512':
model = unet512(6).to(device)
elif config["model"] == 'unet1024':
model = unet1024(6).to(device)
elif config["model"] == 'unetr16':
model = unetr16(6).to(device)
elif config["model"] == 'unet128':
model = unet128(6).to(device)
elif config["model"] == 'unet64':
model = unet64(6).to(device)
elif config["model"] == 'unet256':
model = unet256(6).to(device)
else:
model = unet256(6).to(device)
# Set metric to compute average over each class
test_metric = DiceMetric(include_background=False, reduction="none")
test(config,
config_id,
device,
model,
model_path,
test_metric,
test_loader,
metrics_path,
seg_dir,
clip_dir)
def load_config(config_name, config_dir):
with open(os.path.join(config_dir, config_name)) as file:
config = yaml.load(file, Loader=yaml.FullLoader)
return config
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config-id', type=str)
parser.add_argument('--out-name', type=str, default='test.csv')
parser.add_argument('--train', action='store_true', default=False)
parser.add_argument('--test', action='store_true', default=False)
parser.add_argument('--output-seg', action='store_true', default=False)
parser.add_argument('--output-clip', action='store_true', default=False)
args = parser.parse_args()
CONFIG_DIR = "/home/local/VANDERBILT/ohas/Desktop/Programming/new_lobe/lobe_seg/configs"
config = load_config(f"Config_{args.config_id}.YAML", CONFIG_DIR)
if args.train:
run_train(config, args.config_id)
if args.test:
run_test(config, args.config_id, args.out_name, output_seg=args.output_seg,
output_clip=args.output_clip)