diff --git a/benches/cakes/src/steps.rs b/benches/cakes/src/steps.rs index 5ad362dc..96e93f55 100644 --- a/benches/cakes/src/steps.rs +++ b/benches/cakes/src/steps.rs @@ -9,7 +9,7 @@ use abd_clam::{ metric::ParMetric, Ball, Cluster, Dataset, FlatVec, }; -use bench_utils::{ann_benchmarks::AnnDataset, reports::CakesResults, types::Row}; +use bench_utils::{ann_benchmarks::AnnDataset, reports::CakesResults}; use distances::Number; use rand::prelude::*; @@ -421,7 +421,7 @@ pub fn read_tabular>( seed: Option, inp_dir: &P, out_dir: &P, -) -> Result<(Vec>, Vec>), String> { +) -> Result<(Vec>, Vec>), String> { let data_path = out_dir.as_ref().join(format!("{}-0.flat_vec", dataset.name())); let queries_path = out_dir.as_ref().join(format!("{}.queries", dataset.name())); let neighbors_path = out_dir.as_ref().join(format!("{}.neighbors", dataset.name())); @@ -443,29 +443,49 @@ pub fn read_tabular>( let (train, queries, neighbors) = (data.train, data.queries, data.neighbors); let (min_dim, max_dim) = train .iter() + .chain(queries.iter()) .fold((usize::MAX, 0), |(min, max), x| (min.min(x.len()), max.max(x.len()))); - let items = train.iter().cloned().map(Row::from).collect(); - let data = FlatVec::new(items)? + let data = FlatVec::new(train)? .with_name(dataset.name()) .with_dim_lower_bound(min_dim) .with_dim_upper_bound(max_dim); ftlog::info!("Writing original data to {data_path:?}..."); data.write_to(&data_path)?; - let queries = queries.iter().cloned().map(Row::from).collect::>(); + let npy_name = format!("{}-0.npy", dataset.name()); + ftlog::info!( + "Writing original data as npy array to {:?}/{npy_name}...", + out_dir.as_ref() + ); + data.write_npy(out_dir, &npy_name)?; + + let query_data = FlatVec::new(queries)? + .with_name(&format!("{}-queries", dataset.name())) + .with_dim_lower_bound(min_dim) + .with_dim_upper_bound(max_dim); + let query_npy_name = format!("{}.npy", query_data.name()); + ftlog::info!( + "Writing queries as npy array to {:?}/{query_npy_name}...", + out_dir.as_ref() + ); + query_data.write_npy(out_dir, &query_npy_name)?; + let queries = query_data.take_items(); + std::fs::write(&queries_path, bitcode::encode(&queries).map_err(|e| e.to_string())?) .map_err(|e| e.to_string())?; std::fs::write(&neighbors_path, bitcode::encode(&neighbors).map_err(|e| e.to_string())?) .map_err(|e| e.to_string())?; - let base_cardinality = data.cardinality(); + let train = data.take_items(); + let base_cardinality = train.len(); let data = AnnDataset { train, queries: Vec::new(), neighbors: Vec::new(), } .augment(1 << max_power, 0.1); + let data = FlatVec::new(data.train)? .with_dim_lower_bound(min_dim) .with_dim_upper_bound(max_dim); @@ -477,6 +497,14 @@ pub fn read_tabular>( let data_path = out_dir.as_ref().join(format!("{name}.flat_vec")); ftlog::info!("Writing {}x augmented data to {data_path:?}...", 1 << power); data_sample.write_to(&data_path)?; + + let npy_name = format!("{name}.npy"); + ftlog::info!( + "Writing {}x augmented data as npy array to {:?}/{npy_name}...", + 1 << power, + out_dir.as_ref() + ); + data_sample.write_npy(out_dir, &npy_name)?; } Ok((queries, neighbors)) @@ -489,7 +517,7 @@ pub fn read_member_set>( dataset: &bench_utils::ann_benchmarks::RawData, inp_dir: &P, out_dir: &P, -) -> Result<(std::path::PathBuf, Vec>, Vec>), String> { +) -> Result<(std::path::PathBuf, Vec>, Vec>), String> { let data_name = format!("{}.flat_vec", dataset.name()); let data_path = out_dir.as_ref().join(&data_name); let queries_path = out_dir.as_ref().join(format!("{}.queries", dataset.name())); @@ -511,12 +539,10 @@ pub fn read_member_set>( .iter() .fold((usize::MAX, 0), |(min, max), x| (min.min(x.len()), max.max(x.len()))); - let data = train.into_iter().map(Row::from).collect(); - let data = FlatVec::new(data)? + let data = FlatVec::new(train)? .with_name(&data_name) .with_dim_lower_bound(min_dim) .with_dim_upper_bound(max_dim); - let queries = queries.into_iter().map(Row::from).collect::>(); ftlog::info!("Writing data to {data_path:?}..."); data.write_to(&data_path)?; diff --git a/crates/abd-clam/src/core/dataset/flat_vec.rs b/crates/abd-clam/src/core/dataset/flat_vec.rs index 742d9a4e..5585cfe8 100644 --- a/crates/abd-clam/src/core/dataset/flat_vec.rs +++ b/crates/abd-clam/src/core/dataset/flat_vec.rs @@ -193,6 +193,12 @@ impl FlatVec { &self.items } + /// Takes the items out of the dataset. + #[must_use] + pub fn take_items(self) -> Vec { + self.items + } + /// Transforms the items in the dataset. /// /// # Type Parameters