Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use log crate instead of println #3

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rand = "0.9.0-alpha.2"
ndarray-rand = "0.15.0"
rand_distr = "0.4.3"
rayon = "1.10.0"
log = "0.4.22"

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn main() -> Result<()> {
// Configure PQ parameters
let m = 8; // Number of subspaces (controls compression ratio)
let ks = 256; // Number of centroids per subspace (usually 256 for uint8)
let mut pq = PQ::try_new(m, ks, Some(true))?;
let mut pq = PQ::try_new(m, ks)?;

// Train the quantizer on the data
println!("Training PQ model...");
Expand Down
2 changes: 1 addition & 1 deletion src/bin/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() -> Result<()> {
let ks = 256; // Number of clusters per subspace
let verbose = Some(true);

let mut pq = PQ::try_new(m, ks, verbose)?;
let mut pq = PQ::try_new(m, ks)?;

// Step 3: Train the PQ Model
let iterations = 20; // Number of iterations for k-means
Expand Down
2 changes: 1 addition & 1 deletion src/bin/readme_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() -> Result<()> {
// Configure PQ parameters
let m = 8; // Number of subspaces (controls compression ratio)
let ks = 256; // Number of centroids per subspace (usually 256 for uint8)
let mut pq = PQ::try_new(m, ks, Some(true))?;
let mut pq = PQ::try_new(m, ks)?;

// Train the quantizer on the data
println!("Training PQ model...");
Expand Down
49 changes: 23 additions & 26 deletions src/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use anyhow::Result;
use ndarray::parallel::prelude::*;
use ndarray::{s, Array2, Array3, Axis};
use rayon::prelude::*;
use log::{debug, error, info, trace, warn};

#[derive(Debug, Clone, Copy)]
pub enum CodeType {
Expand All @@ -14,15 +15,14 @@ pub enum CodeType {
pub struct PQ {
m: usize,
ks: u32,
verbose: bool,
code_dtype: CodeType,
codewords: Option<Array3<f32>>,
ds: Option<Vec<usize>>,
dim: Option<usize>,
}

impl PQ {
pub fn try_new(m: usize, ks: u32, verbose: Option<bool>) -> Result<Self> {
pub fn try_new(m: usize, ks: u32) -> Result<Self> {
if ks == 0 {
anyhow::bail!(
"cluster subspaces (ks) must be a u32 between 1 and 2**32 - 1. Got {}",
Expand All @@ -37,7 +37,6 @@ impl PQ {
Ok(Self {
m,
ks,
verbose: verbose.unwrap_or(false),
code_dtype: determine_code_type(ks),
codewords: None,
ds: None,
Expand Down Expand Up @@ -94,15 +93,13 @@ impl PQ {
let trained_codewords: Vec<(usize, Array2<f32>)> = (0..self.m)
.into_par_iter()
.map(|m| {
if self.verbose {
println!(
"# Training the subspace: {} / {}, {} -> {}",
m,
self.m,
self.ds.as_ref().unwrap()[m],
self.ds.as_ref().unwrap()[m + 1]
);
}
info!(
"Training the subspace: {} / {}, {} -> {}",
m,
self.m,
self.ds.as_ref().unwrap()[m],
self.ds.as_ref().unwrap()[m + 1]
);

let ds_ref = self.ds.as_ref().unwrap();

Expand Down Expand Up @@ -256,13 +253,13 @@ mod tests {
// Edge case: ks is zero or exceeds u32 limits.
#[test]
fn test_try_new_invalid_ks_zero() {
let pq = PQ::try_new(4, 0, None);
let pq = PQ::try_new(4, 0);
assert!(pq.is_err(), "Initialization should fail when ks is zero");
}

#[test]
fn test_try_new_invalid_ks_max() {
let pq = PQ::try_new(4, u32::MAX, None);
let pq = PQ::try_new(4, u32::MAX);
assert!(
pq.is_ok(),
"Initialization should succeed when ks is u32::MAX"
Expand All @@ -272,7 +269,7 @@ mod tests {
// Edge Case: m is zero.
#[test]
fn test_try_new_invalid_m_zero() {
let pq = PQ::try_new(0, 256, None);
let pq = PQ::try_new(0, 256);
assert!(
pq.is_err(),
"Initialization should fail when m is zero, but it succeeded"
Expand All @@ -282,7 +279,7 @@ mod tests {
// Edge Case: Number of training vectors is less than ks.
#[test]
fn test_fit_vectors_less_than_ks() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(100, 128); // Less than ks
let result = pq.fit(&vecs, 10);
assert!(
Expand All @@ -294,7 +291,7 @@ mod tests {
// Edge Case: Vectors have zero dimensions or m exceeds vector dimensions.
#[test]
fn test_fit_zero_dimensions() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(1000, 0); // Zero dimensions
let result = pq.fit(&vecs, 10);
assert!(
Expand All @@ -305,7 +302,7 @@ mod tests {

#[test]
fn test_fit_m_greater_than_dimensions() {
let mut pq = PQ::try_new(200, 256, None).unwrap();
let mut pq = PQ::try_new(200, 256).unwrap();
let vecs = create_dummy_vectors(1000, 128); // m > dimensions
let result = pq.fit(&vecs, 10);
assert!(
Expand All @@ -317,7 +314,7 @@ mod tests {
// Edge Case: Calling encode before fit.
#[test]
fn test_encode_without_fit() {
let pq = PQ::try_new(4, 256, None).unwrap();
let pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(1000, 128);
let result = pq.encode(&vecs);
assert!(
Expand All @@ -329,7 +326,7 @@ mod tests {
// Edge Case: Vectors have different dimensions than those used in fit.
#[test]
fn test_encode_mismatched_dimensions() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let train_vecs = create_dummy_vectors(1000, 128);
pq.fit(&train_vecs, 10).unwrap();

Expand All @@ -344,7 +341,7 @@ mod tests {
// Edge Case: Codes have incorrect dimensions or contain invalid values.
#[test]
fn test_decode_invalid_code_m() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let train_vecs = create_dummy_vectors(1000, 128);
pq.fit(&train_vecs, 10).unwrap();

Expand All @@ -358,7 +355,7 @@ mod tests {

#[test]
fn test_decode_code_value_exceeds_ks() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let train_vecs = create_dummy_vectors(1000, 128);
pq.fit(&train_vecs, 10).unwrap();

Expand All @@ -374,7 +371,7 @@ mod tests {
// Edge Case: Ensuring compress works end-to-end.
#[test]
fn test_compress() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(1000, 128);
pq.fit(&vecs, 10).unwrap();

Expand All @@ -389,7 +386,7 @@ mod tests {
// Edge Case: Ensuring code values fit within specified data types.
#[test]
fn test_encode_code_dtype_u8_overflow() {
let mut pq = PQ::try_new(4, 300, None).unwrap(); // ks exceeds u8::MAX
let mut pq = PQ::try_new(4, 300).unwrap(); // ks exceeds u8::MAX
pq.code_dtype = CodeType::U8;
let vecs = create_random_vectors(1000, 128);
pq.fit(&vecs, 10).unwrap();
Expand All @@ -403,7 +400,7 @@ mod tests {

#[test]
fn test_encode_code_dtype_u16_overflow() {
let mut pq = PQ::try_new(4, 70000, None).unwrap();
let mut pq = PQ::try_new(4, 70000).unwrap();
pq.code_dtype = CodeType::U16;
pq.codewords = Some(Array3::zeros((pq.m, pq.ks as usize, 128 / pq.m)));
pq.dim = Some(128);
Expand All @@ -418,7 +415,7 @@ mod tests {

#[test]
fn test_encode_code_dtype_u8_valid() {
let mut pq = PQ::try_new(4, 200, None).unwrap(); // ks within u8::MAX
let mut pq = PQ::try_new(4, 200).unwrap(); // ks within u8::MAX
pq.code_dtype = CodeType::U8;
let vecs = create_random_vectors(1000, 128);
pq.fit(&vecs, 10).unwrap();
Expand Down
Loading