Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/ensemble/base_forest_regressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,31 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(

let samples = match &self.samples {
Some(s) => s,
None => {
return Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
}
};

if samples[0].len() != n {
return Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
));
}

for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
let mut result = Y::zeros(n);

Ok(result)
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}

Ok(result)
}

fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
Expand Down
41 changes: 24 additions & 17 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,27 +539,34 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(

let samples = match &self.samples {
Some(s) => s,
None => {
return Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
));
}
};

if samples[0].len() != n {
return Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
));
}

for i in 0..n {
result.set(
i,
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
);
}
Ok(result)
let mut result = Y::zeros(n);

for i in 0..n {
result.set(
i,
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
);
}

Ok(result)
}

fn predict_for_row_oob(&self, x: &X, row: usize) -> usize {
Expand Down
136 changes: 136 additions & 0 deletions src/metrics/distance/jaccard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//! # Jaccard Distance
//!
//! Jaccard Distance measures dissimilarity between two integer-valued vectors of the same length.
//! Given two vectors \\( x \in ℝ^n \\), \\( y \in ℝ^n \\) the Jaccard distance between \\( x \\) and \\( y \\) is defined as
//!
//! \\[ d(x, y) = 1 - \frac{|x \cap y|}{|x \cup y|} \\]
//!
//! where \\(|x \cap y|\\) is the number of positions where both vectors are non-zero,
//! and \\(|x \cup y|\\) is the number of positions where at least one of the vectors is non-zero.
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::jaccard::Jaccard;
//!
//! let a = vec![1, 0, 1, 1];
//! let b = vec![1, 1, 0, 1];
//!
//! let j: f64 = Jaccard::new().distance(&a, &b);
//!
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;

use super::Distance;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;

/// Jaccard distance between two integer-valued vectors
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Jaccard<T: Number> {
_t: PhantomData<T>,
}

impl<T: Number> Jaccard<T> {
/// instatiate the initial structure
pub fn new() -> Jaccard<T> {
Jaccard { _t: PhantomData }
}
}

impl<T: Number> Default for Jaccard<T> {
fn default() -> Self {
Self::new()
}
}

impl<T: Number, A: ArrayView1<T>> Distance<A> for Jaccard<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different");
}

let (intersection, union): (usize, usize) = x
.iterator(0)
.zip(y.iterator(0))
.map(|(a, b)| {
let a_nz = *a != T::zero();
let b_nz = *b != T::zero();

match (a_nz, b_nz) {
(true, true) => (1, 1),
(true, false) | (false, true) => (0, 1),
(false, false) => (0, 0),
}
})
.fold((0, 0), |acc, v| (acc.0 + v.0, acc.1 + v.1));

if union == 0 {
0.0
} else {
1.0 - intersection as f64 / union as f64
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]

#[test]
fn jaccard_distance() {
let a = vec![1, 0, 1, 1];
let b = vec![1, 1, 0, 1];

let j: f64 = Jaccard::new().distance(&a, &b);

assert!((j - 0.5).abs() < 1e-8);
}

#[test]
fn jaccard_identical_vectors() {
let a = vec![1, 0, 1, 0];
let b = vec![1, 0, 1, 0];

let j: f64 = Jaccard::new().distance(&a, &b);

assert!((j - 0.0).abs() < 1e-8);
}

#[test]
fn jaccard_both_zero_vectors() {
let a = vec![0, 0, 0];
let b = vec![0, 0, 0];

let j: f64 = Jaccard::new().distance(&a, &b);

assert!((j - 0.0).abs() < 1e-8);
}

#[test]
fn jaccard_symmetry() {
let a = vec![1, 0, 1, 1];
let b = vec![0, 1, 1, 0];

let j = Jaccard::new();

let d1 = j.distance(&a, &b);
let d2 = j.distance(&b, &a);

assert!((d1 - d2).abs() < 1e-12);
}
}

7 changes: 7 additions & 0 deletions src/metrics/distance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub mod cosine;
pub mod euclidian;
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
pub mod hamming;
/// Jaccard distance between two integer-valued vectors.
pub mod jaccard;
/// The Mahalanobis distance is the distance between two points in multivariate space.
pub mod mahalanobis;
/// Also known as rectilinear distance, city block distance, taxicab metric.
Expand Down Expand Up @@ -67,6 +69,11 @@ impl Distances {
hamming::Hamming::new()
}

/// Jaccard distance, see [`Jaccard`](jaccard/index.html)
pub fn jaccard<T: Number>() -> jaccard::Jaccard<T> {
jaccard::Jaccard::new()
}

/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
pub fn mahalanobis<T: Number, M: Array2<T>, C: Array2<f64> + LUDecomposable<f64>>(
data: &M,
Expand Down
Loading