Skip to content

Commit d4b4e67

Browse files
optimized (#108)
1 parent 38d7404 commit d4b4e67

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

smarttree/_builder.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,15 @@ def build(self, tree: Tree) -> None:
100100
tree.leaf_counter -= 1
101101

102102
def distribution(self, mask: pd.Series) -> NDArray[np.integer]:
103-
return np.array([
104-
(mask & (self.y == class_name)).sum() for class_name in self.class_names
105-
])
103+
104+
mask_arr = mask.to_numpy()
105+
y_arr = self.y.to_numpy()
106+
107+
result = np.zeros(len(self.class_names), dtype=np.int32)
108+
for i, class_name in enumerate(self.class_names):
109+
result[i] = np.sum(mask_arr & (y_arr == class_name))
110+
111+
return result
106112

107113
def gini_index(self, mask: pd.Series) -> float:
108114
r"""

0 commit comments

Comments
 (0)