-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathaccuracies.py
More file actions
27 lines (24 loc) · 1.01 KB
/
accuracies.py
File metadata and controls
27 lines (24 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay
import numpy as np
def accuracies(y_true, y_pred, adaptive=True):
y_true = np.array(y_true).flatten()
y_pred = np.array(y_pred).flatten()
auc = roc_auc_score(y_true, y_pred)
if adaptive == True:
fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
dist = fpr ** 2 + (1 - tpr) ** 2
best_thres = thresholds[np.argmin(dist)]
else:
best_thres = 0.5
y_pred_val = np.where(np.array(y_pred).flatten() >= best_thres, 1, 0)
cm = confusion_matrix(y_true, y_pred_val)
tn, fp, fn, tp = cm.ravel()
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
accuracy = (tp + tn) / (tn + fp + fn + tp)
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
print('AUC score:', np.round(auc, 4))
print('Accuracy:', np.round(accuracy, 4))
print('Sensitivity:', np.round(sensitivity, 4))
print('Specificity:', np.round(specificity, 4))