Skip to content

Commit c4e1deb

Browse files
committed
adagrowacq improvements
1 parent f8b5cd3 commit c4e1deb

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

pycona/active_algorithms/adagrowacq.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,17 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6767

6868
init_bias = list(self.env.instance.bias)
6969
init_bias_provided = len(init_bias) > 0
70-
70+
7171
while len(Y) < n_vars:
7272
it += 1
7373

7474
Y_new = self.choose_variables(X, Y, v)
75-
if verbose >= 3:
76-
print(f"Added to GrowAcq: {set(Y_new) - set(Y)}")
77-
78-
# add the constraints involving x and other added variables
7975
if init_bias_provided:
8076
visible_now = set(get_con_subset(init_bias, Y_new))
8177
self.env.instance.bias = list(visible_now)
8278
init_bias = set(init_bias) - visible_now
8379
else:
84-
self.env.instance.construct_bias_for_vars(set(Y_new) - set(Y), Y)
80+
self.env.instance.construct_bias_for_vars(set(Y_new) - set(Y), Y_new)
8581
if verbose >= 3:
8682
print(f"Created {len(self.env.instance.bias)} constraints")
8783

@@ -91,6 +87,11 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
9187
print(f"\nGrowAcq: calling inner_algorithm for {len(Y)}/{n_vars} variables")
9288
cl_size = len(self.env.instance.cl)
9389
self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, metrics=self.env.metrics, X=Y)
90+
# Add implied constraints from bias to cl
91+
implied_constraints = get_con_subset(self.env.instance.bias, Y)
92+
self.env.instance.cl.extend(implied_constraints)
93+
self.env.instance.bias = [c for c in self.env.instance.bias if c not in set(implied_constraints)] # remove implied constraints from bias
94+
9495
self.env.metrics.print_statistics()
9596
if len(self.env.instance.cl) == cl_size:
9697
if self._adaptive_grow == 1:

pycona/ca_environment/ca_env_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def remove_from_bias(self, C):
7474
print(f"removing the following constraints from bias: {C}")
7575

7676
self.instance.bias = list(set(self.instance.bias) - set(C))
77+
# Persist removed candidates so that if the bias is reconstructed later,
78+
# previously-eliminated negative constraints do not re-enter and get re-labeled.
79+
self.instance.excluded_cons = list(set(self.instance.excluded_cons).union(set(C)))
7780

7881
def add_to_cl(self, C):
7982
"""
@@ -91,6 +94,8 @@ def add_to_cl(self, C):
9194
# Add constraint(s) c to the learned network and remove them from the bias
9295
self.instance.cl.extend(C)
9396
self.instance.bias = list(set(self.instance.bias) - set(C))
97+
# Safety: if something ended up excluded but is now learned, ensure it is no longer excluded.
98+
self.instance.excluded_cons = list(set(self.instance.excluded_cons) - set(C))
9499

95100
self.metrics.cl += len(C)
96101
if self.verbose == 1:

0 commit comments

Comments
 (0)