Skip to content

Commit

Permalink
wip: added instrumentation for counting the numbers of distance compu…
Browse files Browse the repository at this point in the history
…tations
  • Loading branch information
nishaq503 committed Dec 29, 2024
1 parent e3329c6 commit 43ac0d3
Show file tree
Hide file tree
Showing 18 changed files with 635 additions and 345 deletions.
71 changes: 47 additions & 24 deletions benches/cakes/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
use std::path::PathBuf;

use abd_clam::metric::ParMetric;
use clap::Parser;
use metric::ParCountingMetric;

mod metric;
mod steps;

/// Reproducible results for the CAKES paper.
Expand All @@ -39,6 +40,10 @@ struct Args {
#[arg(short('q'), long)]
num_queries: usize,

/// Whether to count the number of distance computations during search.
#[arg(short('c'), long, default_value = "false")]
count_distance_calls: bool,

/// The maximum power of 2 to which the cardinality of the dataset should be
/// augmented for scaling experiments.
///
Expand All @@ -47,6 +52,14 @@ struct Args {
#[arg(short('m'), long)]
max_power: Option<u32>,

/// The minimum power of 2 to which the cardinality of the dataset should be
/// augmented for scaling experiments.
///
/// This is only used with the tabular floating-point datasets and is
/// ignored otherwise.
#[arg(short('n'), long, default_value = "0")]
min_power: Option<u32>,

/// The seed for the random number generator.
#[arg(short('s'), long)]
seed: Option<u64>,
Expand All @@ -55,6 +68,14 @@ struct Args {
#[arg(short('t'), long, default_value = "10.0")]
max_time: f32,

/// Whether to run benchmarks with balanced trees.
#[arg(short('b'), long)]
balanced_trees: bool,

/// Whether to run benchmarks with permuted data.
#[arg(short('p'), long)]
permuted_data: bool,

/// Path to the output directory.
#[arg(short('o'), long)]
out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -95,23 +116,35 @@ fn main() -> Result<(), String> {
let radial_fractions = [0.1, 0.25];
let ks = [10, 100];
let seed = args.seed;
let max_power = args.max_power.unwrap_or_default();
let min_power = args.min_power.unwrap_or_default();
let max_power = args.max_power.unwrap_or(5);
let max_time = std::time::Duration::from_secs_f32(args.max_time);

if min_power > max_power {
return Err("min_power must be less than or equal to max_power".to_string());
}

if args.dataset.is_tabular() {
let metric: Box<dyn ParMetric<_, _>> = match args.dataset.metric() {
"cosine" => Box::new(abd_clam::metric::Cosine),
"euclidean" => Box::new(abd_clam::metric::Euclidean),
_ => return Err(format!("Unknown metric: {}", args.dataset.metric())),
let metric = {
let mut metric: Box<dyn ParCountingMetric<_, _>> = match args.dataset.metric() {
"cosine" => Box::new(metric::Cosine::new()),
"euclidean" => Box::new(metric::Euclidean::new()),
_ => return Err(format!("Unknown metric: {}", args.dataset.metric())),
};
if !args.count_distance_calls {
metric.disable_counting();
}
metric
};

let (queries, neighbors) = steps::read_tabular(&args.dataset, max_power, args.seed, &inp_dir, &out_dir)?;
let queries = &queries[..args.num_queries];
let neighbors = &neighbors[..args.num_queries];

for power in 0..=max_power {
let data_path = out_dir.join(format!("{}-{power}.flat_vec", args.dataset.name()));
for power in min_power..=max_power {
let data_path = out_dir.join(format!("{}-{power}.npy", args.dataset.name()));
ftlog::info!("Reading {}x augmented data from: {data_path:?}", 1 << power);
let run_linear = power < 4;

steps::workflow(
&out_dir,
Expand All @@ -123,25 +156,15 @@ fn main() -> Result<(), String> {
&ks,
seed,
max_time,
run_linear,
args.balanced_trees,
args.permuted_data,
)?;
}
} else {
let metric = bench_utils::metrics::Jaccard;
let (data_path, queries, neighbors) = steps::read_member_set(&args.dataset, &inp_dir, &out_dir)?;
let queries = &queries[..args.num_queries];
let neighbors = &neighbors[..args.num_queries];

steps::workflow(
&out_dir,
&data_path,
&metric,
queries,
neighbors,
&radial_fractions,
&ks,
seed,
max_time,
)?;
let msg = format!("Unsupported dataset: {}", args.dataset.name());
ftlog::error!("{msg}");
return Err(msg);
}

Ok(())
Expand Down
83 changes: 83 additions & 0 deletions benches/cakes/src/metric/cosine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//! The Cosine distance function.
use std::sync::{Arc, RwLock};

use abd_clam::{metric::ParMetric, Metric};
use distances::number::Float;

use super::{CountingMetric, ParCountingMetric};

/// The Cosine distance function.
pub struct Cosine(Arc<RwLock<usize>>, bool);

impl Cosine {
/// Creates a new `Euclidean` distance metric.
pub fn new() -> Self {
Self(Arc::new(RwLock::new(0)), false)
}
}

impl<I: AsRef<[T]>, T: Float> Metric<I, T> for Cosine {
fn distance(&self, a: &I, b: &I) -> T {
if self.1 {
<Self as CountingMetric<I, T>>::increment(self);
}
distances::vectors::cosine(a.as_ref(), b.as_ref())
}

fn name(&self) -> &str {
"cosine"
}

fn has_identity(&self) -> bool {
true
}

fn has_non_negativity(&self) -> bool {
true
}

fn has_symmetry(&self) -> bool {
true
}

fn obeys_triangle_inequality(&self) -> bool {
true
}

fn is_expensive(&self) -> bool {
false
}
}

impl<I: AsRef<[U]>, U: Float> CountingMetric<I, U> for Cosine {
fn disable_counting(&mut self) {
self.1 = false;
}

fn enable_counting(&mut self) {
self.1 = true;
}

#[allow(clippy::unwrap_used)]
fn count(&self) -> usize {
*self.0.read().unwrap()
}

#[allow(clippy::unwrap_used)]
fn reset_count(&self) -> usize {
let mut count = self.0.write().unwrap();
let old = *count;
*count = 0;
old
}

#[allow(clippy::unwrap_used)]
fn increment(&self) {
*self.0.write().unwrap() += 1;
}
}

impl<I: AsRef<[U]> + Send + Sync, U: Float> ParMetric<I, U> for Cosine {}

impl<I: AsRef<[U]> + Send + Sync, U: Float> ParCountingMetric<I, U> for Cosine {}
83 changes: 83 additions & 0 deletions benches/cakes/src/metric/euclidean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//! The `Euclidean` distance metric.
use std::sync::{Arc, RwLock};

use abd_clam::{metric::ParMetric, Metric};
use distances::number::Float;

use super::{CountingMetric, ParCountingMetric};

/// The `Euclidean` distance metric.
pub struct Euclidean(Arc<RwLock<usize>>, bool);

impl Euclidean {
/// Creates a new `Euclidean` distance metric.
pub fn new() -> Self {
Self(Arc::new(RwLock::new(0)), true)
}
}

impl<I: AsRef<[T]>, T: Float> Metric<I, T> for Euclidean {
fn distance(&self, a: &I, b: &I) -> T {
if self.1 {
<Self as CountingMetric<I, T>>::increment(self);
}
distances::vectors::euclidean(a.as_ref(), b.as_ref())
}

fn name(&self) -> &str {
"euclidean"
}

fn has_identity(&self) -> bool {
true
}

fn has_non_negativity(&self) -> bool {
true
}

fn has_symmetry(&self) -> bool {
true
}

fn obeys_triangle_inequality(&self) -> bool {
true
}

fn is_expensive(&self) -> bool {
false
}
}

impl<I: AsRef<[U]>, U: Float> CountingMetric<I, U> for Euclidean {
fn disable_counting(&mut self) {
self.1 = false;
}

fn enable_counting(&mut self) {
self.1 = true;
}

#[allow(clippy::unwrap_used)]
fn count(&self) -> usize {
*self.0.read().unwrap()
}

#[allow(clippy::unwrap_used)]
fn reset_count(&self) -> usize {
let mut count = self.0.write().unwrap();
let old = *count;
*count = 0;
old
}

#[allow(clippy::unwrap_used)]
fn increment(&self) {
*self.0.write().unwrap() += 1;
}
}

impl<I: AsRef<[U]> + Send + Sync, U: Float> ParMetric<I, U> for Euclidean {}

impl<I: AsRef<[U]> + Send + Sync, U: Float> ParCountingMetric<I, U> for Euclidean {}
Loading

0 comments on commit 43ac0d3

Please sign in to comment.