-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfair_classifier.py
More file actions
56 lines (43 loc) · 1.58 KB
/
fair_classifier.py
File metadata and controls
56 lines (43 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import matplotlib.pyplot as plt
import Data_Generata
from quad_fair_glvq import MeanDiffGlvqModel
from GLVQ.plot_2d import to_tango_colors, tango_color
def split_x(x, dim_protected):
protected = []
new_x = []
for i in range(0, len(x)):
protected.append(x[i][dim_protected])
new_x.append(
x[i][:dim_protected] + x[i][dim_protected + 1:]
)
return new_x, protected
print(__doc__)
nb_ppc = 100
print('Fair GLVQ:')
# generate random data
# TODO: use unfair data
toy_data = np.append(
np.random.multivariate_normal([0, 0], np.eye(2) / 2, size=nb_ppc),
np.random.multivariate_normal([5, 0], np.eye(2) / 2, size=nb_ppc),axis=0)
toy_label = np.append(np.zeros(nb_ppc), np.ones(nb_ppc), axis=0)
toy_protected_labels = np.append(np.zeros(nb_ppc), np.ones(nb_ppc)) #np.random.randint(2, size=2*nb_ppc)
print(len(toy_protected_labels))
print(len(toy_data))
weights = 'uniform'
# model fitting
# TODO: add platt scaling
glvq = MeanDiffGlvqModel(1)
# glvq.fit(new_x, y, protected_labels)
glvq.fit(toy_data, toy_label, toy_protected_labels)
pred = glvq.predict(toy_data)
# plotting
plt.scatter(toy_data[:, 0], toy_data[:, 1], c=to_tango_colors(toy_label), alpha=0.5)
plt.scatter(toy_data[:, 0], toy_data[:, 1], c=to_tango_colors(pred), marker='.')
plt.scatter(glvq.w_[:, 0], glvq.w_[:, 1],
c=tango_color('aluminium', 5), marker='D')
plt.scatter(glvq.w_[:, 0], glvq.w_[:, 1],
c=to_tango_colors(glvq.c_w_, 0), marker='.')
plt.axis('equal')
print('classification accuracy:', glvq.score(toy_data, toy_label))
plt.show()