-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestfile.py
More file actions
47 lines (35 loc) · 1.83 KB
/
testfile.py
File metadata and controls
47 lines (35 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import zliobaite
import relaxed_equalized_odds.calib_eq_odds as odds
from sklearn.model_selection import train_test_split
data = zliobaite.generate_data(100, 0.5, 0.1, 0.5)
prediction = 0
label = 1
group = 2
print("--------------- start testfile ---------------")
for i in range(len(data)):
if data[i][2] == 'F':
data[i][2] = 1
else:
data[i][2] = 0
print(data[i])
test_data, val_data, _, _ = train_test_split(data, data, test_size=1/2, random_state=42)
# Create model objects - one for each group, validation and test
group_0_val_data = val_data[val_data[group] == 0]
group_1_val_data = val_data[val_data[group] == 1]
group_0_test_data = test_data[test_data[group] == 0]
group_1_test_data = test_data[test_data[group] == 1]
group_0_val_model = odds.Model(group_0_val_data[prediction], group_0_val_data[label])
group_1_val_model = odds.Model(group_1_val_data[prediction], group_1_val_data[label])
group_0_test_model = odds.Model(group_0_test_data[prediction], group_0_test_data[label])
group_1_test_model = odds.Model(group_1_test_data[prediction], group_1_test_data[label])
# Find mixing rates for equalized odds models
_, _, mix_rates = odds.Model.eq_odds(group_0_val_model, group_1_val_model)
# Apply the mixing rates to the test models
eq_odds_group_0_test_model, eq_odds_group_1_test_model = odds.Model.eq_odds(group_0_test_model,
group_1_test_model,
mix_rates)
# Print results on test model
print('Original group 0 model:\n%s\n' % repr(group_0_test_model))
print('Original group 1 model:\n%s\n' % repr(group_1_test_model))
print('Equalized odds group 0 model:\n%s\n' % repr(eq_odds_group_0_test_model))
print('Equalized odds group 1 model:\n%s\n' % repr(eq_odds_group_1_test_model))