Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: faster IVF & PQ #328

Merged
merged 22 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ebfe90d
add multi-threading for kmeans
whateveraname Jan 26, 2024
4f29b9b
add multi-threading for IVF and PQ index build
whateveraname Jan 27, 2024
cc33815
better memory layout for IVF
whateveraname Feb 1, 2024
7fef203
fix codes
whateveraname Feb 2, 2024
ad0c76a
remove unused code
whateveraname Feb 3, 2024
37d67ff
Merge branch 'main' into refactor/faster-ivf-pq
VoVAllen Feb 2, 2024
a13a590
add multi-threading for kmeans
whateveraname Jan 26, 2024
5ade8fe
add multi-threading for IVF and PQ index build
whateveraname Jan 27, 2024
fd5a16c
better memory layout for IVF
whateveraname Feb 1, 2024
feedefa
fix codes
whateveraname Feb 2, 2024
155b0b2
remove unused code
whateveraname Feb 3, 2024
f56e5cc
parallelize over subquantizers for pq
whateveraname Feb 3, 2024
63a0d87
merge
whateveraname Feb 3, 2024
e47695c
Merge branch 'main' into refactor/faster-ivf-pq
whateveraname Feb 6, 2024
187694b
add tale lookup search for ivfpq with l2 distance
whateveraname Feb 7, 2024
2cfed8b
Merge branch 'refactor/faster-ivf-pq' of github.com:whateveraname/pgv…
whateveraname Feb 7, 2024
3a44588
Merge branch 'main' into refactor/faster-ivf-pq
whateveraname Feb 7, 2024
dab1185
refine codes
whateveraname Feb 8, 2024
899540a
Merge branch 'refactor/faster-ivf-pq' of github.com:whateveraname/pgv…
whateveraname Feb 8, 2024
6c000ec
move comment to github
whateveraname Feb 8, 2024
a32f0eb
use table lookup for ip distance
whateveraname Feb 21, 2024
583b345
Merge branch 'main' into refactor/faster-ivf-pq
whateveraname Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions crates/service/src/algorithms/clustering/elkan_k_means.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::utils::vec2::Vec2;
use base::scalar::FloatCast;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use std::ops::{Index, IndexMut};

pub struct ElkanKMeans<S: G> {
Expand Down Expand Up @@ -32,13 +34,16 @@ impl<S: G> ElkanKMeans<S> {
centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]);

let mut weight = vec![F32::infinity(); n];
let mut dis = vec![F32::zero(); n];
for i in 0..c {
let mut sum = F32::zero();
dis.par_iter_mut().enumerate().for_each(|(j, x)| {
*x = S::elkan_k_means_distance(&samples[j], &centroids[i]);
});
for j in 0..n {
let dis = S::elkan_k_means_distance(&samples[j], &centroids[i]);
lowerbound[(j, i)] = dis;
if dis * dis < weight[j] {
weight[j] = dis * dis;
lowerbound[(j, i)] = dis[j];
if dis[j] * dis[j] < weight[j] {
weight[j] = dis[j] * dis[j];
}
sum += weight[j];
}
Expand Down Expand Up @@ -132,11 +137,16 @@ impl<S: G> ElkanKMeans<S> {
// Step 1
let mut dist0 = Square::new(c, c);
let mut sp = vec![F32::zero(); c];
for i in 0..c {
for j in i + 1..c {
let dis = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
dist0[(i, j)] = dis;
dist0[(j, i)] = dis;
dist0.v.par_iter_mut().enumerate().for_each(|(ii, v)| {
let i = ii / c;
let j = ii % c;
if i <= j {
*v = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
}
});
for i in 1..c {
for j in 0..i - 1 {
dist0[(i, j)] = dist0[(j, i)];
}
}
for i in 0..c {
Expand All @@ -153,12 +163,18 @@ impl<S: G> ElkanKMeans<S> {
sp[i] = minimal;
}

let mut dis = vec![F32::zero(); n];
dis.par_iter_mut().enumerate().for_each(|(i, x)| {
if upperbound[i] > sp[assign[i]] {
*x = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
}
});
for i in 0..n {
// Step 2
if upperbound[i] <= sp[assign[i]] {
continue;
}
let mut minimal = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
let mut minimal = dis[i];
lowerbound[(i, assign[i])] = minimal;
upperbound[i] = minimal;
// Step 3
Expand Down Expand Up @@ -191,9 +207,9 @@ impl<S: G> ElkanKMeans<S> {
centroids.fill(S::Scalar::zero());
for i in 0..n {
for j in 0..dims as usize {
centroids[assign[i]][j] += samples[i][j];
centroids[self.assign[i]][j] += samples[i][j];
}
count[assign[i]] += 1.0;
count[self.assign[i]] += 1.0;
}
for i in 0..c {
if count[i] == F32::zero() {
Expand Down Expand Up @@ -229,22 +245,23 @@ impl<S: G> ElkanKMeans<S> {
count[i] = count[o] / 2.0;
count[o] = count[o] - count[i];
}
for i in 0..c {
S::elkan_k_means_normalize(&mut centroids[i]);
}
centroids.par_chunks_mut(dims as usize).for_each(|v| {
S::elkan_k_means_normalize(v);
});

// Step 5, 6
let mut dist1 = vec![F32::zero(); c];
for i in 0..c {
dist1[i] = S::elkan_k_means_distance(&old[i], &centroids[i]);
}
dist1.par_iter_mut().enumerate().for_each(|(i, v)| {
*v = S::elkan_k_means_distance(&old[i], &centroids[i]);
});
for i in 0..n {
for j in 0..c {
lowerbound[(i, j)] = std::cmp::max(lowerbound[(i, j)] - dist1[j], F32::zero());
self.lowerbound[(i, j)] =
std::cmp::max(self.lowerbound[(i, j)] - dist1[j], F32::zero());
}
}
for i in 0..n {
upperbound[i] += dist1[assign[i]];
self.upperbound[i] += dist1[self.assign[i]];
}

change == 0
Expand Down
1 change: 1 addition & 0 deletions crates/service/src/algorithms/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub fn make<S: G>(
options.clone(),
idx_opts.quantization,
&raw,
(0..raw.len()).collect::<Vec<_>>(),
);
FlatRam { raw, quantization }
}
Expand Down
1 change: 1 addition & 0 deletions crates/service/src/algorithms/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub fn make<S: G>(
options.clone(),
quantization_opts,
&raw,
(0..raw.len()).collect::<Vec<_>>(),
);
let n = raw.len();
let graph = HnswRamGraph {
Expand Down
124 changes: 57 additions & 67 deletions crates/service/src/algorithms/ivf/ivf_naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ use crate::index::IndexOptions;
use crate::index::SearchOptions;
use crate::index::VectorOptions;
use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::dir_ops::sync_dir;
use crate::utils::element_heap::ElementHeap;
use crate::utils::mmap_array::MmapArray;
use crate::utils::vec2::Vec2;
use rand::seq::index::sample;
use rand::thread_rng;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
use rayon::prelude::ParallelIterator;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::fs::create_dir;
use std::path::Path;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use std::sync::Arc;

pub struct IvfNaive<S: G> {
Expand Down Expand Up @@ -90,8 +88,8 @@ pub struct IvfRam<S: G> {
nlist: u32,
// ----------------------
centroids: Vec2<S::Scalar>,
heads: Vec<AtomicU32>,
nexts: Vec<SyncUnsafeCell<u32>>,
ptr: Vec<usize>,
payloads: Vec<Payload>,
}

unsafe impl<S: G> Send for IvfRam<S> {}
Expand All @@ -106,8 +104,8 @@ pub struct IvfMmap<S: G> {
nlist: u32,
// ----------------------
centroids: MmapArray<S::Scalar>,
heads: MmapArray<u32>,
nexts: MmapArray<u32>,
ptr: MmapArray<usize>,
payloads: MmapArray<Payload>,
}

unsafe impl<S: G> Send for IvfMmap<S> {}
Expand Down Expand Up @@ -141,12 +139,6 @@ pub fn make<S: G>(
sealed,
growing,
));
let quantization = Quantization::open(
&path.join("quantization"),
options.clone(),
quantization_opts,
&raw,
);
let n = raw.len();
let m = std::cmp::min(nsample, n);
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
Expand All @@ -165,70 +157,68 @@ pub fn make<S: G>(
}
}
let centroids = k_means.finish();
let heads = {
let mut heads = Vec::with_capacity(nlist as usize);
heads.resize_with(nlist as usize, || AtomicU32::new(u32::MAX));
heads
};
let nexts = {
let mut nexts = Vec::with_capacity(nlist as usize);
nexts.resize_with(n as usize, || SyncUnsafeCell::new(u32::MAX));
nexts
};
(0..n).into_par_iter().for_each(|i| {
let mut vector = S::ref_to_owned(raw.vector(i));
let mut idx = vec![0usize; n as usize];
idx.par_iter_mut().enumerate().for_each(|(i, x)| {
let mut vector = S::ref_to_owned(raw.vector(i as u32));
S::elkan_k_means_normalize2(&mut vector);
let mut result = (F32::infinity(), 0);
for i in 0..nlist {
let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), &centroids[i as usize]);
for i in 0..nlist as usize {
let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), &centroids[i]);
result = std::cmp::min(result, (dis, i));
}
let centroid_id = result.1;
loop {
let next = heads[centroid_id as usize].load(Acquire);
unsafe {
nexts[i as usize].get().write(next);
}
let o = &heads[centroid_id as usize];
if o.compare_exchange(next, i, Release, Relaxed).is_ok() {
break;
}
}
*x = result.1;
});
let mut invlists_ids = vec![Vec::new(); nlist as usize];
let mut invlists_payloads = vec![Vec::new(); nlist as usize];
for i in 0..n {
invlists_ids[idx[i as usize]].push(i);
invlists_payloads[idx[i as usize]].push(raw.payload(i));
}
let permutation = Vec::from_iter((0..nlist).flat_map(|i| &invlists_ids[i as usize]).copied());
let payloads = Vec::from_iter(
(0..nlist)
.flat_map(|i| &invlists_payloads[i as usize])
.copied(),
);
let quantization = Quantization::create(
&path.join("quantization"),
options.clone(),
quantization_opts,
&raw,
permutation,
);
let mut ptr = vec![0usize; nlist as usize + 1];
for i in 0..nlist {
ptr[i as usize + 1] = ptr[i as usize] + invlists_ids[i as usize].len();
}
IvfRam {
raw,
quantization,
centroids,
heads,
nexts,
nlist,
dims,
ptr,
payloads,
}
}

pub fn save<S: G>(mut ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
pub fn save<S: G>(ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
let centroids = MmapArray::create(
&path.join("centroids"),
(0..ram.nlist)
.flat_map(|i| &ram.centroids[i as usize])
.copied(),
);
let heads = MmapArray::create(
&path.join("heads"),
ram.heads.iter_mut().map(|x| *x.get_mut()),
);
let nexts = MmapArray::create(
&path.join("nexts"),
ram.nexts.iter_mut().map(|x| *x.get_mut()),
);
let ptr = MmapArray::create(&path.join("ptr"), ram.ptr.iter().copied());
let payloads = MmapArray::create(&path.join("payload"), ram.payloads.iter().copied());
IvfMmap {
raw: ram.raw,
quantization: ram.quantization,
dims: ram.dims,
nlist: ram.nlist,
centroids,
heads,
nexts,
ptr,
payloads,
}
}

Expand All @@ -241,17 +231,17 @@ pub fn open<S: G>(path: &Path, options: IndexOptions) -> IvfMmap<S> {
&raw,
);
let centroids = MmapArray::open(&path.join("centroids"));
let heads = MmapArray::open(&path.join("heads"));
let nexts = MmapArray::open(&path.join("nexts"));
let ptr = MmapArray::open(&path.join("ptr"));
let payloads = MmapArray::open(&path.join("payload"));
let IvfIndexingOptions { nlist, .. } = options.indexing.unwrap_ivf();
IvfMmap {
raw,
quantization,
dims: options.vector.dims,
nlist,
centroids,
heads,
nexts,
ptr,
payloads,
}
}

Expand All @@ -277,14 +267,14 @@ pub fn basic<S: G>(
let lists = lists.into_sorted_vec();
let mut result = BinaryHeap::new();
for i in lists.iter().map(|e| e.payload as usize) {
let mut j = mmap.heads[i];
while u32::MAX != j {
let payload = mmap.raw.payload(j);
let start = mmap.ptr[i];
let end = mmap.ptr[i + 1];
for j in start..end {
let payload = mmap.payloads[j];
if filter.check(payload) {
let distance = mmap.quantization.distance(vector, j);
let distance = mmap.quantization.distance(vector, j as u32);
result.push(Reverse(Element { distance, payload }));
}
j = mmap.nexts[j as usize];
}
}
result
Expand All @@ -311,15 +301,15 @@ pub fn vbase<'a, S: G>(
}
let lists = lists.into_sorted_vec();
let mut result = Vec::new();
for i in lists.iter().map(|e| e.payload as u32) {
let mut j = mmap.heads[i as usize];
while u32::MAX != j {
let payload = mmap.raw.payload(j);
for i in lists.iter().map(|e| e.payload as usize) {
let start = mmap.ptr[i];
let end = mmap.ptr[i + 1];
for j in start..end {
let payload = mmap.payloads[j];
if filter.check(payload) {
let distance = mmap.quantization.distance(vector, j);
let distance = mmap.quantization.distance(vector, j as u32);
result.push(Element { distance, payload });
}
j = mmap.nexts[j as usize];
}
}
(result, Box::new(std::iter::empty()))
Expand Down
Loading