diff --git a/src/learning/agglomerative.rs b/src/learning/agglomerative.rs index 8093ff9c..a5924ea7 100644 --- a/src/learning/agglomerative.rs +++ b/src/learning/agglomerative.rs @@ -5,12 +5,12 @@ //! # Usage //! //! ``` -//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics}; +//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage}; //! use rusty_machine::learning::SupModel; //! use rusty_machine::linalg::{Matrix, Vector}; //! //! let inputs = Matrix::new(4, 2, vec![1., 3., 2., 3., 4., 3., 5., 3.]); -//! let mut agg = AgglomerativeClustering::new(2, Metrics::Single); +//! let mut agg = AgglomerativeClustering::new(2, Linkage::Single); //! //! // Train the model and get the clustering result //! let res = agg.train(&inputs).unwrap(); @@ -27,7 +27,7 @@ use learning::{LearningResult}; /// Agglomerative clustering distances #[derive(Debug)] -pub enum Metrics { +pub enum Linkage { /// Single linkage clustering Single, /// Complete linkage clustering @@ -49,7 +49,7 @@ pub enum Metrics { Ward2, } -impl Metrics { +impl Linkage { // calculate distance using Lance-Williams algorithm fn dist(&self, ci: &Cluster, cj: &Cluster, ck: &Cluster, dmat: &DistanceMatrix) -> f64 { @@ -58,38 +58,38 @@ impl Metrics { let djk = dmat.get(ck.id, cj.id); match self { - &Metrics::Single => { + &Linkage::Single => { // 0.5 * dik + 0.5 * djk + 0. * dij - 0.5 * (dik - djk).abs() dik.min(djk) }, - &Metrics::Complete => { + &Linkage::Complete => { // 0.5 * dik + 0.5 * djk + 0. * dij + 0.5 * (dik - djk).abs() dik.max(djk) }, - &Metrics::Average => { + &Linkage::Average => { let s = ci.size + cj.size; ci.size / s * dik + cj.size / s * djk }, - &Metrics::Centroid => { + &Linkage::Centroid => { let s = ci.size + cj.size; let ai = ci.size / s; let aj = cj.size / s; let dij = dmat.get(ci.id, cj.id); ai * dik + aj * djk - ai * aj * dij }, - &Metrics::Median => { + &Linkage::Median => { let dij = dmat.get(ci.id, cj.id); 0.5 * dik + 0.5 * djk - 0.25 * dij }, - &Metrics::Ward1 => { + &Linkage::Ward1 => { let s = ci.size + cj.size + ck.size; let dij = dmat.get(ci.id, cj.id); (ci.size + ck.size) / s * dik + (cj.size + ck.size) / s * djk - ck.size / s * dij }, - &Metrics::Ward | &Metrics::Ward2 => { + &Linkage::Ward | &Linkage::Ward2 => { let s = ci.size + cj.size + ck.size; let dij = dmat.get(ci.id, cj.id); - ((ci.size + ck.size) / s * dik.powf(2.) + (cj.size + ck.size) / s * djk.powf(2.) - ck.size / s * dij.powf(2.)).sqrt() + ((ci.size + ck.size) / s * dik * dik + (cj.size + ck.size) / s * djk * djk - ck.size / s * dij * dij).sqrt() } } } @@ -147,11 +147,11 @@ impl DistanceMatrix { unsafe { for i in 0..n { - for j in i..inputs.rows() { + for j in (i + 1)..inputs.rows() { let mut val = 0.; for k in 0..inputs.cols() { - val += (inputs.get_unchecked([i, k]) - - inputs.get_unchecked([j, k])).abs().powf(2.); + let d = inputs.get_unchecked([i, k]) - inputs.get_unchecked([j, k]); + val += d * d; } val = val.sqrt(); data.insert((i, j), val); @@ -177,13 +177,13 @@ impl DistanceMatrix { /// Add distance between i-th and j-th item /// i must be smaller than j fn insert(&mut self, i: usize, j: usize, dist: f64) { - assert!(i < j, "i must be smaller than j"); + debug_assert!(i < j, "i must be smaller than j"); self.data.insert((i, j), dist); } /// Delete distance between i-th and j-th item fn delete(&mut self, i: usize, j: usize) { - assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0"); + debug_assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0"); if i > j { self.data.remove(&(j, i)); } else { @@ -196,7 +196,7 @@ impl DistanceMatrix { #[derive(Debug)] pub struct AgglomerativeClustering { n: usize, - method: Metrics, + linkage: Linkage, // internally stores distances / merged history (currently for testing) distances: Option>, @@ -208,19 +208,19 @@ impl AgglomerativeClustering { /// Constructs an untrained Decision Tree with specified /// /// - `n` - Number of clusters - /// - `method` - Distance metrics + /// - `linkage` - Linkage method /// /// # Examples /// /// ``` - /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics}; + /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage}; /// - /// let _ = AgglomerativeClustering::new(3, Metrics::Single); + /// let _ = AgglomerativeClustering::new(3, Linkage::Single); /// ``` - pub fn new(n: usize, method: Metrics) -> Self { + pub fn new(n: usize, linkage: Linkage) -> Self { AgglomerativeClustering { n: n, - method: method, + linkage: linkage, distances: None, merged: None @@ -269,7 +269,7 @@ impl AgglomerativeClustering { // update distances using Lance Williams algorithm for ck in clusters.iter() { - let d = self.method.dist(&ci, &cj, ck, &dmat); + let d = self.linkage.dist(&ci, &cj, ck, &dmat); dmat.insert(ck.id, id, d); // remove unnecessary distances @@ -301,7 +301,7 @@ impl AgglomerativeClustering { #[cfg(test)] mod tests { - use super::{AgglomerativeClustering, DistanceMatrix, Metrics}; + use super::{AgglomerativeClustering, DistanceMatrix, Linkage}; #[test] fn test_distance_matrix() { @@ -348,7 +348,7 @@ mod tests { 55., 65., 80., 75., 85.; 90., 85., 88., 92., 95.]; - let mut hclust = AgglomerativeClustering::new(1, Metrics::Single); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Single); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 28.478061731796284, 38.1051177665153, 47.10626285325551, 54.31390245600108]; @@ -356,7 +356,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Complete); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Complete); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 33.77869150810907, 45.58508528016593, 60.13318551349163, 91.53141537199127]; @@ -364,7 +364,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Average); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Average); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 31.128376619952675, 41.84510152334062, 53.305905710336944, 69.92295649225116]; @@ -372,7 +372,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Centroid); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Centroid); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045, 38.7426831118429, 44.021013600051624, 44.02758328256392]; @@ -380,7 +380,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Median); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Median); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045, 38.7426831118429, 45.898926771596045, 45.42216730738696]; @@ -388,7 +388,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward1); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward1); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 34.4020769090494, 51.65691081579053, 66.03152040007744, 150.95171411164773]; @@ -396,18 +396,18 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward2); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward2); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334, - 47.97916214358062, 62.481997407253225, 115.91869071527186]; + 47.97916214358062, 62.48199740725323, 115.91869071527186]; assert_eq!(hclust.distances.unwrap(), exp); let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334, - 47.97916214358062, 62.481997407253225, 115.91869071527186]; + 47.97916214358062, 62.48199740725323, 115.91869071527186]; assert_eq!(hclust.distances.unwrap(), exp); let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); diff --git a/tests/learning/agglomerative.rs b/tests/learning/agglomerative.rs index 247a5c92..6306f1ec 100644 --- a/tests/learning/agglomerative.rs +++ b/tests/learning/agglomerative.rs @@ -1,5 +1,5 @@ use rm::linalg::{Matrix, Vector}; -use rm::learning::agglomerative::{AgglomerativeClustering, Metrics}; +use rm::learning::agglomerative::{AgglomerativeClustering, Linkage}; #[test] fn test_cluster() { @@ -11,42 +11,42 @@ fn test_cluster() { 55., 65., 80., 75., 85., 90., 85., 88., 92., 95.]); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Single); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Single); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Complete); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Complete); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Average); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Average); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Centroid); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Centroid); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Median); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Median); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward1); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward1); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward2); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward2); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp);