We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 38d7404 commit d4b4e67Copy full SHA for d4b4e67
1 file changed
smarttree/_builder.py
@@ -100,9 +100,15 @@ def build(self, tree: Tree) -> None:
100
tree.leaf_counter -= 1
101
102
def distribution(self, mask: pd.Series) -> NDArray[np.integer]:
103
- return np.array([
104
- (mask & (self.y == class_name)).sum() for class_name in self.class_names
105
- ])
+
+ mask_arr = mask.to_numpy()
+ y_arr = self.y.to_numpy()
106
107
+ result = np.zeros(len(self.class_names), dtype=np.int32)
108
+ for i, class_name in enumerate(self.class_names):
109
+ result[i] = np.sum(mask_arr & (y_arr == class_name))
110
111
+ return result
112
113
def gini_index(self, mask: pd.Series) -> float:
114
r"""
0 commit comments