diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 3de8c37734..2c9997604d 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -43,6 +43,7 @@ serde_json = "1.0.53" primal-check = "0.3.1" thiserror = "1.0" typed-builder = "0.9.0" +atomic_float = "0.1.0" [target.'cfg(all(target_arch = "wasm32", target_vendor="unknown"))'.dependencies.wasm-bindgen] version = "0.2.62" diff --git a/src/core/src/ffi/search.rs b/src/core/src/ffi/search.rs index 09087c35f2..8a6a94a075 100644 --- a/src/core/src/ffi/search.rs +++ b/src/core/src/ffi/search.rs @@ -1,8 +1,6 @@ use crate::index::{JaccardSearch, SearchType}; -use crate::signature::Signature; -use crate::ffi::signature::SourmashSignature; -use crate::ffi::utils::{ForeignObject, SourmashStr}; +use crate::ffi::utils::ForeignObject; pub struct SourmashSearchFn; @@ -22,25 +20,3 @@ pub unsafe extern "C" fn searchfn_new( ) -> *mut SourmashSearchFn { SourmashSearchFn::from_rust(JaccardSearch::with_threshold(search_type, threshold)) } - -/* -#[no_mangle] -pub unsafe extern "C" fn searchresult_score(ptr: *const SourmashSearchResult) -> f64 { - let result = SourmashSearchResult::as_rust(ptr); - result.0 -} - -#[no_mangle] -pub unsafe extern "C" fn searchresult_filename(ptr: *const SourmashSearchResult) -> SourmashStr { - let result = SourmashSearchResult::as_rust(ptr); - (result.2).clone().into() -} - -#[no_mangle] -pub unsafe extern "C" fn searchresult_signature( - ptr: *const SourmashSearchResult, -) -> *mut SourmashSignature { - let result = SourmashSearchResult::as_rust(ptr); - SourmashSignature::from_rust((result.1).clone()) -} -*/ diff --git a/src/core/src/index/linear.rs b/src/core/src/index/linear.rs index 29adad6b22..26a3f3ec6f 100644 --- a/src/core/src/index/linear.rs +++ b/src/core/src/index/linear.rs @@ -1,3 +1,4 @@ +use std::convert::TryInto; use std::fs::File; use std::io::{BufReader, Read}; use std::path::Path; @@ -7,10 +8,13 @@ use std::rc::Rc; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; -use crate::index::storage::{FSStorage, ReadData, Storage, StorageInfo, ToWriter}; use crate::index::{Comparable, DatasetInfo, Index, JaccardSearch, SigStore}; use crate::signature::Signature; use crate::Error; +use crate::{ + index::storage::{FSStorage, ReadData, Storage, StorageInfo, ToWriter}, + sketch::Sketch, +}; #[derive(TypedBuilder)] pub struct LinearIndex { @@ -193,6 +197,47 @@ impl LinearIndex { search_fn: &JaccardSearch, query: &Signature, ) -> Result, Error> { - unimplemented!() + search_fn.check_is_compatible(&query)?; + + let query_mh; + if let Sketch::MinHash(mh) = &query.signatures[0] { + query_mh = mh; + } else { + unimplemented!() + } + + // TODO: prepare_subject and prepare_query + let location: String = "TODO".into(); + + Ok(self + .datasets + .iter() + .filter_map(|subj| { + let subj_sig = subj.data().unwrap(); + let subj_mh; + if let Sketch::MinHash(mh) = &subj_sig.signatures[0] { + subj_mh = mh; + } else { + unimplemented!() + } + + let (shared_size, total_size) = query_mh.intersection_size(&subj_mh).unwrap(); + let query_size = query.size(); + let subj_size = subj.size(); + + let score: f64 = search_fn.score( + query_size.try_into().unwrap(), + shared_size, + subj_size.try_into().unwrap(), + total_size, + ); + + if search_fn.passes(score) && search_fn.collect(score, subj) { + Some((score, subj_sig.clone(), location.clone())) + } else { + None + } + }) + .collect()) } } diff --git a/src/core/src/index/mod.rs b/src/core/src/index/mod.rs index 591a4baae7..66d89dc70f 100644 --- a/src/core/src/index/mod.rs +++ b/src/core/src/index/mod.rs @@ -14,7 +14,9 @@ pub mod search; use std::ops::Deref; use std::path::Path; use std::rc::Rc; +use std::sync::atomic::Ordering; +use atomic_float::AtomicF64; use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; @@ -340,7 +342,7 @@ pub enum SearchType { pub struct JaccardSearch { search_type: SearchType, - threshold: f64, + threshold: AtomicF64, require_scaled: bool, } @@ -354,17 +356,67 @@ impl JaccardSearch { JaccardSearch { search_type, require_scaled, - threshold: 0., + threshold: AtomicF64::new(0.0), } } pub fn with_threshold(search_type: SearchType, threshold: f64) -> Self { - let mut s = Self::new(search_type); + let s = Self::new(search_type); s.set_threshold(threshold); s } - pub fn set_threshold(&mut self, threshold: f64) { - self.threshold = threshold; + pub fn set_threshold(&self, threshold: f64) { + self.threshold + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(threshold)) + .unwrap(); + } + + pub fn check_is_compatible(&self, sig: &Signature) -> Result<(), Error> { + // TODO: implement properly + Ok(()) + } + + pub fn score( + &self, + query_size: u64, + shared_size: u64, + subject_size: u64, + total_size: u64, + ) -> f64 { + let shared_size = shared_size as f64; + match self.search_type { + SearchType::Jaccard => shared_size / total_size as f64, + SearchType::Containment => { + if query_size == 0 { + 0.0 + } else { + shared_size / query_size as f64 + } + } + SearchType::MaxContainment => { + let min_denom = query_size.min(subject_size); + if min_denom == 0 { + 0.0 + } else { + shared_size / min_denom as f64 + } + } + } + } + + /// Return True if this match should be collected. + pub fn collect(&self, score: f64, subj: &Signature) -> bool { + true + } + + /// Return true if this score meets or exceeds the threshold. + /// + /// Note: this can be used whenever a score or estimate is available + /// (e.g. internal nodes on an SBT). `collect(...)`, below, decides + /// whether a particular signature should be collected, and/or can + /// update the threshold (used for BestOnly behavior). + pub fn passes(&self, score: f64) -> bool { + score > 0. && score >= self.threshold.load(Ordering::SeqCst) } } diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 9c907f481e..f03e5c3798 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -164,7 +164,7 @@ def _as_rust(self): return rustcall( lib.searchfn_new, - self.search_type.value(), + self.search_type.value, self.threshold, )