@@ -26,39 +26,45 @@ def __init__(self, ca_env: ActiveCAEnv = None, inner_algorithm: AlgorithmCAInter
2626 super ().__init__ (env )
2727 self .inner_algorithm = inner_algorithm if inner_algorithm is not None else MQuAcq2 (ca_env )
2828
29- def learn (self , instance : ProblemInstance , oracle : Oracle = UserOracle (), verbose = 0 , metrics : Metrics = None ):
29+ def learn (self , instance : ProblemInstance , oracle : Oracle = UserOracle (), verbose = 0 , X = None , metrics : Metrics = None ):
3030 """
3131 Learn constraints by incrementally adding variables and using the inner algorithm to learn constraints
3232 for each added variable.
3333
3434 :param instance: the problem instance to acquire the constraints for
3535 :param oracle: An instance of Oracle, default is to use the user as the oracle.
3636 :param verbose: Verbosity level, default is 0.
37+ :param X: The set of variables to consider, default is None.
3738 :param metrics: statistics logger during learning
3839 :return: the learned instance
3940 """
41+ if X is None :
42+ X = instance .X
43+ assert isinstance (X , list ) and set (X ).issubset (set (instance .X )), "When using .learn(), set parameter X must be a list of variables"
44+
4045 self .env .init_state (instance , oracle , verbose , metrics )
4146
4247 if verbose >= 1 :
4348 print (f"Running growacq with { self .inner_algorithm } as inner algorithm" )
4449
4550 self .inner_algorithm .env = copy .copy (self .env )
4651
47- self . env . instance . X = []
52+ Y = []
4853
49- n_vars = len (self . env . instance . variables . flat )
50- for x in self . env . instance . variables . flat :
54+ n_vars = len (X )
55+ for x in X :
5156 # we 'grow' the inner bias by adding one extra variable at a time
52- self . env . instance . X .append (x )
57+ Y .append (x )
5358 # add the constraints involving x and other added variables
54- self .env .instance .construct_bias_for_var (x )
59+ if len (self .env .instance .bias ) == 0 :
60+ self .env .instance .construct_bias_for_var (x , Y )
5561 if verbose >= 3 :
5662 print (f"Added variable { x } in GrowAcq" )
5763 print ("size of B in growacq: " , len (self .env .instance .bias ))
5864
5965 if verbose >= 2 :
60- print (f"\n GrowAcq: calling inner_algorithm for { len (self . env . instance . X )} /{ n_vars } variables" )
61- self .env .instance = self .inner_algorithm .learn (self .env .instance , oracle , verbose = verbose , metrics = self .env .metrics )
66+ print (f"\n GrowAcq: calling inner_algorithm for { len (Y )} /{ n_vars } variables" )
67+ self .env .instance = self .inner_algorithm .learn (self .env .instance , oracle , verbose = verbose , X = Y , metrics = self .env .metrics )
6268
6369 if verbose >= 3 :
6470 print ("C_L: " , len (self .env .instance .cl ))
0 commit comments