diff --git a/crates/abd-clam/src/msa/quality/row_major.rs b/crates/abd-clam/src/msa/quality/row_major.rs index 780218761..9ef846486 100644 --- a/crates/abd-clam/src/msa/quality/row_major.rs +++ b/crates/abd-clam/src/msa/quality/row_major.rs @@ -37,6 +37,46 @@ impl, U: Number, M> FlatVec { } impl, U: Number, M> FlatVec { + /// Calculates the average and maximum `p-distance`s of all pairwise + /// alignments in the MSA. + #[must_use] + pub fn p_distance_stats(&self, gap_char: u8) -> (f32, f32) { + let p_dists = self.p_distances(gap_char); + let n_pairs = p_dists.len(); + let (sum, max) = p_dists + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Same as `p_distance_stats`, but only estimates the score for a subset of + /// the pairwise alignments. + #[must_use] + pub fn p_distance_stats_subsample(&self, gap_char: u8) -> (f32, f32) { + let p_dists = self.p_distances_subsample(gap_char); + let n_pairs = p_dists.len(); + let (sum, max) = p_dists + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Calculates the `p-distance` of each pairwise alignment in the MSA. + fn p_distances(&self, gap_char: u8) -> Vec { + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.apply_pairwise(&self.indices(), scorer) + } + + /// Same as `p_distances`, but only estimates the score for a subset of the + /// pairwise alignments. + fn p_distances_subsample(&self, gap_char: u8) -> Vec { + let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH); + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.apply_pairwise(&indices, scorer) + } + /// Scores the MSA using the distortion of the Levenshtein edit distance /// and the Hamming distance between each pair of sequences. #[must_use] @@ -127,10 +167,10 @@ impl, U: Number, M> FlatVec { score.as_f32() / utils::n_pairs(indices.len()).as_f32() } - /// Calculate the sum of the pairwise scores for a given scorer. - pub(crate) fn sum_of_pairs(&self, indices: &[usize], scorer: F) -> G + /// Applies a pairwise scorer to all pairs of sequences in the MSA. + pub(crate) fn apply_pairwise(&self, indices: &[usize], scorer: F) -> Vec where - F: Fn(&[u8], &[u8]) -> G, + F: (Fn(&[u8], &[u8]) -> G), { indices .iter() @@ -143,12 +183,59 @@ impl, U: Number, M> FlatVec { .map(move |&j| (s1, self.get(j).as_ref())) .map(|(s1, s2)| scorer(s1, s2)) }) - .sum() + .collect() + } + + /// Calculate the sum of the pairwise scores for a given scorer. + pub(crate) fn sum_of_pairs(&self, indices: &[usize], scorer: F) -> G + where + F: Fn(&[u8], &[u8]) -> G, + { + self.apply_pairwise(indices, scorer).into_iter().sum() } } // Parallelized implementations here impl + Send + Sync, U: Number, M: Send + Sync> FlatVec { + /// Calculates the average and maximum `p-distance`s of all pairwise + /// alignments in the MSA. + #[must_use] + pub fn par_p_distance_stats(&self, gap_char: u8) -> (f32, f32) { + let p_dists = self.par_p_distances(gap_char); + let n_pairs = p_dists.len(); + let (sum, max) = p_dists + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Same as `par_p_distance_stats`, but only estimates the score for a + /// subset of the pairwise alignments. + #[must_use] + pub fn par_p_distance_stats_subsample(&self, gap_char: u8) -> (f32, f32) { + let p_dists = self.par_p_distances_subsample(gap_char); + let n_pairs = p_dists.len(); + let (sum, max) = p_dists + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Calculates the `p-distance` of each pairwise alignment in the MSA. + fn par_p_distances(&self, gap_char: u8) -> Vec { + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.par_apply_pairwise(&self.indices(), scorer) + } + + /// Parallel version of `p_distance_stats_subsample`. + fn par_p_distances_subsample(&self, gap_char: u8) -> Vec { + let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH); + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.par_apply_pairwise(&indices, scorer) + } + /// Parallel version of `distance_distortion`. #[must_use] pub fn par_distance_distortion(&self, gap_char: u8) -> f32 { @@ -212,8 +299,8 @@ impl + Send + Sync, U: Number, M: Send + Sync> FlatVec { score.as_f32() / utils::n_pairs(indices.len()).as_f32() } - /// Calculate the sum of the pairwise scores for a given scorer. - pub(crate) fn par_sum_of_pairs(&self, indices: &[usize], scorer: F) -> G + /// Parallel version of `apply_pairwise`. + pub(crate) fn par_apply_pairwise(&self, indices: &[usize], scorer: F) -> Vec where F: (Fn(&[u8], &[u8]) -> G) + Send + Sync, { @@ -228,7 +315,15 @@ impl + Send + Sync, U: Number, M: Send + Sync> FlatVec { .map(move |&j| (s1, self.get(j).as_ref())) .map(|(s1, s2)| scorer(s1, s2)) }) - .sum() + .collect() + } + + /// Calculate the sum of the pairwise scores for a given scorer. + pub(crate) fn par_sum_of_pairs(&self, indices: &[usize], scorer: F) -> G + where + F: (Fn(&[u8], &[u8]) -> G) + Send + Sync, + { + self.apply_pairwise(indices, scorer).into_iter().sum() } } @@ -295,3 +390,13 @@ fn dd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 { ham.as_f32() / lev.as_f32() } } + +/// Calculates the p-distance of a pair of sequences. +fn pd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 { + let num_mismatches = s1 + .iter() + .zip(s2.iter()) + .filter(|(&a, &b)| a != gap_char && b != gap_char && a != b) + .count(); + num_mismatches.as_f32() / s1.len().as_f32() +} diff --git a/crates/results/msa/src/main.rs b/crates/results/msa/src/main.rs index 71e44a590..7ea81b0e0 100644 --- a/crates/results/msa/src/main.rs +++ b/crates/results/msa/src/main.rs @@ -175,6 +175,12 @@ fn main() -> Result<(), String> { // let wps_quality = msa_data.par_weighted_scoring_pairwise(gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); // ftlog::info!("Weighted pairwise scoring metric: {wps_quality}"); + let (avg_p, max_p) = msa_data.par_p_distance_stats_subsample(gap_char); + ftlog::info!("Pairwise distance stats estimate: avg = {avg_p:.4}, max = {max_p:.4}"); + + // let (avg_p, max_p) = msa_data.par_p_distance_stats(gap_char); + // ftlog::info!("Pairwise distance stats: avg = {avg_p}, max = {max_p}"); + let dd_quality = msa_data.par_distance_distortion_subsample(gap_char); ftlog::info!("Distance distortion metric estimate: {dd_quality}");