|
12 | 12 | from .. import Metrics |
13 | 13 |
|
14 | 14 |
|
15 | | -class GQuAcq(AlgorithmCAInteractive): |
| 15 | +class MineAcq(AlgorithmCAInteractive): |
16 | 16 |
|
17 | 17 | """ |
18 | | - QuAcq variation algorithm, using mine&Ask to detect types of variables and ask genralization queries. From: |
| 18 | + QuAcq variation algorithm, using mine&Ask to detect types of variables and ask generalization queries. From: |
19 | 19 | "Detecting Types of Variables for Generalization in Constraint Acquisition", ICTAI 2015. |
20 | 20 | """ |
21 | 21 |
|
22 | 22 | def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10): |
23 | 23 | """ |
24 | | - Initialize the GQuAcq algorithm with an optional constraint acquisition environment. |
| 24 | + Initialize the MineAcq algorithm with an optional constraint acquisition environment. |
25 | 25 |
|
26 | 26 | :param ca_env: An instance of ActiveCAEnv, default is None. |
27 | | - : param GQmax: maximum number of generalization queries |
| 27 | + :param qg_max: maximum number of generalization queries |
28 | 28 | """ |
29 | 29 | super().__init__(ca_env) |
30 | 30 | self._negativeQ = [] |
31 | 31 | self._qg_max = qg_max |
32 | 32 |
|
33 | | - def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None): |
| 33 | + def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None, X=None): |
34 | 34 | """ |
35 | | - Learn constraints using the GQuAcq algorithm by generating queries and analyzing the results. |
| 35 | + Learn constraints using the QuAcq algorithm by generating queries and analyzing the results. |
36 | 36 |
|
37 | 37 | :param instance: the problem instance to acquire the constraints for |
38 | 38 | :param oracle: An instance of Oracle, default is to use the user as the oracle. |
39 | 39 | :param verbose: Verbosity level, default is 0. |
40 | 40 | :param metrics: statistics logger during learning |
41 | | - :param X: The set of variables to consider, default is None. |
| 41 | + :param X: List of variables to consider for learning. If None, uses all variables from the instance. |
42 | 42 | :return: the learned instance |
43 | 43 | """ |
44 | | - if X is None: |
45 | | - X = instance.X |
46 | | - assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X) |
47 | | - assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X) |
48 | | - |
49 | 44 | self.env.init_state(instance, oracle, verbose, metrics) |
50 | 45 |
|
| 46 | + if X is None: |
| 47 | + X = list(self.env.instance.variables.flat) |
| 48 | + |
51 | 49 | if len(self.env.instance.bias) == 0: |
52 | 50 | self.env.instance.construct_bias(X) |
53 | 51 |
|
54 | 52 | while True: |
55 | 53 | if self.env.verbose > 0: |
56 | 54 | print("Size of CL: ", len(self.env.instance.cl)) |
57 | 55 | print("Size of B: ", len(self.env.instance.bias)) |
58 | | - print("Number of Queries: ", self.env.metrics.membership_queries_count) |
| 56 | + print("Number of Queries: ", self.env.metrics.total_queries) |
59 | 57 |
|
60 | 58 | gen_start = time.time() |
61 | 59 | Y = self.env.run_query_generation(X) |
62 | | - gen_end = time.time() |
| 60 | + gen_end = time.time() |
63 | 61 |
|
64 | 62 | if len(Y) == 0: |
65 | 63 | # if no query can be generated it means we have (prematurely) converged to the target network ----- |
66 | 64 | self.env.metrics.finalize_statistics() |
67 | 65 | if self.env.verbose >= 1: |
68 | 66 | print(f"\nLearned {self.env.metrics.cl} constraints in " |
69 | | - f"{self.env.metrics.membership_queries_count} queries.") |
| 67 | + f"{self.env.metrics.total_queries} queries.") |
70 | 68 | self.env.instance.bias = [] |
71 | 69 | return self.env.instance |
72 | 70 |
|
@@ -130,16 +128,13 @@ def mineAsk(self, r): |
130 | 128 | # potentially generalizing leads to UNSAT |
131 | 129 | new_CL = self.env.instance.cl.copy() |
132 | 130 | new_CL += B |
133 | | - if any(Y2.issubset(Y) for Y2 in self._negativeQ) or not can_be_clique(G.subgraph(Y), D) or \ |
134 | | - len(B) > 0 or cp.Model(new_CL).solve(): |
135 | | - continue |
136 | | - |
137 | | - if self.env.ask_generalization_query(self.env.instance.language[r], B): |
138 | | - gen_flag = True |
139 | | - self.env.add_to_cl(B) |
140 | | - else: |
141 | | - gq_counter += 1 |
142 | | - self._negativeQ.append(Y) |
| 131 | + if not (any(Y2.issubset(Y) for Y2 in self._negativeQ) or not (can_be_clique(G.subgraph(Y), D) and (len(B) > 0) and cp.Model(new_CL).solve())): |
| 132 | + if self.env.ask_generalization_query(self.env.instance.language[r], B): |
| 133 | + gen_flag = True |
| 134 | + self.env.add_to_cl(B) |
| 135 | + else: |
| 136 | + gq_counter += 1 |
| 137 | + self._negativeQ.append(Y) |
143 | 138 |
|
144 | 139 | if not gen_flag: |
145 | 140 | communities = nx.community.greedy_modularity_communities(G.subgraph(Y)) |
|
0 commit comments