Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: improve constructing dist table #3155

Merged
merged 4 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions rust/lance-index/benches/pq_dist_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use arrow_array::types::Float32Type;
use arrow_array::{FixedSizeListArray, UInt8Array};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::pq::distance::*;
use lance_index::vector::pq::ProductQuantizer;
use lance_linalg::distance::DistanceType;
use lance_testing::datagen::generate_random_array_with_seed;
Expand All @@ -21,7 +22,52 @@ const PQ: usize = 96;
const DIM: usize = 1536;
const TOTAL: usize = 16 * 1000;

fn dist_table(c: &mut Criterion) {
fn construct_dist_table(c: &mut Criterion) {
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
let query = generate_random_array_with_seed::<Float32Type>(DIM, [32; 32]);

c.bench_function(
format!(
"construct_dist_table: {},PQ={},DIM={}",
DistanceType::L2,
PQ,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(build_distance_table_l2(
codebook.values(),
8,
PQ,
query.values(),
));
})
},
);

c.bench_function(
format!(
"construct_dist_table: {},PQ={},DIM={}",
DistanceType::Dot,
PQ,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(build_distance_table_dot(
codebook.values(),
8,
PQ,
query.values(),
));
})
},
);
}

fn compute_distances(c: &mut Criterion) {
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
let query = generate_random_array_with_seed::<Float32Type>(DIM, [32; 32]);

Expand All @@ -38,7 +84,7 @@ fn dist_table(c: &mut Criterion) {
);

c.bench_function(
format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
format!("compute_distances: {},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
|b| {
b.iter(|| {
black_box(pq.compute_distances(&query, &code).unwrap());
Expand All @@ -53,12 +99,12 @@ criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = dist_table);
targets = construct_dist_table, compute_distances);

#[cfg(not(target_os = "linux"))]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = dist_table);
targets = construct_dist_table, compute_distances);

criterion_main!(benches);
57 changes: 43 additions & 14 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADA
use tracing::instrument;

pub mod builder;
mod distance;
pub mod distance;
pub mod storage;
pub mod transform;
pub(crate) mod utils;
Expand Down Expand Up @@ -96,6 +96,26 @@ impl ProductQuantizer {

#[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
where
T::Native: Float + L2 + Dot,
{
match self.num_bits {
4 => self.transform_impl::<4, T>(vectors),
8 => self.transform_impl::<8, T>(vectors),
_ => Err(Error::Index {
message: format!(
"ProductQuantization: num_bits {} not supported",
self.num_bits
),
location: location!(),
}),
}
}

fn transform_impl<const NUM_BITS: u32, T: ArrowPrimitiveType>(
&self,
vectors: &dyn Array,
) -> Result<ArrayRef>
where
T::Native: Float + L2 + Dot,
{
Expand All @@ -108,8 +128,7 @@ impl ProductQuantizer {
})?;
let num_sub_vectors = self.num_sub_vectors;
let dim = self.dimension;
let num_bits = self.num_bits;
if num_bits == 4 && num_sub_vectors % 2 != 0 {
if NUM_BITS == 4 && num_sub_vectors % 2 != 0 {
return Err(Error::Index {
message: format!(
"PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}",
Expand All @@ -132,17 +151,16 @@ impl ProductQuantizer {
.chunks_exact(sub_dim)
.enumerate()
.map(|(sub_idx, sub_vector)| {
let centroids = get_sub_vector_centroids(
let centroids = get_sub_vector_centroids::<NUM_BITS, _>(
codebook.values(),
dim,
num_bits,
num_sub_vectors,
sub_idx,
);
compute_partition(centroids, sub_vector, distance_type).unwrap() as u8
})
.collect::<Vec<_>>();
if num_bits == 4 {
if NUM_BITS == 4 {
sub_vec_code
.chunks_exact(2)
.map(|v| (v[1] << 4) | v[0])
Expand All @@ -153,7 +171,7 @@ impl ProductQuantizer {
})
.collect::<Vec<_>>();

let num_sub_vectors_in_byte = if num_bits == 4 {
let num_sub_vectors_in_byte = if NUM_BITS == 4 {
num_sub_vectors / 2
} else {
num_sub_vectors
Expand Down Expand Up @@ -321,13 +339,24 @@ impl ProductQuantizer {
///
/// Returns a flatten `num_centroids * sub_vector_width` f32 array.
pub fn centroids<T: ArrowPrimitiveType>(&self, sub_vector_idx: usize) -> &[T::Native] {
get_sub_vector_centroids(
self.codebook.values().as_primitive::<T>().values(),
self.dimension,
self.num_bits,
self.num_sub_vectors,
sub_vector_idx,
)
match self.num_bits {
4 => get_sub_vector_centroids::<4, _>(
self.codebook.values().as_primitive::<T>().values(),
self.dimension,
self.num_sub_vectors,
sub_vector_idx,
),
8 => get_sub_vector_centroids::<8, _>(
self.codebook.values().as_primitive::<T>().values(),
self.dimension,
self.num_sub_vectors,
sub_vector_idx,
),
_ => panic!(
"ProductQuantization: num_bits {} not supported",
self.num_bits
),
}
}
}

Expand Down
40 changes: 35 additions & 5 deletions rust/lance-index/src/vector/pq/distance.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use core::panic;
use std::cmp::min;

use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2};
use lance_table::utils::LanceIteratorExtension;

use super::{num_centroids, utils::get_sub_vector_centroids};

/// Build a Distance Table from the query to each PQ centroid
/// using L2 distance.
pub(super) fn build_distance_table_l2<T: L2>(
pub fn build_distance_table_l2<T: L2>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
match num_bits {
4 => build_distance_table_l2_impl::<4, T>(codebook, num_sub_vectors, query),
8 => build_distance_table_l2_impl::<8, T>(codebook, num_sub_vectors, query),
_ => panic!("Unsupported number of bits: {}", num_bits),
}
}

#[inline]
pub fn build_distance_table_l2_impl<const NUM_BITS: u32, T: L2>(
codebook: &[T],
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i);
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
l2_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.exact_size(num_sub_vectors * num_centroids)
.collect()
}

/// Build a Distance Table from the query to each PQ centroid
/// using Dot distance.
pub(super) fn build_distance_table_dot<T: Dot>(
pub fn build_distance_table_dot<T: Dot>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
match num_bits {
4 => build_distance_table_dot_impl::<4, T>(codebook, num_sub_vectors, query),
8 => build_distance_table_dot_impl::<8, T>(codebook, num_sub_vectors, query),
_ => panic!("Unsupported number of bits: {}", num_bits),
}
}

pub fn build_distance_table_dot_impl<const NUM_BITS: u32, T: Dot>(
codebook: &[T],
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i);
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.exact_size(num_sub_vectors * num_centroids)
.collect()
}

Expand Down
7 changes: 3 additions & 4 deletions rust/lance-index/src/vector/pq/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,20 @@ pub fn num_centroids(num_bits: impl Into<u32>) -> usize {
}

#[inline]
pub fn get_sub_vector_centroids<T>(
pub fn get_sub_vector_centroids<const NUM_BITS: u32, T>(
codebook: &[T],
dimension: usize,
num_bits: impl Into<u32>,
num_sub_vectors: usize,
sub_vector_idx: usize,
) -> &[T] {
assert!(
debug_assert!(
sub_vector_idx < num_sub_vectors,
"sub_vector idx: {}, num_sub_vectors: {}",
sub_vector_idx,
num_sub_vectors
);

let num_centroids = num_centroids(num_bits);
let num_centroids: usize = 2_usize.pow(NUM_BITS);
let sub_vector_width = dimension / num_sub_vectors;
&codebook[sub_vector_idx * num_centroids * sub_vector_width
..(sub_vector_idx + 1) * num_centroids * sub_vector_width]
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-linalg/src/distance/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ pub fn l2_distance_batch<'a, T: L2>(
debug_assert_eq!(from.len(), dimension);
debug_assert_eq!(to.len() % dimension, 0);

Box::new(T::l2_batch(from, to, dimension))
T::l2_batch(from, to, dimension)
}

fn do_l2_distance_arrow_batch<T: ArrowFloatType>(
Expand Down
Loading