-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcross_validation.py
More file actions
32 lines (26 loc) · 994 Bytes
/
cross_validation.py
File metadata and controls
32 lines (26 loc) · 994 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import pipeline
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold
def n_kfold_cross_validation(x, y, clf, n, k):
"""Perform k-fold cross-validation n times, and
returns the obtained accuracies"""
# List that will contain the score of each k-fold
means = []
stds = []
# Perform k-fold cross-val n times
for _ in range(n):
skf = StratifiedKFold(n_splits=k, shuffle=True)
# Compute the mean of scores of the k-fold
list_scores = cross_val_score(clf, x, y, cv=skf, n_jobs=-1)
print(list_scores)
# Compute mean and standard deviation
mean_score = np.mean(list_scores)
std_score = np.std(list_scores)
# Store them
means.append(mean_score)
stds.append(std_score)
return means, stds
def display_info(scores):
print('Mean accuracy is :', np.mean(scores))
print('Standard deviation is: ', np.std(scores))