Skip to content

Commit

Permalink
feat: added abg and max p-distance measures
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Nov 12, 2024
1 parent b96513a commit 5e9c72d
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 7 deletions.
119 changes: 112 additions & 7 deletions crates/abd-clam/src/msa/quality/row_major.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,46 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
}

impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
/// Calculates the average and maximum `p-distance`s of all pairwise
/// alignments in the MSA.
#[must_use]
pub fn p_distance_stats(&self, gap_char: u8) -> (f32, f32) {
let p_dists = self.p_distances(gap_char);
let n_pairs = p_dists.len();
let (sum, max) = p_dists
.into_iter()
.fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist)));
let avg = sum / n_pairs.as_f32();
(avg, max)
}

/// Same as `p_distance_stats`, but only estimates the score for a subset of
/// the pairwise alignments.
#[must_use]
pub fn p_distance_stats_subsample(&self, gap_char: u8) -> (f32, f32) {
let p_dists = self.p_distances_subsample(gap_char);
let n_pairs = p_dists.len();
let (sum, max) = p_dists
.into_iter()
.fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist)));
let avg = sum / n_pairs.as_f32();
(avg, max)
}

/// Calculates the `p-distance` of each pairwise alignment in the MSA.
fn p_distances(&self, gap_char: u8) -> Vec<f32> {
let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char);
self.apply_pairwise(&self.indices(), scorer)
}

/// Same as `p_distances`, but only estimates the score for a subset of the
/// pairwise alignments.
fn p_distances_subsample(&self, gap_char: u8) -> Vec<f32> {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH);
let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char);
self.apply_pairwise(&indices, scorer)
}

/// Scores the MSA using the distortion of the Levenshtein edit distance
/// and the Hamming distance between each pair of sequences.
#[must_use]
Expand Down Expand Up @@ -127,10 +167,10 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Calculate the sum of the pairwise scores for a given scorer.
pub(crate) fn sum_of_pairs<F, G: Number>(&self, indices: &[usize], scorer: F) -> G
/// Applies a pairwise scorer to all pairs of sequences in the MSA.
pub(crate) fn apply_pairwise<F, G: Number>(&self, indices: &[usize], scorer: F) -> Vec<G>
where
F: Fn(&[u8], &[u8]) -> G,
F: (Fn(&[u8], &[u8]) -> G),
{
indices
.iter()
Expand All @@ -143,12 +183,59 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
.map(move |&j| (s1, self.get(j).as_ref()))
.map(|(s1, s2)| scorer(s1, s2))
})
.sum()
.collect()
}

/// Calculate the sum of the pairwise scores for a given scorer.
pub(crate) fn sum_of_pairs<F, G: Number>(&self, indices: &[usize], scorer: F) -> G
where
F: Fn(&[u8], &[u8]) -> G,
{
self.apply_pairwise(indices, scorer).into_iter().sum()
}
}

// Parallelized implementations here
impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
/// Calculates the average and maximum `p-distance`s of all pairwise
/// alignments in the MSA.
#[must_use]
pub fn par_p_distance_stats(&self, gap_char: u8) -> (f32, f32) {
let p_dists = self.par_p_distances(gap_char);
let n_pairs = p_dists.len();
let (sum, max) = p_dists
.into_iter()
.fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist)));
let avg = sum / n_pairs.as_f32();
(avg, max)
}

/// Same as `par_p_distance_stats`, but only estimates the score for a
/// subset of the pairwise alignments.
#[must_use]
pub fn par_p_distance_stats_subsample(&self, gap_char: u8) -> (f32, f32) {
let p_dists = self.par_p_distances_subsample(gap_char);
let n_pairs = p_dists.len();
let (sum, max) = p_dists
.into_iter()
.fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist)));
let avg = sum / n_pairs.as_f32();
(avg, max)
}

/// Calculates the `p-distance` of each pairwise alignment in the MSA.
fn par_p_distances(&self, gap_char: u8) -> Vec<f32> {
let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char);
self.par_apply_pairwise(&self.indices(), scorer)
}

/// Parallel version of `p_distance_stats_subsample`.
fn par_p_distances_subsample(&self, gap_char: u8) -> Vec<f32> {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH);
let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char);
self.par_apply_pairwise(&indices, scorer)
}

/// Parallel version of `distance_distortion`.
#[must_use]
pub fn par_distance_distortion(&self, gap_char: u8) -> f32 {
Expand Down Expand Up @@ -212,8 +299,8 @@ impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Calculate the sum of the pairwise scores for a given scorer.
pub(crate) fn par_sum_of_pairs<F, G: Number>(&self, indices: &[usize], scorer: F) -> G
/// Parallel version of `apply_pairwise`.
pub(crate) fn par_apply_pairwise<F, G: Number>(&self, indices: &[usize], scorer: F) -> Vec<G>
where
F: (Fn(&[u8], &[u8]) -> G) + Send + Sync,
{
Expand All @@ -228,7 +315,15 @@ impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
.map(move |&j| (s1, self.get(j).as_ref()))
.map(|(s1, s2)| scorer(s1, s2))
})
.sum()
.collect()
}

/// Calculate the sum of the pairwise scores for a given scorer.
pub(crate) fn par_sum_of_pairs<F, G: Number>(&self, indices: &[usize], scorer: F) -> G
where
F: (Fn(&[u8], &[u8]) -> G) + Send + Sync,
{
self.apply_pairwise(indices, scorer).into_iter().sum()
}
}

Expand Down Expand Up @@ -295,3 +390,13 @@ fn dd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 {
ham.as_f32() / lev.as_f32()
}
}

/// Calculates the p-distance of a pair of sequences.
fn pd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 {
let num_mismatches = s1
.iter()
.zip(s2.iter())
.filter(|(&a, &b)| a != gap_char && b != gap_char && a != b)
.count();
num_mismatches.as_f32() / s1.len().as_f32()
}
6 changes: 6 additions & 0 deletions crates/results/msa/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ fn main() -> Result<(), String> {
// let wps_quality = msa_data.par_weighted_scoring_pairwise(gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty);
// ftlog::info!("Weighted pairwise scoring metric: {wps_quality}");

let (avg_p, max_p) = msa_data.par_p_distance_stats_subsample(gap_char);
ftlog::info!("Pairwise distance stats estimate: avg = {avg_p:.4}, max = {max_p:.4}");

// let (avg_p, max_p) = msa_data.par_p_distance_stats(gap_char);
// ftlog::info!("Pairwise distance stats: avg = {avg_p}, max = {max_p}");

let dd_quality = msa_data.par_distance_distortion_subsample(gap_char);
ftlog::info!("Distance distortion metric estimate: {dd_quality}");

Expand Down

0 comments on commit 5e9c72d

Please sign in to comment.