Skip to content

Commit

Permalink
feat: columnar metrics are not much faster to compute
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Nov 12, 2024
1 parent c514799 commit b96513a
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 63 deletions.
7 changes: 7 additions & 0 deletions crates/abd-clam/src/msa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ use distances::Number;
pub use builder::Builder;
pub use needleman_wunsch::{Aligner, CostMatrix};

/// The number of characters.
pub(crate) const NUM_CHARS: usize = 1 + (u8::MAX as usize);
/// The square root threshold for sub-sampling.
pub(crate) const SQRT_THRESH: usize = 1000;
/// The logarithmic threshold for sub-sampling.
pub(crate) const LOG2_THRESH: usize = 100_000;

/// A multiple sequence alignment (MSA).
pub struct Msa {
/// The aligned sequences.
Expand Down
3 changes: 1 addition & 2 deletions crates/abd-clam/src/msa/needleman_wunsch/cost_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use std::collections::HashSet;

use distances::{number::Int, Number};

/// The number of characters.
const NUM_CHARS: usize = 1 + (u8::MAX as usize);
use super::super::NUM_CHARS;

/// A substitution matrix for the Needleman-Wunsch aligner.
#[derive(Clone)]
Expand Down
72 changes: 28 additions & 44 deletions crates/abd-clam/src/msa/quality/col_major.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rayon::prelude::*;

use crate::{utils, Dataset, FlatVec};

use super::{LOG2_THRESH, SQRT_THRESH};
use super::super::NUM_CHARS;

impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
/// Scores each pair of columns in the MSA, applying a penalty for gaps and
Expand All @@ -20,22 +20,6 @@ impl<T: AsRef<[u8]>, U: Number, M> FlatVec<T, U, M> {
.sum::<usize>();
score.as_f32() / utils::n_pairs(self.cardinality()).as_f32()
}

/// Same as `scoring_columns`, but with a subsample of the rows.
#[must_use]
pub fn scoring_columns_subsample(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 {
let num_rows = self.get(0).as_ref().len();
let row_indices = (0..num_rows).collect::<Vec<_>>();
let samples = utils::choose_samples(&row_indices, SQRT_THRESH, LOG2_THRESH);
let score = self
.instances
.iter()
.map(AsRef::as_ref)
.map(|c| samples.iter().map(|&i| c[i]).collect::<Vec<_>>())
.map(|c| sc_inner(&c, gap_char, gap_penalty, mismatch_penalty))
.sum::<usize>();
score.as_f32() / utils::n_pairs(samples.len()).as_f32()
}
}

impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
Expand All @@ -51,37 +35,37 @@ impl<T: AsRef<[u8]> + Send + Sync, U: Number, M: Send + Sync> FlatVec<T, U, M> {
.sum::<usize>();
score.as_f32() / utils::n_pairs(num_seqs).as_f32()
}

/// Parallel version of `scoring_columns_subsample`.
#[must_use]
pub fn par_scoring_columns_subsample(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 {
let num_rows = self.get(0).as_ref().len();
let row_indices = (0..num_rows).collect::<Vec<_>>();
let seq_ids = utils::choose_samples(&row_indices, SQRT_THRESH, LOG2_THRESH);
let score = self
.instances
.par_iter()
.map(AsRef::as_ref)
.map(|c| seq_ids.iter().map(|&i| c[i]).collect::<Vec<_>>())
.map(|c| sc_inner(&c, gap_char, gap_penalty, mismatch_penalty))
.sum::<usize>();
score.as_f32() / utils::n_pairs(seq_ids.len()).as_f32()
}
}

/// Scores a single pair of columns in the MSA, applying a penalty for gaps and
/// mismatches.
fn sc_inner(col: &[u8], gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> usize {
col.iter()
// Create a frequency count of the characters in the column.
let freqs = col.iter().fold([0; NUM_CHARS], |mut freqs, &c| {
freqs[c as usize] += 1;
freqs
});

// Start scoring the column.
let mut score = 0;

// Calculate the number of pairs of characters of which one is a gap and
// apply the gap penalty.
let num_gaps = freqs[gap_char as usize];
score += num_gaps * (col.len() - num_gaps) * gap_penalty / 2;

// Get the frequencies of non-gap characters with non-zero frequency.
let freqs = freqs
.into_iter()
.enumerate()
.filter(|&(i, f)| (f > 0) && (i != gap_char as usize))
.map(|(_, f)| f)
.collect::<Vec<_>>();

// For each combinatorial pair, add mismatch penalties.
freqs
.iter()
.enumerate()
.flat_map(|(i, &a)| col.iter().skip(i + 1).map(move |&b| (a, b)))
.fold(0, |score, (a, b)| {
if a == gap_char || b == gap_char {
score + gap_penalty
} else if a != b {
score + mismatch_penalty
} else {
score
}
})
.flat_map(|(i, &f1)| freqs.iter().skip(i + 1).map(move |&f2| (f1, f2)))
.fold(score, |score, (f1, f2)| score + f1 * f2 * mismatch_penalty)
}
5 changes: 0 additions & 5 deletions crates/abd-clam/src/msa/quality/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,3 @@
pub mod col_major;
pub mod row_major;

/// The square root threshold for sub-sampling.
pub(crate) const SQRT_THRESH: usize = 1000;
/// The logarithmic threshold for sub-sampling.
pub(crate) const LOG2_THRESH: usize = 100_000;
2 changes: 1 addition & 1 deletion crates/abd-clam/src/msa/quality/row_major.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rayon::prelude::*;

use crate::{utils, Dataset, FlatVec, Metric};

use super::{LOG2_THRESH, SQRT_THRESH};
use super::super::{LOG2_THRESH, SQRT_THRESH};

// TODO: Consider adding a new trait for MSA datasets. Then move these methods
// to that trait.
Expand Down
6 changes: 6 additions & 0 deletions crates/results/msa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## Usage

Run the following command to see the usage information:

```shell
cargo run -r -p results-msa -- --help
```

If you want to run the MSA on all a sequences in a fasta file, you can use the following command:

```shell
Expand Down
26 changes: 15 additions & 11 deletions crates/results/msa/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,29 @@ fn main() -> Result<(), String> {
msa_data.dimensionality_hint()
);

let ps_quality = msa_data.par_scoring_pairwise_subsample(b'-', 1, 1);
let gap_char = b'-';
let gap_penalty = 1;
let mismatch_penalty = 1;
let gap_open_penalty = 10;
let gap_ext_penalty = 1;

let ps_quality = msa_data.par_scoring_pairwise_subsample(gap_char, gap_penalty, mismatch_penalty);
ftlog::info!("Pairwise scoring metric estimate: {ps_quality}");

// let ps_quality = msa_data.par_scoring_pairwise(b'-', 1, 1);
// let ps_quality = msa_data.par_scoring_pairwise(gap_char, gap_penalty, mismatch_penalty);
// ftlog::info!("Pairwise scoring metric: {ps_quality}");

let wps_quality = msa_data.par_weighted_scoring_pairwise_subsample(b'-', 10, 1, 10);
let wps_quality =
msa_data.par_weighted_scoring_pairwise_subsample(gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty);
ftlog::info!("Weighted pairwise scoring metric estimate: {wps_quality}");

// let wps_quality = msa_data.par_weighted_scoring_pairwise(b'-', 10, 1, 10);
// let wps_quality = msa_data.par_weighted_scoring_pairwise(gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty);
// ftlog::info!("Weighted pairwise scoring metric: {wps_quality}");

let dd_quality = msa_data.par_distance_distortion_subsample(b'-');
let dd_quality = msa_data.par_distance_distortion_subsample(gap_char);
ftlog::info!("Distance distortion metric estimate: {dd_quality}");

// let dd_quality = msa_data.par_distance_distortion(b'-');
// let dd_quality = msa_data.par_distance_distortion(gap_char);
// ftlog::info!("Distance distortion metric: {dd_quality}");

ftlog::info!("Finished scoring row-wise.");
Expand All @@ -181,11 +188,8 @@ fn main() -> Result<(), String> {
let col_ms_data = msa_data.as_col_major::<Vec<_>>(metric);
ftlog::info!("Finished converting to column-major format.");

// let cs_quality = col_ms_data.par_scoring_columns(b'-', 1, 1);
// ftlog::info!("Column scoring metric estimate: {cs_quality}");

let cs_quality = col_ms_data.par_scoring_columns_subsample(b'-', 1, 1);
ftlog::info!("Column scoring metric estimate: {cs_quality}");
let cs_quality = col_ms_data.par_scoring_columns(gap_char, gap_penalty, mismatch_penalty);
ftlog::info!("Column scoring metric: {cs_quality}");

ftlog::info!("Finished scoring column-wise.");

Expand Down

0 comments on commit b96513a

Please sign in to comment.