-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrepeat_pruner.py
More file actions
40 lines (29 loc) · 1.17 KB
/
repeat_pruner.py
File metadata and controls
40 lines (29 loc) · 1.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import math
from optuna.pruners import BasePruner
from optuna.storages import BaseStorage # NOQA
from optuna.structs import TrialState
class RepeatPruner(BasePruner):
# Based on https://github.com/Minyus/optkeras/blob/master/optkeras/optkeras.py
def prune(self, study, trial):
print(" *** Calling prunner ***")
# Get all trials
all_trials = study.trials
# Count completed trials
n_trials = len([t for t in all_trials
if t.state == TrialState.COMPLETE])
# If there are no previous trials
if n_trials == 0:
print("Not pruned Trial n_trials==0")
return False
# Assert that current trial is running
assert all_trials[-1].state == TrialState.RUNNING
# Extract params from previously completed trials
completed_params_list = \
[t.params for t in all_trials \
if t.state == TrialState.COMPLETE]
# Check if current trial is repeated
if all_trials[-1].params in completed_params_list:
print(" ---- Pruned Trial ----")
return True
print("Not pruned Trial")
return False