This tutorial further explains how SORTD works and how you can add a new optimization task to it.
Here we assume that your new task is called MyNewTask.
- Copy the
include/tasks/accuracy/accuracy.hinto your newinclude/tasks/my_new_task.h. - Copy the
src/tasks/accuracy/accuracy.cppinto your newsrc/tasks/my_new_task.cpp. - Add both new files to the
CMakeList.txtin the header and source files respectively. - Change the class name
Accuracyin both the new header and source file toMyNewTask. - Add your new task to every .cpp file that implements SORTD's templates. E.g.,
src/solver/solver.cpp,src/utils/file_reader.cpp, etc. - Add
"my-new-task"as a valid value for the"task"parameter insrc/solver/define_parameters.cpp. - Check for
"my-new-task"insrc/main.cppand create a newSolver<MyNewTask>and callReadData<MyNewTask>(). - Check if the code compiles and run SORTD with
-task my-new-taskto see if it runs properly. You are now ready to modify the new task as you wish!
To modify your new task, you need to understand a few things about SORTD. This section first explains the class types for your class and the constants that define an optimization task. It then further explains the functions an optimization task is expected to have. Finally, more attention is given to the constants and functions required to implement the special depth-two solver for your task.
Note that many default values for the types, constants and functions are defined include/tasks/optimization_task.h
Since C++ is strongly typed, SORTD uses templates to work with different class types. You can set the types for your class in your my_new_task.h header file.
SolType: The data type of your solution value (e.g.,intfor misclassification score,doublefor MSE)TestSolType: The data type of your test solution value. This could differ fromSolType, e.g., when you computedoublebranching costs while training, but want to ignore this and only measureintmisclassification costs in test evaluation.LabelType: The label type of an instance. For classification this isintand for regression this isdouble.SolLabelType: The label type assigned to a leaf node. Commonly this type is the same asLabelType.ET: The type of the Extra Data per instance (beyond feature and label data). E.g., forPrescriptivePolicythe extra counter factual data is stored in aPPGDataclass.ContextType: The type of the context class (i.e., the state of a search node). Default isBranchContextthat just stores the current branch.
Implicitly defined:
SolContaineror equivalentlyNode<MyNewTask>: The data type of information on the solution (e.g., solution value, feature, label).
The following constants (or constexpr) define your optimization task and its behavior in SORTD.
custom_leaf(bool):Trueif you provide a custom solve leaf node function. Default is to itreate over the set of possible labels. Therefore, a custom leaf node function is required when the label is not discrete.custom_get_label(bool):Trueif you provide a custom function to get the optimal label for a leaf node, e.g., for regression where the optimal label is the mean of the instance labels. Otherwise the label with minimum cost is selected.has_constraint(bool):Trueif the task has a constraint. E.g., a fairness constraint. Note that a minimum leaf node constraint is specified separately through the parameter"min-leaf-node-size".element_additive(bool):Trueif the solution values are element-wise additive. This means SORTD can use its similarity lower bound.Falsedisables the similarity lower bound.
Related to branching costs:
has_branching_costs(bool):Trueif the optimization task has branching costs.element_branching_costs(bool):Trueif the branching costs depend on individual instances in the data set (not implemented for the depth-two solver).constant_branching_costs(bool):Trueif the branching costs are constant and do not depend on the context or on the data set, e.g., for cost-complexity pruning.
Related to preprocessing:
preprocess_data(bool):Trueif the task performs preprocessing on the data (both train and test). This allows modification of instances before computing.preprocess_train_test_data(bool):Trueif the task performs preprocessing on the train or test data.
Related to task specific optimizations:
custom_lower_bound(bool):Trueif the task provides a custom lower bound.custom_similarity_lb(bool):Trueif the task provides a custom similarity lower bound.
Best and worst solution values or label
worst(SolType): The worst solution value possible, e.g.,INT32_MAX.best(SolType): The best solution value possible, e.g.,0.worst_label(SolLabelType): The default label for an unitialized node, e.g.,INT32_MAX.minimum_difference(SolType): The minimum difference between two non equivalent solutions (e.g., 1 for misclassification score). This is used to compute an upper bound from a given solution.
Related to preprocessing
void UpdateParameters(const ParameterHandler& parameters): inform the task of the (updated) parameters.void InformTrainData(const ADataView&, const DataSummary&): informs the task about the training data (before training).void InformTestData(const ADataView&, const DataSummary&): informs the task about the test data (before evaluating).void PreprocessData(AData& data, bool train): preprocess the data, withtrain == Trueif this is the training phase. Only ifpreprocess_dataisTrue.void PreprocessTrainData(ADataView& train_data): preprocess the training data.PreprocessTestDatais defined similarly. Only ifpreprocess_train_test_dataisTrue.
Related to branching:
bool MayBranchOnFeature(int feature): returnFalseis this feature is not available for branching.void GetLeftContext(const ADataView&, const ContextType& context, int feature, ContextType& left_context): update theleft_contextfrom thecontextwhen branching left on the specified feature.void GetRightContext(const ADataView&, const ContextType& context, int feature, ContextType& right_context): update theright_contextfrom thecontextwhen branching right on the specified feature.
Related to branching costs (only if has_branching_costs is True):
SolType GetBranchingCosts(const ADataView&, const ContextType& context, int feature): get the branching costs for the data in the given context when branching on feature. SimilarlyGetTestBranchingCostsreturns the branching costs when evaluating the test performance.
Related to (optimizing) the leaf nodes:
SolType GetLeafCosts(const ADataView& data, const ContextType& context, SolLabelType label): return the leaf costs for the given data in the given context for the assigned label.TestSolType GetTestLeafCosts(const ADataView& data, const ContextType& context, SolLabelType label): return the test leaf costs for the given data in the given context for the assigned label.LabelType Classify(const AInstance*, SolLabelType label): return the label for the given instance if the label of the leaf node where this instance ends in islabel.SolContainer SolveLeafNode(const ADataView&, const ContextType&): returns the optimal solution for the leaf node defined by the given data and context. Only define this ifcustom_leafisTrue.
Related to testing constraint satisfaction (only for when has_constraint is True):
bool SatisfiesConstraint(const Node<MyNewTask>& sol, const ContextType& context): returns true if the solution satisfies the constraint.
Related to score and solution values:
static SolType Add(const SolType left, const SolType right): return left + right. Idem forTestAddwhich adds values ofTestSolType, forAdd(const SolType left, const SolType right, SolType& out)which returns the value inout, and forSubtractwhich subtracts the values and returns the value throughout.static std::string SolToString(SolType val): returnsvalas a string.static std::string ScoreToString(double val): returnsvalas a string. Note that the score is different from the solution value. E.g., the solution value is the misclassification score. The score is the accuracy.static bool CompareScore(SolType v1, SolType v2): return true ifv1is better thanv2.double ComputeTrainScore(SolType test_value): return the training score on the training data. Similarly,ComputeTrainTestScorecomputes the training score on the test data.ComputeTestTestScorecomputes the test score on the test data.
Related to the similarity lower bound:
SolType GetWorstPerLabel(LabelType label): Returns the worst contribution to the solution value a single instnace of the given label.
Related to the custom lower bound (only if custom_lower_bound is True)
SolContainer ComputeLowerBound(const ADataView& data, const Branch& branch, int max_depth, int num_nodes):
Related to the custom similarity lower bound (only if custom_similarity_lb is True):
PairWorstCount<MyNewTask> ComputeSimilarityLowerBound(const ADataView& data_old, const ADataView& data_new): returnsPairWorstCountthat has the subtracted LB ofSolTypeand the count of the number of differences.
Related to hyper-tuning:
static TuneRunConfiguration GetTuneRunConfiguration(const ParameterHandler& default_config, const ADataView& train_data, int phase): Get the tuning configuration from the given default configuration, for the given training data and the given tuning phase.
Some additional types, constants and functions need to be defined for depth-two solver.
SolD2Type: The type of solutions in the depth-two solver.BranchSolD2Type: The type of the branching costs in the depth-two solver.
use_terminal(bool):Trueif the task implements a depth-two solver.terminal_compute_context(bool):Trueif the context needs to be computed in the depth-two solver, e.g., for checking constraint satisfaction.terminal_filter(bool):Trueif the depth-two solver should filter non feasible solutions and solutions that are dominated by the upper bound. Default isFalse. Set toTrueif you think this will yield a performance increase.
void GetInstanceLeafD2Costs(const AInstance* instance, int org_label, int label, SolD2Type& costs, int multiplier): Store the costs of this instance with original labelorg_labelwhen it is assignedlabelas its label. Multiplier is either1or-1. Note that for tasks with a real label, bothorg_labelandlabelis always zero.void ComputeD2Costs(const SolD2Type& d2costs, int count, int label, SolType& costs): Compute the costs from the depth-two costs, for the given label (always zero for real labels).bool IsD2ZeroCost(const SolD2Type d2costs): ReturnTrueif the given costs are zero.BranchSolD2Type GetBranchingCosts(const ContextType& context, int feature): Get the branching costs in the given context (indepenent of the dataset).SolType ComputeD2BranchingCosts(const BranchSolD2Type& d2costs, int count): Get the solution value of the branching costs and the number of instances in the leaf node.SolLabelType GetLabel(const SolD2Type& costs, int count): Get the label that should be assigned to the leaf node for the given depth-two costs and count.
After all the function in C++ are defined, you can expose your new task to the python binding. First, adapt the bindings.cpp, then update the python files.
- Add
my_new_taskto theenumtask_type. - Add
my_new_taskto theget_task_type_codefunction whentask == "my-new-task". - Add
DefineSolver<MyNewTask>(m, "MyNewTask"); - Add a case
my_new_tasktoinitialize_sortd_solverand callSolver<MyNewTask>.
If your task is similar to one of the existing python classes (e.g., SORTDClassifier), add it to that class.
Otherwise, create a new python class that inherits from BaseSORTDSolver in pysortd/base.py.
- Define constraints on new parameters, using
_parameter_constraints. - Define the
__init__constructor and provide default parameters. Pass all paramters except newly defined parameters toBaseSORTDSolver. - Store new parameters in this class. Override the
_initialize_param_handlermethod.