Skip to content

Commit 66a31de

Browse files
authored
Algorithms and methods for specific set of variables (#7)
* query generation on given set of vars * give set of variables to algorithms to focus on * new tests * growacq tests * return flat list of distinct constraints in benchmarks
1 parent 08570e0 commit 66a31de

20 files changed

Lines changed: 223 additions & 64 deletions

File tree

pycona/active_algorithms/gquacq.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class GQuAcq(AlgorithmCAInteractive):
2121

2222
def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10):
2323
"""
24-
Initialize the PQuAcq algorithm with an optional constraint acquisition environment.
24+
Initialize the GQuAcq algorithm with an optional constraint acquisition environment.
2525
2626
:param ca_env: An instance of ActiveCAEnv, default is None.
2727
: param GQmax: maximum number of generalization queries
@@ -30,16 +30,21 @@ def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10):
3030
self._negativeQ = []
3131
self._qg_max = qg_max
3232

33-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None):
33+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
3434
"""
35-
Learn constraints using the QuAcq algorithm by generating queries and analyzing the results.
35+
Learn constraints using the GQuAcq algorithm by generating queries and analyzing the results.
3636
3737
:param instance: the problem instance to acquire the constraints for
3838
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3939
:param verbose: Verbosity level, default is 0.
4040
:param metrics: statistics logger during learning
41+
:param X: The set of variables to consider, default is None.
4142
:return: the learned instance
4243
"""
44+
if X is None:
45+
X = instance.X
46+
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
47+
4348
self.env.init_state(instance, oracle, verbose, metrics)
4449

4550
if len(self.env.instance.bias) == 0:
@@ -52,8 +57,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
5257
print("Number of Queries: ", self.env.metrics.membership_queries_count)
5358

5459
gen_start = time.time()
55-
Y = self.env.run_query_generation()
56-
gen_end = time.time()
60+
Y = self.env.run_query_generation(X)
61+
gen_end = time.time()
5762

5863
if len(Y) == 0:
5964
# if no query can be generated it means we have (prematurely) converged to the target network -----

pycona/active_algorithms/growacq.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,45 @@ def __init__(self, ca_env: ActiveCAEnv = None, inner_algorithm: AlgorithmCAInter
2626
super().__init__(env)
2727
self.inner_algorithm = inner_algorithm if inner_algorithm is not None else MQuAcq2(ca_env)
2828

29-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None):
29+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
3030
"""
3131
Learn constraints by incrementally adding variables and using the inner algorithm to learn constraints
3232
for each added variable.
3333
3434
:param instance: the problem instance to acquire the constraints for
3535
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3636
:param verbose: Verbosity level, default is 0.
37+
:param X: The set of variables to consider, default is None.
3738
:param metrics: statistics logger during learning
3839
:return: the learned instance
3940
"""
41+
if X is None:
42+
X = instance.X
43+
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
44+
4045
self.env.init_state(instance, oracle, verbose, metrics)
4146

4247
if verbose >= 1:
4348
print(f"Running growacq with {self.inner_algorithm} as inner algorithm")
4449

4550
self.inner_algorithm.env = copy.copy(self.env)
4651

47-
self.env.instance.X = []
52+
Y = []
4853

49-
n_vars = len(self.env.instance.variables.flat)
50-
for x in self.env.instance.variables.flat:
54+
n_vars = len(X)
55+
for x in X:
5156
# we 'grow' the inner bias by adding one extra variable at a time
52-
self.env.instance.X.append(x)
57+
Y.append(x)
5358
# add the constraints involving x and other added variables
54-
self.env.instance.construct_bias_for_var(x)
59+
if len(self.env.instance.bias) == 0:
60+
self.env.instance.construct_bias_for_var(x, Y)
5561
if verbose >= 3:
5662
print(f"Added variable {x} in GrowAcq")
5763
print("size of B in growacq: ", len(self.env.instance.bias))
5864

5965
if verbose >= 2:
60-
print(f"\nGrowAcq: calling inner_algorithm for {len(self.env.instance.X)}/{n_vars} variables")
61-
self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, metrics=self.env.metrics)
66+
print(f"\nGrowAcq: calling inner_algorithm for {len(Y)}/{n_vars} variables")
67+
self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, X=Y, metrics=self.env.metrics)
6268

6369
if verbose >= 3:
6470
print("C_L: ", len(self.env.instance.cl))

pycona/active_algorithms/mquacq.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@ def __init__(self, ca_env: ActiveCAEnv = None):
2121
"""
2222
super().__init__(ca_env)
2323

24-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None):
24+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
2525
"""
2626
Learn constraints using the modified QuAcq algorithm by generating queries and analyzing the results.
2727
2828
:param instance: the problem instance to acquire the constraints for
2929
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3030
:param verbose: Verbosity level, default is 0.
3131
:param metrics: statistics logger during learning
32+
:param X: The set of variables to consider, default is None.
3233
:return: the learned instance
3334
"""
35+
if X is None:
36+
X = instance.X
37+
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
38+
3439
self.env.init_state(instance, oracle, verbose, metrics)
3540

3641
if len(self.env.instance.bias) == 0:
@@ -47,7 +52,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
4752

4853
# generate e in D^X accepted by C_l and rejected by B
4954
gen_start = time.time()
50-
Y = self.env.run_query_generation()
55+
Y = self.env.run_query_generation(X)
5156
gen_end = time.time()
5257

5358
if len(Y) == 0:

pycona/active_algorithms/mquacq2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,21 @@ def __init__(self, ca_env: ActiveCAEnv = None, *, perform_analyzeAndLearn: bool
3131
self.cl_neighbours = None
3232
self.hashX = None
3333

34-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None):
34+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
3535
"""
3636
Learn constraints using the modified QuAcq algorithm by generating queries and analyzing the results.
3737
3838
:param instance: the problem instance to acquire the constraints for
3939
:param oracle: An instance of Oracle, default is to use the user as the oracle.
4040
:param verbose: Verbosity level, default is 0.
4141
:param metrics: statistics logger during learning
42+
:param X: The set of variables to consider, default is None.
4243
:return: the learned instance
4344
"""
45+
if X is None:
46+
X = instance.X
47+
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
48+
4449
self.env.init_state(instance, oracle, verbose, metrics)
4550

4651
# Hash the variables
@@ -52,7 +57,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
5257

5358
while True:
5459
gen_start = time.time()
55-
Y = self.env.run_query_generation()
60+
Y = self.env.run_query_generation(X)
5661
gen_end = time.time()
5762
self.env.metrics.increase_generation_time(gen_end - gen_start)
5863

pycona/active_algorithms/pquacq.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,21 @@ def __init__(self, ca_env: ActiveCAEnv = None):
2525
"""
2626
super().__init__(ca_env)
2727

28-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None):
28+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
2929
"""
3030
Learn constraints using the QuAcq algorithm by generating queries and analyzing the results.
3131
3232
:param instance: the problem instance to acquire the constraints for
3333
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3434
:param verbose: Verbosity level, default is 0.
3535
:param metrics: statistics logger during learning
36+
:param X: The set of variables to consider, default is None.
3637
:return: the learned instance
3738
"""
39+
if X is None:
40+
X = instance.X
41+
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
42+
3843
self.env.init_state(instance, oracle, verbose, metrics)
3944

4045
if len(self.env.instance.bias) == 0:
@@ -47,7 +52,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
4752
print("Number of Queries: ", self.env.metrics.membership_queries_count)
4853

4954
gen_start = time.time()
50-
Y = self.env.run_query_generation()
55+
Y = self.env.run_query_generation(X)
5156
gen_end = time.time()
5257

5358
if len(Y) == 0:

pycona/active_algorithms/quacq.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@ def __init__(self, ca_env: ActiveCAEnv = None):
2121
"""
2222
super().__init__(ca_env)
2323

24-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None):
24+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
2525
"""
2626
Learn constraints using the QuAcq algorithm by generating queries and analyzing the results.
2727
2828
:param instance: the problem instance to acquire the constraints for
2929
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3030
:param verbose: Verbosity level, default is 0.
3131
:param metrics: statistics logger during learning
32+
:param X: The set of variables to consider, default is None.
3233
:return: the learned instance
3334
"""
35+
if X is None:
36+
X = instance.X
37+
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
38+
3439
self.env.init_state(instance, oracle, verbose, metrics)
3540

3641
if len(self.env.instance.bias) == 0:
@@ -43,7 +48,7 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
4348
print("Number of Queries: ", self.env.metrics.membership_queries_count)
4449

4550
gen_start = time.time()
46-
Y = self.env.run_query_generation()
51+
Y = self.env.run_query_generation(X)
4752
gen_end = time.time()
4853

4954
if len(Y) == 0:

pycona/benchmarks/exam_timetabling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ..answering_queries.constraint_oracle import ConstraintOracle
44
from ..problem_instance import ProblemInstance, absvar
5-
5+
from cpmpy.transformations.normalize import toplevel_list
66

77
def day_of_exam(course, slots_per_day):
88
return course // slots_per_day
@@ -27,7 +27,7 @@ def construct_examtt_simple(nsemesters=9, courses_per_semester=6, slots_per_day=
2727
for row in courses:
2828
model += cp.AllDifferent(day_of_exam(row, slots_per_day)).decompose()
2929

30-
C_T = list(model.constraints)
30+
C_T = list(set(toplevel_list(model.constraints)))
3131

3232
if model.solve():
3333
courses.clear()

pycona/benchmarks/job_shop_scheduling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import cpmpy as cp
44
import numpy as np
55
from cpmpy.expressions.utils import all_pairs
6-
6+
from cpmpy.transformations.normalize import toplevel_list
77
from ..answering_queries.constraint_oracle import ConstraintOracle
88
from ..problem_instance import ProblemInstance, absvar
99

@@ -56,7 +56,7 @@ def construct_job_shop_scheduling_problem(n_jobs, machines, horizon, seed=0):
5656
for (j1, t1), (j2, t2) in all_pairs(zip(*tasks_on_mach)):
5757
m += (end[j1, t1] <= start[j2, t2]) | (end[j2, t2] <= start[j1, t1])
5858

59-
C_T = list(model.constraints)
59+
C_T = list(set(toplevel_list(model.constraints)))
6060

6161
max_duration = max(duration)
6262

pycona/benchmarks/jsudoku.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import cpmpy as cp
2-
2+
from cpmpy.transformations.normalize import toplevel_list
33
from ..answering_queries.constraint_oracle import ConstraintOracle
44
from ..problem_instance import ProblemInstance, absvar
55

@@ -49,6 +49,6 @@ def construct_jsudoku():
4949

5050
instance = ProblemInstance(variables=grid, params=parameters, language=lang, name="jsudoku")
5151

52-
oracle = ConstraintOracle(C_T)
52+
oracle = ConstraintOracle(list(set(toplevel_list(C_T))))
5353

5454
return instance, oracle

pycona/benchmarks/murder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
import cpmpy as cp
3-
3+
from cpmpy.transformations.normalize import toplevel_list
44
from ..answering_queries.constraint_oracle import ConstraintOracle
55
from ..problem_instance import ProblemInstance, absvar
66

@@ -45,6 +45,6 @@ def construct_murder_problem():
4545

4646
instance = ProblemInstance(variables=grid, language=lang, name="murder")
4747

48-
oracle = ConstraintOracle(C_T)
48+
oracle = ConstraintOracle(list(set(toplevel_list(C_T))))
4949

5050
return instance, oracle

0 commit comments

Comments
 (0)