Skip to content

Commit 8bfed7d

Browse files
committed
gquacq to mineacq + fixes
1 parent 4f35247 commit 8bfed7d

4 files changed

Lines changed: 23 additions & 28 deletions

File tree

pycona/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .find_constraint import FindC, FindC2
2727
from .query_generation import QGen, TQGen, PQGen
2828
from .find_scope import FindScope, FindScope2
29-
from .active_algorithms import QuAcq, PQuAcq, GQuAcq, GrowAcq, MQuAcq, MQuAcq2
29+
from .active_algorithms import QuAcq, PQuAcq, MineAcq, GrowAcq, MQuAcq, MQuAcq2
3030
from .problem_instance import ProblemInstance, absvar, langBasic, langDist, langEqNeq
3131
from .predictor import CountsPredictor, FeaturesRelDim, FeaturesSimpleRel
3232

pycona/active_algorithms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
from .mquacq import MQuAcq
1313
from .growacq import GrowAcq
1414
from .pquacq import PQuAcq
15-
from .gquacq import GQuAcq
15+
from .gquacq import MineAcq

pycona/active_algorithms/gquacq.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,61 +12,59 @@
1212
from .. import Metrics
1313

1414

15-
class GQuAcq(AlgorithmCAInteractive):
15+
class MineAcq(AlgorithmCAInteractive):
1616

1717
"""
18-
QuAcq variation algorithm, using mine&Ask to detect types of variables and ask genralization queries. From:
18+
QuAcq variation algorithm, using mine&Ask to detect types of variables and ask generalization queries. From:
1919
"Detecting Types of Variables for Generalization in Constraint Acquisition", ICTAI 2015.
2020
"""
2121

2222
def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10):
2323
"""
24-
Initialize the GQuAcq algorithm with an optional constraint acquisition environment.
24+
Initialize the MineAcq algorithm with an optional constraint acquisition environment.
2525
2626
:param ca_env: An instance of ActiveCAEnv, default is None.
27-
: param GQmax: maximum number of generalization queries
27+
:param qg_max: maximum number of generalization queries
2828
"""
2929
super().__init__(ca_env)
3030
self._negativeQ = []
3131
self._qg_max = qg_max
3232

33-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
33+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None, X=None):
3434
"""
35-
Learn constraints using the GQuAcq algorithm by generating queries and analyzing the results.
35+
Learn constraints using the QuAcq algorithm by generating queries and analyzing the results.
3636
3737
:param instance: the problem instance to acquire the constraints for
3838
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3939
:param verbose: Verbosity level, default is 0.
4040
:param metrics: statistics logger during learning
41-
:param X: The set of variables to consider, default is None.
41+
:param X: List of variables to consider for learning. If None, uses all variables from the instance.
4242
:return: the learned instance
4343
"""
44-
if X is None:
45-
X = instance.X
46-
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
47-
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
48-
4944
self.env.init_state(instance, oracle, verbose, metrics)
5045

46+
if X is None:
47+
X = list(self.env.instance.variables.flat)
48+
5149
if len(self.env.instance.bias) == 0:
5250
self.env.instance.construct_bias(X)
5351

5452
while True:
5553
if self.env.verbose > 0:
5654
print("Size of CL: ", len(self.env.instance.cl))
5755
print("Size of B: ", len(self.env.instance.bias))
58-
print("Number of Queries: ", self.env.metrics.membership_queries_count)
56+
print("Number of Queries: ", self.env.metrics.total_queries)
5957

6058
gen_start = time.time()
6159
Y = self.env.run_query_generation(X)
62-
gen_end = time.time()
60+
gen_end = time.time()
6361

6462
if len(Y) == 0:
6563
# if no query can be generated it means we have (prematurely) converged to the target network -----
6664
self.env.metrics.finalize_statistics()
6765
if self.env.verbose >= 1:
6866
print(f"\nLearned {self.env.metrics.cl} constraints in "
69-
f"{self.env.metrics.membership_queries_count} queries.")
67+
f"{self.env.metrics.total_queries} queries.")
7068
self.env.instance.bias = []
7169
return self.env.instance
7270

@@ -130,16 +128,13 @@ def mineAsk(self, r):
130128
# potentially generalizing leads to UNSAT
131129
new_CL = self.env.instance.cl.copy()
132130
new_CL += B
133-
if any(Y2.issubset(Y) for Y2 in self._negativeQ) or not can_be_clique(G.subgraph(Y), D) or \
134-
len(B) > 0 or cp.Model(new_CL).solve():
135-
continue
136-
137-
if self.env.ask_generalization_query(self.env.instance.language[r], B):
138-
gen_flag = True
139-
self.env.add_to_cl(B)
140-
else:
141-
gq_counter += 1
142-
self._negativeQ.append(Y)
131+
if not (any(Y2.issubset(Y) for Y2 in self._negativeQ) or not (can_be_clique(G.subgraph(Y), D) and (len(B) > 0) and cp.Model(new_CL).solve())):
132+
if self.env.ask_generalization_query(self.env.instance.language[r], B):
133+
gen_flag = True
134+
self.env.add_to_cl(B)
135+
else:
136+
gq_counter += 1
137+
self._negativeQ.append(Y)
143138

144139
if not gen_flag:
145140
communities = nx.community.greedy_modularity_communities(G.subgraph(Y))

tests/test_algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
problem_generators = [construct_murder_problem(), construct_examtt_simple(), construct_nurse_rostering()]
1515

1616
classifiers = [DecisionTreeClassifier(), RandomForestClassifier()]
17-
algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2(), ca.GQuAcq(), ca.PQuAcq()]
17+
algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2(), ca.MineAcq(), ca.PQuAcq()]
1818
fast_tests_algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2()]
1919

2020
def _generate_fast_benchmarks():

0 commit comments

Comments
 (0)