Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions smarttree/_cy_column_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ cdef int CRITERION_GINI = 1
cdef class CyBaseColumnSplitter:

cdef int criterion
cdef object[:] y
cdef object[:] class_names
cdef int[:] y
cdef Py_ssize_t n_classes
cdef Py_ssize_t n_samples

def __cinit__(self, dataset: Dataset, criterion: Criterion) -> None:
self.criterion = criterion.value
self.y = dataset.y.values
self.class_names = dataset.class_names
self.y = dataset.y
self.n_classes = len(dataset.classes)
self.n_samples = len(dataset.y)

cdef double impurity(self, int8_t[:] mask):
if self.criterion == CRITERION_GINI:
Expand Down Expand Up @@ -76,12 +78,11 @@ cdef class CyBaseColumnSplitter:
]

cdef:
int i
Py_ssize_t n = len(parent_mask_arr)
Py_ssize_t i
long N = 0
long N_parent = 0
int8_t parent_mask_value
for i in range(n):
for i in range(self.n_samples):
N += 1
parent_mask_value = parent_mask_arr[i]
if parent_mask_value:
Expand All @@ -90,16 +91,17 @@ cdef class CyBaseColumnSplitter:
cdef double impurity_parent = self.impurity(parent_mask_arr)

cdef:
int j
Py_ssize_t j
Py_ssize_t n_childs = len(child_mask_arrs)
double weighted_impurity_childs = 0.0
long N_childs = 0
long N_child_j
int8_t child_mask_value
double impurity_child_i
for j in range(len(child_mask_arrs)):
for j in range(n_childs):
N_child_j = 0
child_mask_arr = child_mask_arrs[j]
for i in range(n):
for i in range(self.n_samples):
child_mask_value = child_mask_arr[i]
if child_mask_value:
N_child_j += 1
Expand Down Expand Up @@ -129,30 +131,28 @@ cdef class CyBaseColumnSplitter:
p_i - the probability of choosing a sample with class i.
"""
cdef:
int i
Py_ssize_t n = len(mask)
Py_ssize_t i
int8_t mask_value
long N = 0
for i in range(n):
for i in range(self.n_samples):
mask_value = mask[i]
if mask_value:
N += 1

cdef:
int j
cdef long N_i
cdef object class_name, label
Py_ssize_t j
long N_i
int encoded_label
double p_i = 0.0
gini_index = 1.0
for j in range(len(self.class_names)):
double gini_index = 1.0
for j in range(self.n_classes):
N_i = 0
class_name = self.class_names[j]

for i in range(n):
for i in range(self.n_samples):
mask_value = mask[i]
if mask_value:
label = self.y[i]
if label == class_name:
encoded_label = self.y[i]
if encoded_label == j:
N_i += 1

p_i = <double>N_i / <double>N
Expand All @@ -172,30 +172,28 @@ cdef class CyBaseColumnSplitter:
p_i - probability of the i-th system state.
"""
cdef:
int i
Py_ssize_t n = len(mask)
Py_ssize_t i
int8_t mask_value
long N = 0
for i in range(n):
for i in range(self.n_samples):
mask_value = mask[i]
if mask_value:
N += 1

cdef:
int j
Py_ssize_t j
long N_i = 0
object class_name, label
int encoded_label
double p_i = 0.0
entropy = 0.0
for j in range(len(self.class_names)):
double entropy = 0.0
for j in range(self.n_classes):
N_i = 0
class_name = self.class_names[j]

for i in range(n):
for i in range(self.n_samples):
mask_value = mask[i]
if mask_value:
label = self.y[i]
if label == class_name:
encoded_label = self.y[i]
if encoded_label == j:
N_i += 1

if N_i != 0:
Expand Down
20 changes: 6 additions & 14 deletions smarttree/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from dataclasses import dataclass, field

import numpy as np
import pandas as pd
from numpy.typing import NDArray


@dataclass
class Dataset:

X: pd.DataFrame
y: pd.Series
class_names: NDArray = field(init=False)
has_na: dict[str, bool] = field(init=False)
mask_na: dict[str, pd.Series] = field(init=False)

def __post_init__(self) -> None:
self.class_names = np.sort(self.y.unique())
self.has_na = dict()
self.mask_na = dict()
def __init__(self, X: pd.DataFrame, y: pd.Series) -> None:
self.X = X
self.classes = np.sort(y.unique())
self.y = np.searchsorted(self.classes, y.to_numpy()).astype(np.int32)
self.has_na: dict[str, bool] = dict()
self.mask_na: dict[str, pd.Series] = dict()
for column in self.X.columns:
mask_na = self.X[column].isna()
has_na = mask_na.any()
Expand Down
4 changes: 2 additions & 2 deletions tests/column_splitter/test__cy_base_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test__gini_index(dataset):
dataset=dataset, criterion=Criterion.GINI
)

mask = dataset.y.apply(lambda x: True).values.astype(np.int8)
mask = np.ones(dataset.y.shape, dtype=np.int8)

gini_index = cy_base_column_splitter.gini_index(mask)
assert gini_index == 0.6666591342419322
Expand All @@ -22,7 +22,7 @@ def test__entropy(dataset):
dataset=dataset, criterion=Criterion.ENTROPY
)

mask = dataset.y.apply(lambda x: True).values.astype(np.int8)
mask = np.ones(dataset.y.shape, dtype=np.int8)

gini_index = cy_base_column_splitter.entropy(mask)
assert gini_index == 1.584946181877191
Loading