Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Commit

Permalink
rename linkage, use debug_assert
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed Dec 20, 2016
1 parent 406cce0 commit 4c46d25
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 44 deletions.
70 changes: 35 additions & 35 deletions src/learning/agglomerative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
//! # Usage
//!
//! ```
//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics};
//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage};
//! use rusty_machine::learning::SupModel;
//! use rusty_machine::linalg::{Matrix, Vector};
//!
//! let inputs = Matrix::new(4, 2, vec![1., 3., 2., 3., 4., 3., 5., 3.]);
//! let mut agg = AgglomerativeClustering::new(2, Metrics::Single);
//! let mut agg = AgglomerativeClustering::new(2, Linkage::Single);
//!
//! // Train the model and get the clustering result
//! let res = agg.train(&inputs).unwrap();
Expand All @@ -27,7 +27,7 @@ use learning::{LearningResult};

/// Agglomerative clustering distances
#[derive(Debug)]
pub enum Metrics {
pub enum Linkage {
/// Single linkage clustering
Single,
/// Complete linkage clustering
Expand All @@ -49,7 +49,7 @@ pub enum Metrics {
Ward2,
}

impl Metrics {
impl Linkage {

// calculate distance using Lance-Williams algorithm
fn dist(&self, ci: &Cluster, cj: &Cluster, ck: &Cluster, dmat: &DistanceMatrix) -> f64 {
Expand All @@ -58,38 +58,38 @@ impl Metrics {
let djk = dmat.get(ck.id, cj.id);

match self {
&Metrics::Single => {
&Linkage::Single => {
// 0.5 * dik + 0.5 * djk + 0. * dij - 0.5 * (dik - djk).abs()
dik.min(djk)
},
&Metrics::Complete => {
&Linkage::Complete => {
// 0.5 * dik + 0.5 * djk + 0. * dij + 0.5 * (dik - djk).abs()
dik.max(djk)
},
&Metrics::Average => {
&Linkage::Average => {
let s = ci.size + cj.size;
ci.size / s * dik + cj.size / s * djk
},
&Metrics::Centroid => {
&Linkage::Centroid => {
let s = ci.size + cj.size;
let ai = ci.size / s;
let aj = cj.size / s;
let dij = dmat.get(ci.id, cj.id);
ai * dik + aj * djk - ai * aj * dij
},
&Metrics::Median => {
&Linkage::Median => {
let dij = dmat.get(ci.id, cj.id);
0.5 * dik + 0.5 * djk - 0.25 * dij
},
&Metrics::Ward1 => {
&Linkage::Ward1 => {
let s = ci.size + cj.size + ck.size;
let dij = dmat.get(ci.id, cj.id);
(ci.size + ck.size) / s * dik + (cj.size + ck.size) / s * djk - ck.size / s * dij
},
&Metrics::Ward | &Metrics::Ward2 => {
&Linkage::Ward | &Linkage::Ward2 => {
let s = ci.size + cj.size + ck.size;
let dij = dmat.get(ci.id, cj.id);
((ci.size + ck.size) / s * dik.powf(2.) + (cj.size + ck.size) / s * djk.powf(2.) - ck.size / s * dij.powf(2.)).sqrt()
((ci.size + ck.size) / s * dik * dik + (cj.size + ck.size) / s * djk * djk - ck.size / s * dij * dij).sqrt()
}
}
}
Expand Down Expand Up @@ -147,11 +147,11 @@ impl DistanceMatrix {

unsafe {
for i in 0..n {
for j in i..inputs.rows() {
for j in (i + 1)..inputs.rows() {
let mut val = 0.;
for k in 0..inputs.cols() {
val += (inputs.get_unchecked([i, k]) -
inputs.get_unchecked([j, k])).abs().powf(2.);
let d = inputs.get_unchecked([i, k]) - inputs.get_unchecked([j, k]);
val += d * d;
}
val = val.sqrt();
data.insert((i, j), val);
Expand All @@ -177,13 +177,13 @@ impl DistanceMatrix {
/// Add distance between i-th and j-th item
/// i must be smaller than j
fn insert(&mut self, i: usize, j: usize, dist: f64) {
assert!(i < j, "i must be smaller than j");
debug_assert!(i < j, "i must be smaller than j");
self.data.insert((i, j), dist);
}

/// Delete distance between i-th and j-th item
fn delete(&mut self, i: usize, j: usize) {
assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0");
debug_assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0");
if i > j {
self.data.remove(&(j, i));
} else {
Expand All @@ -196,7 +196,7 @@ impl DistanceMatrix {
#[derive(Debug)]
pub struct AgglomerativeClustering {
n: usize,
method: Metrics,
linkage: Linkage,

// internally stores distances / merged history (currently for testing)
distances: Option<Vec<f64>>,
Expand All @@ -208,19 +208,19 @@ impl AgglomerativeClustering {
/// Constructs an untrained Decision Tree with specified
///
/// - `n` - Number of clusters
/// - `method` - Distance metrics
/// - `linkage` - Linkage method
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics};
/// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage};
///
/// let _ = AgglomerativeClustering::new(3, Metrics::Single);
/// let _ = AgglomerativeClustering::new(3, Linkage::Single);
/// ```
pub fn new(n: usize, method: Metrics) -> Self {
pub fn new(n: usize, linkage: Linkage) -> Self {
AgglomerativeClustering {
n: n,
method: method,
linkage: linkage,

distances: None,
merged: None
Expand Down Expand Up @@ -269,7 +269,7 @@ impl AgglomerativeClustering {

// update distances using Lance Williams algorithm
for ck in clusters.iter() {
let d = self.method.dist(&ci, &cj, ck, &dmat);
let d = self.linkage.dist(&ci, &cj, ck, &dmat);
dmat.insert(ck.id, id, d);

// remove unnecessary distances
Expand Down Expand Up @@ -301,7 +301,7 @@ impl AgglomerativeClustering {
#[cfg(test)]
mod tests {

use super::{AgglomerativeClustering, DistanceMatrix, Metrics};
use super::{AgglomerativeClustering, DistanceMatrix, Linkage};

#[test]
fn test_distance_matrix() {
Expand Down Expand Up @@ -348,66 +348,66 @@ mod tests {
55., 65., 80., 75., 85.;
90., 85., 88., 92., 95.];

let mut hclust = AgglomerativeClustering::new(1, Metrics::Single);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Single);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 28.478061731796284,
38.1051177665153, 47.10626285325551, 54.31390245600108];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Complete);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Complete);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 33.77869150810907,
45.58508528016593, 60.13318551349163, 91.53141537199127];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Average);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Average);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 31.128376619952675,
41.84510152334062, 53.305905710336944, 69.92295649225116];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Centroid);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Centroid);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045,
38.7426831118429, 44.021013600051624, 44.02758328256392];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Median);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Median);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045,
38.7426831118429, 45.898926771596045, 45.42216730738696];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward1);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward1);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 34.4020769090494,
51.65691081579053, 66.03152040007744, 150.95171411164773];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward2);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward2);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334,
47.97916214358062, 62.481997407253225, 115.91869071527186];
47.97916214358062, 62.48199740725323, 115.91869071527186];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward);
let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward);
let _ = hclust.train(&data);
let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334,
47.97916214358062, 62.481997407253225, 115.91869071527186];
47.97916214358062, 62.48199740725323, 115.91869071527186];
assert_eq!(hclust.distances.unwrap(), exp);
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
assert_eq!(hclust.merged.unwrap(), exp);
Expand Down
18 changes: 9 additions & 9 deletions tests/learning/agglomerative.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rm::linalg::{Matrix, Vector};
use rm::learning::agglomerative::{AgglomerativeClustering, Metrics};
use rm::learning::agglomerative::{AgglomerativeClustering, Linkage};

#[test]
fn test_cluster() {
Expand All @@ -11,42 +11,42 @@ fn test_cluster() {
55., 65., 80., 75., 85.,
90., 85., 88., 92., 95.]);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Single);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Single);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Complete);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Complete);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Average);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Average);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Centroid);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Centroid);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Median);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Median);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward1);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward1);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward2);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward2);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);

let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward);
let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward);
let res = hclust.train(&data);
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
assert_eq!(res.unwrap(), exp);
Expand Down

0 comments on commit 4c46d25

Please sign in to comment.