Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions build-extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from setuptools import Extension
from setuptools.command.build_ext import build_ext

import numpy as np


if sys.platform == "win32":
COMPILE_ARGS = ["/O2", "/fp:fast"]
LINK_ARGS = []
INCLUDE_DIRS = []
LIBRARIES = []
else:
COMPILE_ARGS = ["-march=native", "-O3", "-msse", "-msse2", "-mfma", "-mfpmath=sse"]
LINK_ARGS = []
INCLUDE_DIRS = []
LIBRARIES = ["m"]
LINK_ARGS = []
INCLUDE_DIRS = [np.get_include()]


def build() -> None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["poetry-core", "cython", "setuptools"]
requires = ["poetry-core", "setuptools", "cython", "numpy"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
Expand Down
5 changes: 2 additions & 3 deletions smarttree/_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import bisect

import numpy as np
import pandas as pd

from ._criterion import ClassificationCriterion, Entropy, Gini
Expand Down Expand Up @@ -45,7 +44,7 @@ def build(self, tree: Tree) -> None:
self.available_features.remove(value)

mask = self.y.apply(lambda x: True)
mask_np = mask.to_numpy(dtype=np.int8)
mask_np = mask.to_numpy()
root = tree.create_node(
mask=mask,
hierarchy=self.hierarchy,
Expand Down Expand Up @@ -76,7 +75,7 @@ def build(self, tree: Tree) -> None:
else: # str
node.available_features.append(value)

child_mask_np = child_mask.to_numpy(dtype=np.int8)
child_mask_np = child_mask.to_numpy()
child_node = tree.create_node(
mask=child_mask,
hierarchy=node.hierarchy,
Expand Down
22 changes: 17 additions & 5 deletions smarttree/_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def __init__(
def split(self, *args, **kwargs) -> ColumnSplitResult:
raise NotImplementedError

def pre_information_gain(
self,
parent_mask: pd.Series,
child_masks: list[pd.Series],
) -> tuple[NDArray[np.bool_], list[NDArray[np.bool_]]]:
parent_mask_np = parent_mask.to_numpy()
child_masks_np = [child_mask.to_numpy() for child_mask in child_masks]
return parent_mask_np, child_masks_np

def foo(
self,
parent_mask: pd.Series,
Expand All @@ -78,7 +87,8 @@ def foo(
else:
assert False
else:
information_gain = self.information_gain(parent_mask, child_masks)
parent_mask_np, child_masks_np = self.pre_information_gain(parent_mask, child_masks)
information_gain = self.information_gain(parent_mask_np, child_masks_np)
return information_gain, child_masks, -1

def include_all_split(
Expand All @@ -93,7 +103,8 @@ def include_all_split(
if child_masks[i].sum() < self.min_samples_leaf:
return NO_INFORMATION_GAIN, [], -1

information_gain = self.information_gain(parent_mask, child_masks, normalize=True)
parent_mask_np, child_masks_np = self.pre_information_gain(parent_mask, child_masks)
information_gain = self.information_gain(parent_mask_np, child_masks_np, normalize=True)

return information_gain, child_masks, -1

Expand All @@ -119,7 +130,8 @@ def include_best_split(
best_child_masks = []
best_child_na_index = -1
for child_na_index, child_masks in enumerate(candidates):
information_gain = self.information_gain(parent_mask, child_masks)
parent_mask_np, child_masks_np = self.pre_information_gain(parent_mask, child_masks)
information_gain = self.information_gain(parent_mask_np, child_masks_np)
if best_information_gain < information_gain:
best_information_gain = information_gain
best_child_masks = child_masks
Expand All @@ -129,8 +141,8 @@ def include_best_split(

def information_gain(
self,
parent_mask: pd.Series,
child_masks: list[pd.Series],
parent_mask: NDArray[np.bool_],
child_masks: list[NDArray[np.bool_]],
normalize: bool = False,
) -> float:
r"""
Expand Down
10 changes: 5 additions & 5 deletions smarttree/_criterion.pxd
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from libc.stdint cimport int8_t
cimport numpy as cnp


cdef class ClassificationCriterion:

cdef int[:] y
cdef cnp.int64_t[:] y
cdef Py_ssize_t n_classes
cdef Py_ssize_t n_samples

cpdef long[:] distribution(self, int8_t[:] mask)
cpdef cnp.int64_t[:] distribution(self, cnp.npy_bool[:] mask)


cdef class Gini(ClassificationCriterion):
cpdef double impurity(self, int8_t[:] mask)
cpdef double impurity(self, cnp.npy_bool[:] mask)

cdef class Entropy(ClassificationCriterion):
cpdef double impurity(self, int8_t[:] mask)
cpdef double impurity(self, cnp.npy_bool[:] mask)
8 changes: 4 additions & 4 deletions smarttree/_criterion.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ class ClassificationCriterion(ABC):
...

@abstractmethod
def impurity(self, mask: NDArray[np.int8]) -> float:
def impurity(self, mask: NDArray[np.bool_]) -> float:
raise NotImplementedError

def distribution(self, mask: NDArray[np.int8]) -> NDArray[np.int32]:
def distribution(self, mask: NDArray[np.bool_]) -> NDArray[np.int64]:
...


class Gini(ClassificationCriterion):

def impurity(self, mask: NDArray[np.int8]) -> float:
def impurity(self, mask: NDArray[np.bool_]) -> float:
r"""
Calculates Gini index in a tree node.

Expand All @@ -35,7 +35,7 @@ class Gini(ClassificationCriterion):

class Entropy(ClassificationCriterion):

def impurity(self, mask: NDArray[np.int8]) -> float:
def impurity(self, mask: NDArray[np.bool_]) -> float:
r"""
Calculates entropy in a tree node.

Expand Down
47 changes: 29 additions & 18 deletions smarttree/_criterion.pyx
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
cimport cython
from libc.math cimport log2
from libc.stdint cimport int8_t

import numpy as np
cimport numpy as cnp

from ._dataset import Dataset


cnp.import_array()


cdef class ClassificationCriterion:

def __cinit__(self, dataset: Dataset) -> None:
Expand All @@ -16,15 +19,19 @@ cdef class ClassificationCriterion:

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef long[:] distribution(self, int8_t[:] mask):
cpdef cnp.int64_t[:] distribution(self, cnp.npy_bool[:] mask):

cdef Py_ssize_t i
cdef long[:] result
cdef cnp.int64_t[:] result
cdef cnp.npy_bool mask_value
cdef cnp.int64_t class_index

result = np.zeros(self.n_classes, dtype=np.int32)
result = np.zeros(self.n_classes, dtype=np.int64)
for i in range(self.n_samples):
if mask[i]:
result[self.y[i]] += 1
mask_value = mask[i]
if mask_value:
class_index = self.y[i]
result[class_index] += 1

return result

Expand All @@ -34,22 +41,24 @@ cdef class Gini(ClassificationCriterion):
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cpdef double impurity(self, int8_t[:] mask):
cpdef double impurity(self, cnp.npy_bool[:] mask):

cdef Py_ssize_t i
cdef long[:] distribution
cdef long N
cdef cnp.int64_t[:] distribution
cdef cnp.int64_t N, count
cdef double p_i, gini

distribution = self.distribution(mask)
N = 0
for i in range(self.n_classes):
N += distribution[i]
count = distribution[i]
N += count

gini = 1.0
for i in range(self.n_classes):
if distribution[i] > 0:
p_i = <double>distribution[i] / <double>N
count = distribution[i]
if count > 0:
p_i = <double>count / <double>N
gini -= p_i * p_i

return gini
Expand All @@ -60,22 +69,24 @@ cdef class Entropy(ClassificationCriterion):
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cpdef double impurity(self, int8_t[:] mask):
cpdef double impurity(self, cnp.npy_bool[:] mask):

cdef Py_ssize_t i
cdef long[:] distribution
cdef long N
cdef cnp.int64_t[:] distribution
cdef cnp.int64_t N, count
cdef double p_i, gini

distribution = self.distribution(mask)
N = 0
for i in range(self.n_classes):
N += distribution[i]
count = distribution[i]
N += count

entropy = 0.0
for i in range(self.n_classes):
if distribution[i] > 0:
p_i = <double>distribution[i] / <double>N
count = distribution[i]
if count > 0:
p_i = <double>count / <double>N
entropy -= p_i * log2(p_i)

return entropy
4 changes: 2 additions & 2 deletions smarttree/_cy_column_splitter.pxd
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from libc.stdint cimport int8_t
cimport numpy as cnp

from ._criterion cimport ClassificationCriterion


cdef class CyBaseColumnSplitter:

cdef ClassificationCriterion criterion
cdef int[:] y
cdef cnp.int64_t[:] y
cdef Py_ssize_t n_classes
cdef Py_ssize_t n_samples
7 changes: 4 additions & 3 deletions smarttree/_cy_column_splitter.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
import numpy as np
from numpy.typing import NDArray

from ._dataset import Dataset
from ._types import Criterion
Expand All @@ -10,8 +11,8 @@ class CyBaseColumnSplitter:

def information_gain(
self,
parent_mask: pd.Series,
child_masks: list[pd.Series],
parent_mask: NDArray[np.bool_],
child_masks: list[NDArray[np.bool_]],
normalize: bool = False,
) -> float:
r"""
Expand Down
30 changes: 13 additions & 17 deletions smarttree/_cy_column_splitter.pyx
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
cimport cython
from libc.stdint cimport int8_t

import numpy as np
cimport numpy as cnp
import pandas as pd

from ._dataset import Dataset
from ._types import Criterion
from ._criterion cimport Entropy, Gini


cnp.import_array()

cdef int CRITERION_GINI = 1


Expand All @@ -25,41 +27,35 @@ cdef class CyBaseColumnSplitter:

def information_gain(
self,
parent_mask: pd.Series,
child_masks: list[pd.Series],
normalize: bool = False,
) -> float:
cnp.npy_bool[:] parent_mask,
list[cnp.npy_bool[:]] child_masks,
bint normalize,
):

cdef int8_t[:] parent_mask_arr, child_mask_arr
cdef Py_ssize_t i, j, n_childs
cdef long N, N_parent, N_childs, N_child_j
cdef double impurity_parent, weighted_impurity_childs, impurity_child_i

parent_mask_arr = parent_mask.to_numpy(dtype=np.int8)
child_mask_arrs = [
child_mask.to_numpy(dtype=np.int8) for child_mask in child_masks
]

N = 0
N_parent = 0
for i in range(self.n_samples):
N += 1
if parent_mask_arr[i]:
if parent_mask[i]:
N_parent += 1

impurity_parent = self.criterion.impurity(parent_mask_arr)
impurity_parent = self.criterion.impurity(parent_mask)

N_childs = 0
n_childs = len(child_mask_arrs)
n_childs = len(child_masks)
weighted_impurity_childs = 0.0
for j in range(n_childs):
N_child_j = 0
child_mask_arr = child_mask_arrs[j]
child_mask = child_masks[j]
for i in range(self.n_samples):
if child_mask_arr[i]:
if child_mask[i]:
N_child_j += 1
N_childs += N_child_j
impurity_child_i = self.criterion.impurity(child_mask_arr)
impurity_child_i = self.criterion.impurity(child_mask)
weighted_impurity_childs += (<double>N_child_j / <double>N_parent) * impurity_child_i

cdef double norm_coef
Expand Down
2 changes: 1 addition & 1 deletion smarttree/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Dataset:
def __init__(self, X: pd.DataFrame, y: pd.Series) -> None:
self.X = X
self.classes = np.sort(y.unique())
self.y = np.searchsorted(self.classes, y.to_numpy()).astype(np.int32)
self.y = np.searchsorted(self.classes, y.to_numpy()).astype(np.int64)
self.has_na: dict[str, bool] = dict()
self.mask_na: dict[str, pd.Series] = dict()
for column in self.X.columns:
Expand Down
4 changes: 4 additions & 0 deletions tests/column_splitter/test__base_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def test__information_gain(concrete_column_splitter, y):
right_child_mask = ~left_child_mask

child_masks = [left_child_mask, right_child_mask]

parent_mask = parent_mask.to_numpy()
child_masks = [child_mask.to_numpy() for child_mask in child_masks]

inf_gain = concrete_column_splitter.information_gain(parent_mask, child_masks)

assert inf_gain == 0.0016794443115909496
Loading
Loading