Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions smarttree/_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np
from numpy.typing import NDArray

from ._cy_column_splitter import CyBaseColumnSplitter
from ._criterion import ClassificationCriterion, Entropy, Gini
from ._dataset import Dataset
from ._tree import TreeNode
from ._types import ClassificationCriterionType, Criterion, NaModeType
from ._types import ClassificationCriterionType, NaModeType


NO_INFORMATION_GAIN = float("-inf")
Expand All @@ -35,12 +35,6 @@ def no_split(cls) -> ColumnSplitResult:

class BaseColumnSplitter(ABC):

mapping: dict[ClassificationCriterionType, Criterion] = {
"gini": Criterion.GINI,
"entropy": Criterion.ENTROPY,
"log_loss": Criterion.LOG_LOSS,
}

def __init__(
self,
dataset: Dataset,
Expand All @@ -51,7 +45,11 @@ def __init__(
) -> None:

self.dataset = dataset
self.criterion = self.mapping[criterion]
self.criterion: ClassificationCriterion
if criterion == "gini":
self.criterion = Gini(dataset)
else: # "entropy" | "log_loss"
self.criterion = Entropy(dataset)
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.feature_na_mode = feature_na_mode
Expand Down Expand Up @@ -166,8 +164,7 @@ def information_gain(
\item $\text{impurity}_{\text{child}_i}$ — child node impurity.
\end{itemize}
"""
cs = CyBaseColumnSplitter(self.dataset, self.criterion)
return cs.information_gain(parent_mask, child_masks, normalize)
return self.criterion.impurity_decrease(parent_mask, child_masks, normalize)


class NumColumnSplitter(BaseColumnSplitter):
Expand Down
6 changes: 6 additions & 0 deletions smarttree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ cdef class ClassificationCriterion:
cdef Py_ssize_t n_classes
cdef Py_ssize_t n_samples

cpdef double impurity_decrease(
self,
cnp.npy_bool[:] parent_mask,
list[cnp.npy_bool[:]] child_masks,
bint normalize,
)
cpdef cnp.int64_t[:] distribution(self, cnp.npy_bool[:] mask)


Expand Down
42 changes: 42 additions & 0 deletions smarttree/_criterion.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,48 @@ class ClassificationCriterion(ABC):
def __init__(self, dataset: Dataset) -> None:
...

def impurity_decrease(
self,
parent_mask: NDArray[np.bool_],
child_masks: list[NDArray[np.bool_]],
normalize: bool = False,
) -> float:
r"""
Calculates information gain of the split.

Parameters:
parent_mask: pd.Series
boolean mask of parent node.
child_masks: pd.Series
list of boolean masks of child nodes.
normalize: bool, default=False
if True, normalizes information gain by split factor to handle
unbalanced splits. Uses child node counts for normalization.

Returns:
float: information gain.

Formula in LaTeX:
\begin{align*}
\text{Information Gain} =
\frac{N_{\text{parent}}}{N} \cdot
\Biggl( & \text{impurity}_{\text{parent}} - \\
& \sum^C_{i=1} \frac{N_{\text{child}_i}}{N_{\text{parent}}}
\cdot \text{impurity}_{\text{child}_i} \Biggr)
\end{align*}
where:
\begin{itemize}
\item $\text{Information Gain}$ — information gain;
\item $N$ — number of samples in entire training set;
\item $N_{\text{parent}}$ — number of samples in parent node;
\item $\text{impurity}_{\text{parent}}$ — parent node impurity;
\item $C$ — number of child nodes;
\item $N_{\text{child}_i}$ — number of samples in child node;
\item $\text{impurity}_{\text{child}_i}$ — child node impurity.
\end{itemize}
"""
...

@abstractmethod
def impurity(self, mask: NDArray[np.bool_]) -> float:
raise NotImplementedError
Expand Down
43 changes: 43 additions & 0 deletions smarttree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,49 @@ cdef class ClassificationCriterion:
self.n_classes = len(dataset.classes)
self.n_samples = len(dataset.y)

cpdef double impurity_decrease(
self,
cnp.npy_bool[:] parent_mask,
list[cnp.npy_bool[:]] child_masks,
bint normalize,
):

cdef Py_ssize_t i, j, n_childs
cdef long N, N_parent, N_childs, N_child_j
cdef double impurity_parent, weighted_impurity_childs, impurity_child_i, norm_coef, local_information_gain, information_gain

N = 0
N_parent = 0
for i in range(self.n_samples):
N += 1
if parent_mask[i]:
N_parent += 1

impurity_parent = self.impurity(parent_mask)

N_childs = 0
n_childs = len(child_masks)
weighted_impurity_childs = 0.0
for j in range(n_childs):
N_child_j = 0
child_mask = child_masks[j]
for i in range(self.n_samples):
if child_mask[i]:
N_child_j += 1
N_childs += N_child_j
impurity_child_i = self.impurity(child_mask)
weighted_impurity_childs += (<double>N_child_j / <double>N_parent) * impurity_child_i

if normalize:
norm_coef = <double>N_parent / <double>N_childs
weighted_impurity_childs *= norm_coef

local_information_gain = impurity_parent - weighted_impurity_childs

information_gain = (<double>N_parent / <double>N) * local_information_gain

return information_gain

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cnp.int64_t[:] distribution(self, cnp.npy_bool[:] mask):
Expand Down
11 changes: 0 additions & 11 deletions smarttree/_cy_column_splitter.pxd

This file was deleted.

52 changes: 0 additions & 52 deletions smarttree/_cy_column_splitter.pyi

This file was deleted.

70 changes: 0 additions & 70 deletions smarttree/_cy_column_splitter.pyx

This file was deleted.

7 changes: 0 additions & 7 deletions smarttree/_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import Enum
from typing import Literal


Expand All @@ -12,9 +11,3 @@
VerboseType = Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] | int

SplitType = Literal["numerical", "categorical", "rank"]


class Criterion(Enum):
GINI = 1
ENTROPY = 2
LOG_LOSS = 2