Skip to content

Commit

Permalink
perf: improve constructing dist table (#3155)
Browse files Browse the repository at this point in the history
```
construc_dist_table: 16000,l2,PQ=96,DIM=1536
                        time:   [157.75 µs 157.97 µs 158.19 µs]
                        change: [-28.865% -28.731% -28.581%] (p = 0.00 < 0.10)
                        Performance has improved.

construc_dist_table: 16000,dot,PQ=96,DIM=1536
                        time:   [275.56 µs 276.02 µs 276.47 µs]
                        change: [-2.8930% -2.4637% -2.0839%] (p = 0.00 < 0.10)
                        Performance has improved.
```

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
  • Loading branch information
BubbleCal authored Nov 23, 2024
1 parent 02b2ca0 commit 439db38
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 28 deletions.
54 changes: 50 additions & 4 deletions rust/lance-index/benches/pq_dist_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use arrow_array::types::Float32Type;
use arrow_array::{FixedSizeListArray, UInt8Array};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::pq::distance::*;
use lance_index::vector::pq::ProductQuantizer;
use lance_linalg::distance::DistanceType;
use lance_testing::datagen::generate_random_array_with_seed;
Expand All @@ -21,7 +22,52 @@ const PQ: usize = 96;
const DIM: usize = 1536;
const TOTAL: usize = 16 * 1000;

fn dist_table(c: &mut Criterion) {
fn construct_dist_table(c: &mut Criterion) {
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
let query = generate_random_array_with_seed::<Float32Type>(DIM, [32; 32]);

c.bench_function(
format!(
"construct_dist_table: {},PQ={},DIM={}",
DistanceType::L2,
PQ,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(build_distance_table_l2(
codebook.values(),
8,
PQ,
query.values(),
));
})
},
);

c.bench_function(
format!(
"construct_dist_table: {},PQ={},DIM={}",
DistanceType::Dot,
PQ,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(build_distance_table_dot(
codebook.values(),
8,
PQ,
query.values(),
));
})
},
);
}

fn compute_distances(c: &mut Criterion) {
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
let query = generate_random_array_with_seed::<Float32Type>(DIM, [32; 32]);

Expand All @@ -38,7 +84,7 @@ fn dist_table(c: &mut Criterion) {
);

c.bench_function(
format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
format!("compute_distances: {},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
|b| {
b.iter(|| {
black_box(pq.compute_distances(&query, &code).unwrap());
Expand All @@ -53,12 +99,12 @@ criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = dist_table);
targets = construct_dist_table, compute_distances);

#[cfg(not(target_os = "linux"))]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = dist_table);
targets = construct_dist_table, compute_distances);

criterion_main!(benches);
57 changes: 43 additions & 14 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADA
use tracing::instrument;

pub mod builder;
mod distance;
pub mod distance;
pub mod storage;
pub mod transform;
pub(crate) mod utils;
Expand Down Expand Up @@ -96,6 +96,26 @@ impl ProductQuantizer {

#[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
where
T::Native: Float + L2 + Dot,
{
match self.num_bits {
4 => self.transform_impl::<4, T>(vectors),
8 => self.transform_impl::<8, T>(vectors),
_ => Err(Error::Index {
message: format!(
"ProductQuantization: num_bits {} not supported",
self.num_bits
),
location: location!(),
}),
}
}

fn transform_impl<const NUM_BITS: u32, T: ArrowPrimitiveType>(
&self,
vectors: &dyn Array,
) -> Result<ArrayRef>
where
T::Native: Float + L2 + Dot,
{
Expand All @@ -108,8 +128,7 @@ impl ProductQuantizer {
})?;
let num_sub_vectors = self.num_sub_vectors;
let dim = self.dimension;
let num_bits = self.num_bits;
if num_bits == 4 && num_sub_vectors % 2 != 0 {
if NUM_BITS == 4 && num_sub_vectors % 2 != 0 {
return Err(Error::Index {
message: format!(
"PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}",
Expand All @@ -132,17 +151,16 @@ impl ProductQuantizer {
.chunks_exact(sub_dim)
.enumerate()
.map(|(sub_idx, sub_vector)| {
let centroids = get_sub_vector_centroids(
let centroids = get_sub_vector_centroids::<NUM_BITS, _>(
codebook.values(),
dim,
num_bits,
num_sub_vectors,
sub_idx,
);
compute_partition(centroids, sub_vector, distance_type).unwrap() as u8
})
.collect::<Vec<_>>();
if num_bits == 4 {
if NUM_BITS == 4 {
sub_vec_code
.chunks_exact(2)
.map(|v| (v[1] << 4) | v[0])
Expand All @@ -153,7 +171,7 @@ impl ProductQuantizer {
})
.collect::<Vec<_>>();

let num_sub_vectors_in_byte = if num_bits == 4 {
let num_sub_vectors_in_byte = if NUM_BITS == 4 {
num_sub_vectors / 2
} else {
num_sub_vectors
Expand Down Expand Up @@ -321,13 +339,24 @@ impl ProductQuantizer {
///
/// Returns a flatten `num_centroids * sub_vector_width` f32 array.
pub fn centroids<T: ArrowPrimitiveType>(&self, sub_vector_idx: usize) -> &[T::Native] {
get_sub_vector_centroids(
self.codebook.values().as_primitive::<T>().values(),
self.dimension,
self.num_bits,
self.num_sub_vectors,
sub_vector_idx,
)
match self.num_bits {
4 => get_sub_vector_centroids::<4, _>(
self.codebook.values().as_primitive::<T>().values(),
self.dimension,
self.num_sub_vectors,
sub_vector_idx,
),
8 => get_sub_vector_centroids::<8, _>(
self.codebook.values().as_primitive::<T>().values(),
self.dimension,
self.num_sub_vectors,
sub_vector_idx,
),
_ => panic!(
"ProductQuantization: num_bits {} not supported",
self.num_bits
),
}
}
}

Expand Down
40 changes: 35 additions & 5 deletions rust/lance-index/src/vector/pq/distance.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use core::panic;
use std::cmp::min;

use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2};
use lance_table::utils::LanceIteratorExtension;

use super::{num_centroids, utils::get_sub_vector_centroids};

/// Build a Distance Table from the query to each PQ centroid
/// using L2 distance.
pub(super) fn build_distance_table_l2<T: L2>(
pub fn build_distance_table_l2<T: L2>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
match num_bits {
4 => build_distance_table_l2_impl::<4, T>(codebook, num_sub_vectors, query),
8 => build_distance_table_l2_impl::<8, T>(codebook, num_sub_vectors, query),
_ => panic!("Unsupported number of bits: {}", num_bits),
}
}

#[inline]
pub fn build_distance_table_l2_impl<const NUM_BITS: u32, T: L2>(
codebook: &[T],
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i);
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
l2_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.exact_size(num_sub_vectors * num_centroids)
.collect()
}

/// Build a Distance Table from the query to each PQ centroid
/// using Dot distance.
pub(super) fn build_distance_table_dot<T: Dot>(
pub fn build_distance_table_dot<T: Dot>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
match num_bits {
4 => build_distance_table_dot_impl::<4, T>(codebook, num_sub_vectors, query),
8 => build_distance_table_dot_impl::<8, T>(codebook, num_sub_vectors, query),
_ => panic!("Unsupported number of bits: {}", num_bits),
}
}

pub fn build_distance_table_dot_impl<const NUM_BITS: u32, T: Dot>(
codebook: &[T],
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i);
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.exact_size(num_sub_vectors * num_centroids)
.collect()
}

Expand Down
7 changes: 3 additions & 4 deletions rust/lance-index/src/vector/pq/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,20 @@ pub fn num_centroids(num_bits: impl Into<u32>) -> usize {
}

#[inline]
pub fn get_sub_vector_centroids<T>(
pub fn get_sub_vector_centroids<const NUM_BITS: u32, T>(
codebook: &[T],
dimension: usize,
num_bits: impl Into<u32>,
num_sub_vectors: usize,
sub_vector_idx: usize,
) -> &[T] {
assert!(
debug_assert!(
sub_vector_idx < num_sub_vectors,
"sub_vector idx: {}, num_sub_vectors: {}",
sub_vector_idx,
num_sub_vectors
);

let num_centroids = num_centroids(num_bits);
let num_centroids: usize = 2_usize.pow(NUM_BITS);
let sub_vector_width = dimension / num_sub_vectors;
&codebook[sub_vector_idx * num_centroids * sub_vector_width
..(sub_vector_idx + 1) * num_centroids * sub_vector_width]
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-linalg/src/distance/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ pub fn l2_distance_batch<'a, T: L2>(
debug_assert_eq!(from.len(), dimension);
debug_assert_eq!(to.len() % dimension, 0);

Box::new(T::l2_batch(from, to, dimension))
T::l2_batch(from, to, dimension)
}

fn do_l2_distance_arrow_batch<T: ArrowFloatType>(
Expand Down

0 comments on commit 439db38

Please sign in to comment.