-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
94 lines (71 loc) · 2.55 KB
/
inference.py
File metadata and controls
94 lines (71 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import time
import pandas as pd
import torch
from sklearn.metrics import roc_auc_score
import joblib
from utils import *
def main(args):
set_seed()
start_time = time.time()
df_test = pd.read_csv(args.manifest_test)
print(f"Read {args.manifest_test} dataset containing {len(df_test)} samples")
print("Initializing POLARIX model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = load_trained_model(device, args.checkpoint)
print(f"Finished loading POLARIX in {(time.time() - start_time):.2f}s")
model.eval()
dataset = FeatureBags(
df=df_test,
data_dir=args.data_features_dir,
)
loader = get_val_loader(val_split=dataset, workers=args.workers)
eval_loss, probs, logits, labels, slide_ids = evaluate_loader(model, device, loader)
eval_auc = roc_auc_score(labels, probs)
print(f"Eval loss: {eval_loss}, AUC: {eval_auc:.4f}")
print(f"Eval AUC {eval_auc}")
# Apply Platt scaling for probability calibration
platt_scaler = joblib.load(args.checkpoint_platt_model)
calprobs = platt_scaler.predict_proba(logits.reshape(-1, 1))[:, 1]
preds_df = pd.DataFrame(probs, columns=[f"prob"])
preds_df["prob calibrated"] = calprobs
preds_df["logit"] = logits
preds_df["label"] = labels
preds_df["slide_id"] = slide_ids
# Save into a csv file.
print(f"Saving predictions...")
preds_df.to_csv("predictions.csv")
print(f"Finished making POLARIX predictions in {(time.time() - start_time):.2f}s")
total_time = time.time() - start_time
avg_time_per_slide = total_time / len(df_test)
print(f"Average prediction time per slide: {avg_time_per_slide:.2f}s")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Inference script")
parser.add_argument(
"--manifest_test",
type=str,
help="CSV file of test set listing all work_ids, slides, and labels",
)
parser.add_argument(
"--checkpoint",
type=str,
help="path to POLARIX model checkpoint",
)
parser.add_argument(
"--checkpoint_platt_model",
type=str,
help="path to Platt model checkpoint",
)
parser.add_argument(
"--data_features_dir",
type=str,
help="Directory where all *_features.h5 files are stored",
)
parser.add_argument(
"--workers",
help="The number of workers to use for the data loaders.",
type=int,
default=4,
)
args = parser.parse_args()
main(args)