Skip to content

Commit

Permalink
refactor: faster IVF & PQ (#328)
Browse files Browse the repository at this point in the history
* add multi-threading for kmeans

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* add multi-threading for IVF and PQ index build

fix PQ training for IVF residuals

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* better memory layout for IVF

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* fix codes

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* remove unused code

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* add multi-threading for kmeans

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* add multi-threading for IVF and PQ index build

fix PQ training for IVF residuals

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* better memory layout for IVF

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* fix codes

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* remove unused code

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* parallelize over subquantizers for pq

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* add tale lookup search for ivfpq with l2 distance

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* refine codes

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* move comment to github

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

* use table lookup for ip distance

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>

---------

Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>
Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com>
  • Loading branch information
whateveraname and VoVAllen authored Feb 21, 2024
1 parent 42fa583 commit 3d1621b
Show file tree
Hide file tree
Showing 21 changed files with 502 additions and 215 deletions.
57 changes: 37 additions & 20 deletions crates/service/src/algorithms/clustering/elkan_k_means.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::utils::vec2::Vec2;
use base::scalar::FloatCast;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use std::ops::{Index, IndexMut};

pub struct ElkanKMeans<S: G> {
Expand Down Expand Up @@ -32,13 +34,16 @@ impl<S: G> ElkanKMeans<S> {
centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]);

let mut weight = vec![F32::infinity(); n];
let mut dis = vec![F32::zero(); n];
for i in 0..c {
let mut sum = F32::zero();
dis.par_iter_mut().enumerate().for_each(|(j, x)| {
*x = S::elkan_k_means_distance(&samples[j], &centroids[i]);
});
for j in 0..n {
let dis = S::elkan_k_means_distance(&samples[j], &centroids[i]);
lowerbound[(j, i)] = dis;
if dis * dis < weight[j] {
weight[j] = dis * dis;
lowerbound[(j, i)] = dis[j];
if dis[j] * dis[j] < weight[j] {
weight[j] = dis[j] * dis[j];
}
sum += weight[j];
}
Expand Down Expand Up @@ -132,11 +137,16 @@ impl<S: G> ElkanKMeans<S> {
// Step 1
let mut dist0 = Square::new(c, c);
let mut sp = vec![F32::zero(); c];
for i in 0..c {
for j in i + 1..c {
let dis = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
dist0[(i, j)] = dis;
dist0[(j, i)] = dis;
dist0.v.par_iter_mut().enumerate().for_each(|(ii, v)| {
let i = ii / c;
let j = ii % c;
if i <= j {
*v = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
}
});
for i in 1..c {
for j in 0..i - 1 {
dist0[(i, j)] = dist0[(j, i)];
}
}
for i in 0..c {
Expand All @@ -153,12 +163,18 @@ impl<S: G> ElkanKMeans<S> {
sp[i] = minimal;
}

let mut dis = vec![F32::zero(); n];
dis.par_iter_mut().enumerate().for_each(|(i, x)| {
if upperbound[i] > sp[assign[i]] {
*x = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
}
});
for i in 0..n {
// Step 2
if upperbound[i] <= sp[assign[i]] {
continue;
}
let mut minimal = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
let mut minimal = dis[i];
lowerbound[(i, assign[i])] = minimal;
upperbound[i] = minimal;
// Step 3
Expand Down Expand Up @@ -191,9 +207,9 @@ impl<S: G> ElkanKMeans<S> {
centroids.fill(S::Scalar::zero());
for i in 0..n {
for j in 0..dims as usize {
centroids[assign[i]][j] += samples[i][j];
centroids[self.assign[i]][j] += samples[i][j];
}
count[assign[i]] += 1.0;
count[self.assign[i]] += 1.0;
}
for i in 0..c {
if count[i] == F32::zero() {
Expand Down Expand Up @@ -229,22 +245,23 @@ impl<S: G> ElkanKMeans<S> {
count[i] = count[o] / 2.0;
count[o] = count[o] - count[i];
}
for i in 0..c {
S::elkan_k_means_normalize(&mut centroids[i]);
}
centroids.par_chunks_mut(dims as usize).for_each(|v| {
S::elkan_k_means_normalize(v);
});

// Step 5, 6
let mut dist1 = vec![F32::zero(); c];
for i in 0..c {
dist1[i] = S::elkan_k_means_distance(&old[i], &centroids[i]);
}
dist1.par_iter_mut().enumerate().for_each(|(i, v)| {
*v = S::elkan_k_means_distance(&old[i], &centroids[i]);
});
for i in 0..n {
for j in 0..c {
lowerbound[(i, j)] = std::cmp::max(lowerbound[(i, j)] - dist1[j], F32::zero());
self.lowerbound[(i, j)] =
std::cmp::max(self.lowerbound[(i, j)] - dist1[j], F32::zero());
}
}
for i in 0..n {
upperbound[i] += dist1[assign[i]];
self.upperbound[i] += dist1[self.assign[i]];
}

change == 0
Expand Down
1 change: 1 addition & 0 deletions crates/service/src/algorithms/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub fn make<S: G>(
options.clone(),
idx_opts.quantization,
&raw,
(0..raw.len()).collect::<Vec<_>>(),
);
FlatRam { raw, quantization }
}
Expand Down
1 change: 1 addition & 0 deletions crates/service/src/algorithms/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub fn make<S: G>(
options.clone(),
quantization_opts,
&raw,
(0..raw.len()).collect::<Vec<_>>(),
);
let n = raw.len();
let graph = HnswRamGraph {
Expand Down
124 changes: 57 additions & 67 deletions crates/service/src/algorithms/ivf/ivf_naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ use crate::index::IndexOptions;
use crate::index::SearchOptions;
use crate::index::VectorOptions;
use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::dir_ops::sync_dir;
use crate::utils::element_heap::ElementHeap;
use crate::utils::mmap_array::MmapArray;
use crate::utils::vec2::Vec2;
use rand::seq::index::sample;
use rand::thread_rng;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
use rayon::prelude::ParallelIterator;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::fs::create_dir;
use std::path::Path;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use std::sync::Arc;

pub struct IvfNaive<S: G> {
Expand Down Expand Up @@ -90,8 +88,8 @@ pub struct IvfRam<S: G> {
nlist: u32,
// ----------------------
centroids: Vec2<S::Scalar>,
heads: Vec<AtomicU32>,
nexts: Vec<SyncUnsafeCell<u32>>,
ptr: Vec<usize>,
payloads: Vec<Payload>,
}

unsafe impl<S: G> Send for IvfRam<S> {}
Expand All @@ -106,8 +104,8 @@ pub struct IvfMmap<S: G> {
nlist: u32,
// ----------------------
centroids: MmapArray<S::Scalar>,
heads: MmapArray<u32>,
nexts: MmapArray<u32>,
ptr: MmapArray<usize>,
payloads: MmapArray<Payload>,
}

unsafe impl<S: G> Send for IvfMmap<S> {}
Expand Down Expand Up @@ -141,12 +139,6 @@ pub fn make<S: G>(
sealed,
growing,
));
let quantization = Quantization::open(
&path.join("quantization"),
options.clone(),
quantization_opts,
&raw,
);
let n = raw.len();
let m = std::cmp::min(nsample, n);
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
Expand All @@ -165,70 +157,68 @@ pub fn make<S: G>(
}
}
let centroids = k_means.finish();
let heads = {
let mut heads = Vec::with_capacity(nlist as usize);
heads.resize_with(nlist as usize, || AtomicU32::new(u32::MAX));
heads
};
let nexts = {
let mut nexts = Vec::with_capacity(nlist as usize);
nexts.resize_with(n as usize, || SyncUnsafeCell::new(u32::MAX));
nexts
};
(0..n).into_par_iter().for_each(|i| {
let mut vector = S::ref_to_owned(raw.vector(i));
let mut idx = vec![0usize; n as usize];
idx.par_iter_mut().enumerate().for_each(|(i, x)| {
let mut vector = S::ref_to_owned(raw.vector(i as u32));
S::elkan_k_means_normalize2(&mut vector);
let mut result = (F32::infinity(), 0);
for i in 0..nlist {
let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), &centroids[i as usize]);
for i in 0..nlist as usize {
let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), &centroids[i]);
result = std::cmp::min(result, (dis, i));
}
let centroid_id = result.1;
loop {
let next = heads[centroid_id as usize].load(Acquire);
unsafe {
nexts[i as usize].get().write(next);
}
let o = &heads[centroid_id as usize];
if o.compare_exchange(next, i, Release, Relaxed).is_ok() {
break;
}
}
*x = result.1;
});
let mut invlists_ids = vec![Vec::new(); nlist as usize];
let mut invlists_payloads = vec![Vec::new(); nlist as usize];
for i in 0..n {
invlists_ids[idx[i as usize]].push(i);
invlists_payloads[idx[i as usize]].push(raw.payload(i));
}
let permutation = Vec::from_iter((0..nlist).flat_map(|i| &invlists_ids[i as usize]).copied());
let payloads = Vec::from_iter(
(0..nlist)
.flat_map(|i| &invlists_payloads[i as usize])
.copied(),
);
let quantization = Quantization::create(
&path.join("quantization"),
options.clone(),
quantization_opts,
&raw,
permutation,
);
let mut ptr = vec![0usize; nlist as usize + 1];
for i in 0..nlist {
ptr[i as usize + 1] = ptr[i as usize] + invlists_ids[i as usize].len();
}
IvfRam {
raw,
quantization,
centroids,
heads,
nexts,
nlist,
dims,
ptr,
payloads,
}
}

pub fn save<S: G>(mut ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
pub fn save<S: G>(ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
let centroids = MmapArray::create(
&path.join("centroids"),
(0..ram.nlist)
.flat_map(|i| &ram.centroids[i as usize])
.copied(),
);
let heads = MmapArray::create(
&path.join("heads"),
ram.heads.iter_mut().map(|x| *x.get_mut()),
);
let nexts = MmapArray::create(
&path.join("nexts"),
ram.nexts.iter_mut().map(|x| *x.get_mut()),
);
let ptr = MmapArray::create(&path.join("ptr"), ram.ptr.iter().copied());
let payloads = MmapArray::create(&path.join("payload"), ram.payloads.iter().copied());
IvfMmap {
raw: ram.raw,
quantization: ram.quantization,
dims: ram.dims,
nlist: ram.nlist,
centroids,
heads,
nexts,
ptr,
payloads,
}
}

Expand All @@ -241,17 +231,17 @@ pub fn open<S: G>(path: &Path, options: IndexOptions) -> IvfMmap<S> {
&raw,
);
let centroids = MmapArray::open(&path.join("centroids"));
let heads = MmapArray::open(&path.join("heads"));
let nexts = MmapArray::open(&path.join("nexts"));
let ptr = MmapArray::open(&path.join("ptr"));
let payloads = MmapArray::open(&path.join("payload"));
let IvfIndexingOptions { nlist, .. } = options.indexing.unwrap_ivf();
IvfMmap {
raw,
quantization,
dims: options.vector.dims,
nlist,
centroids,
heads,
nexts,
ptr,
payloads,
}
}

Expand All @@ -277,14 +267,14 @@ pub fn basic<S: G>(
let lists = lists.into_sorted_vec();
let mut result = BinaryHeap::new();
for i in lists.iter().map(|e| e.payload as usize) {
let mut j = mmap.heads[i];
while u32::MAX != j {
let payload = mmap.raw.payload(j);
let start = mmap.ptr[i];
let end = mmap.ptr[i + 1];
for j in start..end {
let payload = mmap.payloads[j];
if filter.check(payload) {
let distance = mmap.quantization.distance(vector, j);
let distance = mmap.quantization.distance(vector, j as u32);
result.push(Reverse(Element { distance, payload }));
}
j = mmap.nexts[j as usize];
}
}
result
Expand All @@ -311,15 +301,15 @@ pub fn vbase<'a, S: G>(
}
let lists = lists.into_sorted_vec();
let mut result = Vec::new();
for i in lists.iter().map(|e| e.payload as u32) {
let mut j = mmap.heads[i as usize];
while u32::MAX != j {
let payload = mmap.raw.payload(j);
for i in lists.iter().map(|e| e.payload as usize) {
let start = mmap.ptr[i];
let end = mmap.ptr[i + 1];
for j in start..end {
let payload = mmap.payloads[j];
if filter.check(payload) {
let distance = mmap.quantization.distance(vector, j);
let distance = mmap.quantization.distance(vector, j as u32);
result.push(Element { distance, payload });
}
j = mmap.nexts[j as usize];
}
}
(result, Box::new(std::iter::empty()))
Expand Down
Loading

0 comments on commit 3d1621b

Please sign in to comment.