Skip to content

Commit

Permalink
feat: added parallelized methods in SMC
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Nov 30, 2024
1 parent d417b71 commit e331d29
Show file tree
Hide file tree
Showing 4 changed files with 484 additions and 72 deletions.
26 changes: 26 additions & 0 deletions crates/abd-clam/src/chaoda/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,31 @@ impl<'a, T: Number, S: Cluster<T>> Graph<'a, T, S> {
}

impl<'a, T: Number, S: ParCluster<T>> Graph<'a, T, S> {
/// Parallel version of `Graph::from_root_uniform_depth`.
pub fn par_from_root_uniform_depth<I: Send + Sync, D: ParDataset<I>, M: ParMetric<I, T>>(
root: &'a Vertex<T, S>,
data: &D,
metric: &M,
depth: usize,
min_depth: usize,
) -> Self
where
T: 'a,
{
let cluster_scorer = |clusters: &[&'a Vertex<T, S>]| {
clusters
.iter()
.map(|c| {
if c.depth() == depth || (c.is_leaf() && c.depth() < depth) {
1.0
} else {
0.0
}
})
.collect::<Vec<_>>()
};
Self::par_from_root(root, data, metric, cluster_scorer, min_depth)
}
/// Parallel version of `Graph::from_root`.
pub fn par_from_root<I: Send + Sync, D: ParDataset<I>, M: ParMetric<I, T>>(
root: &'a Vertex<T, S>,
Expand Down Expand Up @@ -244,6 +269,7 @@ impl<'a, T: Number, S: ParCluster<T>> Graph<'a, T, S> {
let population = vertices.iter().map(|v| v.cardinality()).sum();
let cardinality = components.iter().map(Component::cardinality).sum();
let diameter = components.iter().map(Component::diameter).max().unwrap_or_default();

Self {
components,
population,
Expand Down
160 changes: 146 additions & 14 deletions crates/abd-clam/src/chaoda/inference/trained_smc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
use distances::Number;
use ndarray::prelude::*;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

use crate::{
chaoda::{roc_auc_score, Vertex},
cluster::{adapter::Adapter, Partition},
cluster::{
adapter::{Adapter, ParAdapter},
ParCluster, ParPartition, Partition,
},
dataset::ParDataset,
metric::ParMetric,
Cluster, Dataset, Metric,
};

Expand Down Expand Up @@ -103,22 +109,22 @@ impl TrainedSmc {
M: Metric<I, T>,
S: Cluster<T>,
{
let mut num_discerning = 0;
let mut scores = Vec::new();
for (i, combination) in self.0.iter().enumerate() {
if combination.discerns(tol) {
ftlog::info!("Predicting with combination {} of {}", i + 1, self.0.len());
num_discerning += 1;
let (_, mut row) = combination.predict(root, data, metric, min_depth);
scores.append(&mut row);
} else {
let (num_discerning, scores) = self
.0
.iter()
.enumerate()
.filter(|(_, combination)| combination.discerns(tol))
.fold((0, Vec::new()), |(num_discerning, mut scores), (i, combination)| {
ftlog::info!(
"Skipping combination {} of {} because its expected roc-score is within {tol} of `0.5`.",
"Predicting with combination {}/{} {}",
i + 1,
self.0.len()
self.0.len(),
combination.name()
);
}
}
let (_, mut row) = combination.predict(root, data, metric, min_depth);
scores.append(&mut row);
(num_discerning + 1, scores)
});

if num_discerning == 0 {
ftlog::warn!("No discerning combinations found. Returning all scores as `0.5`.");
Expand Down Expand Up @@ -235,4 +241,130 @@ impl TrainedSmc {
roc_auc_score(labels, &scores)
.unwrap_or_else(|e| unreachable!("Could not compute ROC-AUC score for dataset {}: {e}", data.name()))
}

/// Parallel version of `create_tree`.
fn par_create_tree<I, T, D, M, S, C>(data: &D, metric: &M, criteria: &C, seed: Option<u64>) -> Vertex<T, S>
where
I: Send + Sync,
T: Number,
D: ParDataset<I>,
M: ParMetric<I, T>,
S: ParPartition<T>,
C: (Fn(&S) -> bool) + Send + Sync,
{
let source = S::par_new_tree(data, metric, criteria, seed);
Vertex::par_adapt_tree(source, None, data, metric)
}

/// Parallel version of `predict_from_tree`.
pub fn par_predict_from_tree<I, T, D, M, S>(
&self,
data: &D,
metric: &M,
root: &Vertex<T, S>,
min_depth: usize,
tol: f32,
) -> Vec<f32>
where
I: Send + Sync,
T: Number,
D: ParDataset<I>,
M: ParMetric<I, T>,
S: ParCluster<T>,
{
let (num_discerning, scores) = self
.0
.par_iter()
.enumerate()
.filter(|(_, combination)| combination.discerns(tol))
.fold(
|| (0, Vec::new()),
|(num_discerning, mut scores), (i, combination)| {
ftlog::info!(
"Predicting with combination {}/{} {}",
i + 1,
self.0.len(),
combination.name()
);
let (_, mut row) = combination.par_predict(root, data, metric, min_depth);
scores.append(&mut row);
(num_discerning + 1, scores)
},
)
.reduce(
|| (0, Vec::new()),
|(num_discerning, mut scores), (n, s)| {
scores.extend(s);
(num_discerning + n, scores)
},
);

if num_discerning == 0 {
ftlog::warn!("No discerning combinations found. Returning all scores as `0.5`.");
return vec![0.5; data.cardinality()];
};

ftlog::info!("Averaging scores from {num_discerning} discerning combinations.");
let shape = (data.cardinality(), num_discerning);
let scores_len = scores.len();
let scores = Array2::from_shape_vec(shape, scores).unwrap_or_else(|e| {
unreachable!(
"Could not convert Vec<T> of len {scores_len} to Array2<T> of shape {:?}: {e}",
shape
)
});

scores
.mean_axis(Axis(1))
.unwrap_or_else(|| unreachable!("Could not compute mean of Array2<T> along axis 1"))
.to_vec()
}

/// Parallel version of `predict`.
pub fn par_predict<I, T, D, M, S, C>(
&self,
data: &D,
metric: &M,
criteria: &C,
seed: Option<u64>,
min_depth: usize,
tol: f32,
) -> Vec<f32>
where
I: Send + Sync,
T: Number,
D: ParDataset<I>,
M: ParMetric<I, T>,
S: ParPartition<T>,
C: (Fn(&S) -> bool) + Send + Sync,
{
let root = Self::par_create_tree(data, metric, criteria, seed);
self.par_predict_from_tree(data, metric, &root, min_depth, tol)
}

/// Parallel version of `evaluate`.
#[allow(clippy::too_many_arguments)]
pub fn par_evaluate<I, T, D, M, S, C>(
&self,
data: &D,
labels: &[bool],
metric: &M,
criteria: &C,
seed: Option<u64>,
min_depth: usize,
tol: f32,
) -> f32
where
I: Send + Sync,
T: Number,
D: ParDataset<I>,
M: ParMetric<I, T>,
S: ParPartition<T>,
C: (Fn(&S) -> bool) + Send + Sync,
{
let root = Self::par_create_tree(data, metric, criteria, seed);
let scores = self.par_predict_from_tree(data, metric, &root, min_depth, tol);
roc_auc_score(labels, &scores)
.unwrap_or_else(|e| unreachable!("Could not compute ROC-AUC score for dataset {}: {e}", data.name()))
}
}
Loading

0 comments on commit e331d29

Please sign in to comment.