Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 12 additions & 48 deletions smarttree/_builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import bisect
import math

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from ._criterion import ClassificationCriterion, Entropy, Gini
from ._dataset import Dataset
from ._node_splitter import NodeSplitter
from ._tree import Tree, TreeNode
from ._types import ClassificationCriterionType
Expand All @@ -21,19 +22,14 @@ def __init__(
hierarchy: dict[str, str | list[str]],
) -> None:

self.X = X
self.available_features = X.columns.to_list()
self.y = y
self.criterion = criterion
self.splitter = splitter
self.max_leaf_nodes = max_leaf_nodes
self.hierarchy = hierarchy

match self.criterion:
case "gini":
self.impurity = self.gini_index
case "entropy" | "log_loss":
self.impurity = self.entropy

if self.criterion in ("gini", "entropy", "log_loss"):
self.class_names = np.sort(self.y.unique())

Expand Down Expand Up @@ -110,44 +106,12 @@ def distribution(self, mask: pd.Series) -> NDArray[np.integer]:

return result

def gini_index(self, mask: pd.Series) -> float:
r"""
Calculates Gini index in a tree node.

Gini index formula in LaTeX:
\text{Gini Index} = 1 - \sum^C_{i=1} p_i^2
where
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
N = mask.sum()

gini_index = 1
for label in self.class_names:
N_i = (mask & (self.y == label)).sum()
p_i = N_i / N
gini_index -= pow(p_i, 2)

return gini_index

def entropy(self, mask: pd.Series) -> float:
r"""
Calculates entropy in a tree node.

Entropy formula in LaTeX:
H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i}
where
H - entropy;
\overline{N} - effective number of states;
p_i - probability of the i-th system state.
"""
N = mask.sum()

entropy = 0
for label in self.class_names:
N_i = (mask & (self.y == label)).sum()
if N_i != 0:
p_i = N_i / N
entropy -= p_i * math.log2(p_i)

return entropy
def impurity(self, mask: pd.Series) -> float:

criterion: ClassificationCriterion
if self.criterion == "gini":
criterion = Gini(Dataset(self.X, self.y))
else: # "entropy" | "log_loss"
criterion = Entropy(Dataset(self.X, self.y))

return criterion.impurity(mask.to_numpy(dtype=np.int8))
15 changes: 15 additions & 0 deletions smarttree/_criterion.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from libc.stdint cimport int8_t


cdef class ClassificationCriterion:

cdef int[:] y
cdef Py_ssize_t n_classes
cdef Py_ssize_t n_samples


cdef class Gini(ClassificationCriterion):
cpdef double impurity(self, int8_t[:] mask)

cdef class Entropy(ClassificationCriterion):
cpdef double impurity(self, int8_t[:] mask)
46 changes: 46 additions & 0 deletions smarttree/_criterion.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod

import numpy as np
from numpy.typing import NDArray

from ._dataset import Dataset

class ClassificationCriterion(ABC):

def __init__(self, dataset: Dataset) -> None:
...

@abstractmethod
def impurity(self, mask: NDArray[np.int8]) -> float:
raise NotImplementedError


class Gini(ClassificationCriterion):

def impurity(self, mask: NDArray[np.int8]) -> float:
r"""
Calculates Gini index in a tree node.

Gini index formula in LaTeX:
\text{Gini Index} = 1 - \sum^C_{i=1} p_i^2
where
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
...


class Entropy(ClassificationCriterion):

def impurity(self, mask: NDArray[np.int8]) -> float:
r"""
Calculates entropy in a tree node.

Entropy formula in LaTeX:
H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i}
where
H - entropy;
\overline{N} - effective number of states;
p_i - probability of the i-th system state.
"""
...
71 changes: 71 additions & 0 deletions smarttree/_criterion.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
cimport cython
from libc.math cimport log2
from libc.stdint cimport int8_t

import numpy as np

from ._dataset import Dataset


cdef class ClassificationCriterion:

def __cinit__(self, dataset: Dataset) -> None:
self.y = dataset.y
self.n_classes = len(dataset.classes)
self.n_samples = len(dataset.y)


cdef class Gini(ClassificationCriterion):

@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cpdef double impurity(self, int8_t[:] mask):

cdef Py_ssize_t i
cdef long[:] counts
cdef long N
cdef double p_i, gini

counts = np.zeros(self.n_classes, dtype=np.int32)
N = 0
for i in range(self.n_samples):
if mask[i]:
N += 1
counts[self.y[i]] += 1

gini = 1.0
for i in range(self.n_classes):
if counts[i] > 0:
p_i = <double>counts[i] / <double>N
gini -= p_i * p_i

return gini


cdef class Entropy(ClassificationCriterion):

@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cpdef double impurity(self, int8_t[:] mask):

cdef Py_ssize_t i
cdef long[:] counts
cdef long N
cdef double p_i, entropy

counts = np.zeros(self.n_classes, dtype=np.int32)
N = 0
for i in range(self.n_samples):
if mask[i]:
N += 1
counts[self.y[i]] += 1

entropy = 0.0
for i in range(self.n_classes):
if counts[i] > 0:
p_i = <double>counts[i] / <double>N
entropy -= p_i * log2(p_i)

return entropy
8 changes: 3 additions & 5 deletions smarttree/_cy_column_splitter.pxd
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from libc.stdint cimport int8_t

from ._criterion cimport ClassificationCriterion


cdef class CyBaseColumnSplitter:

cdef int criterion
cdef ClassificationCriterion criterion
cdef int[:] y
cdef Py_ssize_t n_classes
cdef Py_ssize_t n_samples

cdef double impurity(self, int8_t[:] mask)
cpdef double gini_index(self, int8_t[:] mask)
cpdef double entropy(self, int8_t[:] mask)
27 changes: 0 additions & 27 deletions smarttree/_cy_column_splitter.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
import pandas as pd
from numpy.typing import NDArray

from ._dataset import Dataset
from ._types import Criterion
Expand Down Expand Up @@ -51,28 +49,3 @@ class CyBaseColumnSplitter:
\end{itemize}
"""
...

def gini_index(self, mask: NDArray[np.int8]) -> float:
r"""
Calculates Gini index in a tree node.

Gini index formula in LaTeX:
\text{Gini Index} = 1 - \sum^C_{i=1} p_i^2
where
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
...

def entropy(self, mask: NDArray[np.int8]) -> float:
r"""
Calculates entropy in a tree node.

Entropy formula in LaTeX:
H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i}
where
H - entropy;
\overline{N} - effective number of states;
p_i - probability of the i-th system state.
"""
...
71 changes: 9 additions & 62 deletions smarttree/_cy_column_splitter.pyx
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
cimport cython
from libc.math cimport log2
from libc.stdint cimport int8_t

import numpy as np
import pandas as pd

from ._dataset import Dataset
from ._types import Criterion
from ._criterion cimport Entropy, Gini


cdef int CRITERION_GINI = 1
Expand All @@ -15,10 +15,13 @@ cdef int CRITERION_GINI = 1
cdef class CyBaseColumnSplitter:

def __cinit__(self, dataset: Dataset, criterion: Criterion) -> None:
self.criterion = criterion.value
self.y = dataset.y
self.n_classes = len(dataset.classes)
self.n_samples = len(dataset.y)
if criterion.value == CRITERION_GINI:
self.criterion = Gini(dataset)
else:
self.criterion = Entropy(dataset)

def information_gain(
self,
Expand All @@ -32,9 +35,9 @@ cdef class CyBaseColumnSplitter:
cdef long N, N_parent, N_childs, N_child_j
cdef double impurity_parent, weighted_impurity_childs, impurity_child_i

parent_mask_arr = parent_mask.values.astype(np.int8)
parent_mask_arr = parent_mask.to_numpy(dtype=np.int8)
child_mask_arrs = [
child_mask.values.astype(np.int8) for child_mask in child_masks
child_mask.to_numpy(dtype=np.int8) for child_mask in child_masks
]

N = 0
Expand All @@ -44,7 +47,7 @@ cdef class CyBaseColumnSplitter:
if parent_mask_arr[i]:
N_parent += 1

impurity_parent = self.impurity(parent_mask_arr)
impurity_parent = self.criterion.impurity(parent_mask_arr)

N_childs = 0
n_childs = len(child_mask_arrs)
Expand All @@ -56,7 +59,7 @@ cdef class CyBaseColumnSplitter:
if child_mask_arr[i]:
N_child_j += 1
N_childs += N_child_j
impurity_child_i = self.impurity(child_mask_arr)
impurity_child_i = self.criterion.impurity(child_mask_arr)
weighted_impurity_childs += (<double>N_child_j / <double>N_parent) * impurity_child_i

cdef double norm_coef
Expand All @@ -69,59 +72,3 @@ cdef class CyBaseColumnSplitter:
cdef double information_gain = (<double>N_parent / <double>N) * local_information_gain

return information_gain

cdef double impurity(self, int8_t[:] mask):
if self.criterion == CRITERION_GINI:
return self.gini_index(mask)
else:
return self.entropy(mask)

@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cpdef double gini_index(self, int8_t[:] mask):

cdef Py_ssize_t i
cdef long[:] counts
cdef long N
cdef double p_i, gini

counts = np.zeros(self.n_classes, dtype=np.int32)
N = 0
for i in range(self.n_samples):
if mask[i]:
N += 1
counts[self.y[i]] += 1

gini = 1.0
for i in range(self.n_classes):
if counts[i] > 0:
p_i = <double>counts[i] / <double>N
gini -= p_i * p_i

return gini

@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cpdef double entropy(self, int8_t[:] mask):

cdef Py_ssize_t i
cdef long[:] counts
cdef long N
cdef double p_i, entropy

counts = np.zeros(self.n_classes, dtype=np.int32)
N = 0
for i in range(self.n_samples):
if mask[i]:
N += 1
counts[self.y[i]] += 1

entropy = 0.0
for i in range(self.n_classes):
if counts[i] > 0:
p_i = <double>counts[i] / <double>N
entropy -= p_i * log2(p_i)

return entropy
Loading
Loading