Skip to content

Commit

Permalink
feat: added lev-vs-ham distance distortion metric
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Nov 10, 2024
1 parent 42f517e commit 330e581
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 59 deletions.
4 changes: 4 additions & 0 deletions crates/abd-clam/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ smartcore = { git = "https://github.com/smartcorelib/smartcore.git", rev = "239c
ordered-float = { version = "4.2.2", optional = true }
bincode = { workspace = true, optional = true }

# For:
# - MSA
stringzilla = { workspace = true }

[dev-dependencies]
symagen = { workspace = true }
bincode = { workspace = true }
Expand Down
116 changes: 65 additions & 51 deletions crates/abd-clam/src/msa/quality/row_major.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
}

impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
/// Scores the MSA using the distortion of the Levenshtein edit distance
/// and the Hamming distance between each pair of sequences.
#[must_use]
pub fn distance_distortion(&self, gap_char: u8) -> f32 {
let score = self.sum_of_pairs(&self.indices(), |s1, s2| dd_inner(s1, s2, gap_char));
score.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
}

/// Same as `distance_distortion`, but only estimates the score for a subset
/// of the pairwise alignments.
#[must_use]
pub fn distance_distortion_subsample(&self, gap_char: u8) -> f32 {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH / 4, LOG2_THRESH / 4);
let score = self.sum_of_pairs(&indices, |s1, s2| dd_inner(s1, s2, gap_char));
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Scores each pairwise alignment in the MSA, applying a penalty for gaps
/// and mismatches.
///
Expand All @@ -54,33 +71,19 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
/// number of pairwise alignments.
#[must_use]
pub fn scoring_pairwise(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 {
let m = self._scoring_pairwise(gap_char, gap_penalty, mismatch_penalty, &self.indices());
m.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty);
let score = self.sum_of_pairs(&self.indices(), scorer);
score.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
}

/// Same as `scoring_pairwise`, but only estimates the score for a subset of
/// the pairwise alignments.
#[must_use]
pub fn scoring_pairwise_subsample(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH);
let m = self._scoring_pairwise(gap_char, gap_penalty, mismatch_penalty, &indices);
m.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Helper function for `scoring_pairwise` and `scoring_pairwise_subsample`.
fn _scoring_pairwise(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize, indices: &[usize]) -> usize {
indices
.iter()
.map(|&i| self.get(i).as_ref())
.enumerate()
.flat_map(|(i, s1)| {
indices
.iter()
.skip(i + 1)
.map(move |&j| (s1, self.get(j).as_ref()))
.map(|(s1, s2)| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty))
})
.sum()
let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty);
let score = self.sum_of_pairs(&indices, scorer);
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Scores each pairwise alignment in the MSA, applying penalties for
Expand Down Expand Up @@ -120,17 +123,17 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
gap_ext_penalty: usize,
mismatch_penalty: usize,
) -> f32 {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH);
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH / 4, LOG2_THRESH / 4);
let scorer =
|s1: &[u8], s2: &[u8]| wsp_inner(s1, s2, gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty);
let score = self.sum_of_pairs(&indices, scorer);
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Calculate the sum of the pairwise scores for a given scorer.
pub(crate) fn sum_of_pairs<F>(&self, indices: &[usize], scorer: F) -> usize
pub(crate) fn sum_of_pairs<F, G: Number>(&self, indices: &[usize], scorer: F) -> G
where
F: Fn(&[u8], &[u8]) -> usize,
F: Fn(&[u8], &[u8]) -> G,
{
indices
.iter()
Expand All @@ -149,41 +152,36 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {

// Parallelized implementations here
impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
/// Parallel version of `distance_distortion`.
#[must_use]
pub fn par_distance_distortion(&self, gap_char: u8) -> f32 {
let score = self.par_sum_of_pairs(&self.indices(), |s1, s2| dd_inner(s1, s2, gap_char));
score.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
}

/// Parallel version of `distance_distortion_subsample`.
#[must_use]
pub fn par_distance_distortion_subsample(&self, gap_char: u8) -> f32 {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH / 8, LOG2_THRESH / 8);
let score = self.par_sum_of_pairs(&indices, |s1, s2| dd_inner(s1, s2, gap_char));
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Parallel version of `scoring_pairwise`.
#[must_use]
pub fn par_scoring_pairwise(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 {
let m = self._par_scoring_pairwise(gap_char, gap_penalty, mismatch_penalty, &self.indices());
m.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty);
let score = self.par_sum_of_pairs(&self.indices(), scorer);
score.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
}

/// Parallel version of `scoring_pairwise_subsample`.
#[must_use]
pub fn par_scoring_pairwise_subsample(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 {
let indices = utils::choose_samples(&self.indices(), SQRT_THRESH, LOG2_THRESH);
let m = self._par_scoring_pairwise(gap_char, gap_penalty, mismatch_penalty, &indices);
m.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Parallel version of `_scoring_pairwise`.
fn _par_scoring_pairwise(
&self,
gap_char: u8,
gap_penalty: usize,
mismatch_penalty: usize,
indices: &[usize],
) -> usize {
indices
.par_iter()
.map(|&i| self.get(i).as_ref())
.enumerate()
.flat_map(|(i, s1)| {
indices
.par_iter()
.skip(i + 1)
.map(move |&j| (s1, self.get(j).as_ref()))
.map(|(s1, s2)| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty))
})
.sum()
let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty);
let score = self.par_sum_of_pairs(&indices, scorer);
score.as_f32() / utils::n_pairs(indices.len()).as_f32()
}

/// Parallel version of `weighted_scoring_pairwise`.
Expand Down Expand Up @@ -218,9 +216,9 @@ impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
}

/// Calculate the sum of the pairwise scores for a given scorer.
pub(crate) fn par_sum_of_pairs<F>(&self, indices: &[usize], scorer: F) -> usize
pub(crate) fn par_sum_of_pairs<F, G: Number>(&self, indices: &[usize], scorer: F) -> G
where
F: (Fn(&[u8], &[u8]) -> usize) + Send + Sync,
F: (Fn(&[u8], &[u8]) -> G) + Send + Sync,
{
indices
.par_iter()
Expand Down Expand Up @@ -284,3 +282,19 @@ fn wsp_inner(
}
})
}

/// Measures the distortion of the Levenshtein edit distance between the
/// unaligned sequences and the Hamming distance between the aligned sequences.
fn dd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 {
let ham = s1.iter().zip(s2.iter()).filter(|(&a, &b)| a != b).count();

let s1 = s1.iter().filter(|&&c| c != gap_char).copied().collect::<Vec<_>>();
let s2 = s2.iter().filter(|&&c| c != gap_char).copied().collect::<Vec<_>>();
let lev = stringzilla::sz::edit_distance(s1, s2);

if lev == 0 {
1.0
} else {
ham.as_f32() / lev.as_f32()
}
}
22 changes: 14 additions & 8 deletions crates/results/msa/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,25 @@ fn main() -> Result<(), String> {
// let wps_quality = msa_data.par_weighted_scoring_pairwise(b'-', 10, 1, 10);
// ftlog::info!("Weighted pairwise scoring metric: {wps_quality}");

let dd_quality = msa_data.par_distance_distortion_subsample(b'-');
ftlog::info!("Distance distortion metric estimate: {dd_quality}");

// let dd_quality = msa_data.par_distance_distortion(b'-');
// ftlog::info!("Distance distortion metric: {dd_quality}");

ftlog::info!("Finished scoring row-wise.");

ftlog::info!("Convert to column-major format.");
let metric = Metric::default();
let col_ms_data = msa_data.as_col_major::<Vec<_>>(metric);
// ftlog::info!("Convert to column-major format.");
// let metric = Metric::default();
// let col_ms_data = msa_data.as_col_major::<Vec<_>>(metric);

let cs_quality = col_ms_data.par_scoring_columns(b'-', 1, 1);
ftlog::info!("Column scoring metric estimate: {cs_quality}");
// let cs_quality = col_ms_data.par_scoring_columns(b'-', 1, 1);
// ftlog::info!("Column scoring metric estimate: {cs_quality}");

let wcs_quality = col_ms_data.par_weighted_scoring_columns(b'-', 10, 1, 10);
ftlog::info!("Weighted column scoring metric estimate: {wcs_quality}");
// let wcs_quality = col_ms_data.par_weighted_scoring_columns(b'-', 10, 1, 10);
// ftlog::info!("Weighted column scoring metric estimate: {wcs_quality}");

ftlog::info!("Finished scoring column-wise.");
// ftlog::info!("Finished scoring column-wise.");

Ok(())
}

0 comments on commit 330e581

Please sign in to comment.