-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathclassifier.py
More file actions
51 lines (40 loc) · 1.45 KB
/
classifier.py
File metadata and controls
51 lines (40 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from sklearn import svm
import numpy as np
import matplotlib.pyplot as plt
class eye_classifier:
cls = None
samples = []
labels = []
def __init__(self):
pass
def train(self, samples, labels):
self.samples = samples
self.labels = labels
self.cls = svm.SVC()
self.cls.fit(self.samples, self.labels)
def classify(self, sample):
return self.cls.predict([sample])
def plot(self):
fig, sub = plt.subplots()
X0, X1 = [it[0] for it in samples], [it[1] for it in samples]
xx, yy = self.make_meshgrid(X0, X1)
self.plot_contours(sub, self.cls, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)
sub.scatter(X0, X1, c=self.labels, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
sub.set_xlim(xx.min(), xx.max())
sub.set_ylim(yy.min(), yy.max())
sub.set_xlabel('x')
sub.set_ylabel('y')
sub.set_xticks(())
sub.set_yticks(())
plt.show()
def make_meshgrid(self, x, y, h=.02):
x_min, x_max = min(x) - 1, max(x) + 1
y_min, y_max = min(y) - 1, max(y) + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
return xx, yy
def plot_contours(self, ax, clf, xx, yy, **params):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
out = ax.contourf(xx, yy, Z, **params)
return out