Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rework vbase #238

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion crates/service/src/algorithms/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,19 @@ impl<S: G> Flat<S> {
self.mmap.raw.payload(i)
}

pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, vector: &[S::Scalar], k: usize, filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter)
}

pub fn vbase<'a>(
&'a self,
vector: &'a [S::Scalar],
) -> (
Vec<HeapElement>,
Box<(dyn Iterator<Item = HeapElement> + 'a)>,
) {
vbase(&self.mmap, vector)
}
}

unsafe impl<S: G> Send for Flat<S> {}
Expand Down Expand Up @@ -121,3 +131,16 @@ pub fn search<S: G>(
}
result
}

pub fn vbase<'a, S: G>(
mmap: &'a FlatMmap<S>,
vector: &'a [S::Scalar],
) -> (Vec<HeapElement>, Box<dyn Iterator<Item = HeapElement> + 'a>) {
let mut result = Vec::new();
for i in 0..mmap.raw.len() {
let distance = mmap.quantization.distance(vector, i);
let payload = mmap.raw.payload(i);
result.push(HeapElement { distance, payload });
}
(result, Box::new(std::iter::empty()))
}
228 changes: 102 additions & 126 deletions crates/service/src/algorithms/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::index::indexing::hnsw::HnswIndexingOptions;
use crate::index::segments::growing::GrowingSegment;
use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions;
use crate::index::SearchOptions;
use crate::prelude::*;
use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray;
Expand Down Expand Up @@ -51,12 +52,24 @@ impl<S: G> Hnsw<S> {
self.mmap.raw.payload(i)
}

pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter)
pub fn search(
&self,
vector: &[S::Scalar],
opts: &SearchOptions,
filter: &mut impl Filter,
) -> Heap {
search(&self.mmap, vector, opts.search_k, filter)
}

pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> {
search_vbase(&self.mmap, range, vector)
pub fn vbase<'a>(
&'a self,
vector: &'a [S::Scalar],
opts: &'a SearchOptions,
) -> (
Vec<HeapElement>,
Box<(dyn Iterator<Item = HeapElement> + 'a)>,
) {
vbase(&self.mmap, vector, opts.vbase_range)
}
}

Expand Down Expand Up @@ -185,7 +198,6 @@ pub fn make<S: G>(
k: usize,
i: u8,
) -> Vec<(F32, u32)> {
assert!(k > 0);
let mut visited = visited.fetch();
let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = BinaryHeap::new();
Expand Down Expand Up @@ -380,8 +392,8 @@ pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> HnswMmap<S> {

pub fn search<S: G>(
mmap: &HnswMmap<S>,
k: usize,
vector: &[S::Scalar],
k: usize,
filter: &mut impl Filter,
) -> Heap {
let Some(s) = entry(mmap, filter) else {
Expand All @@ -392,18 +404,34 @@ pub fn search<S: G>(
local_search(mmap, k, u, vector, filter)
}

pub fn search_vbase<'a, S: G>(
pub fn vbase<'a, S: G>(
mmap: &'a HnswMmap<S>,
vector: &'a [S::Scalar],
range: usize,
vector: &[S::Scalar],
) -> HnswIndexIter<'a, S> {
let filter_fn = &mut |_| true;
let Some(s) = entry(mmap, filter_fn) else {
return HnswIndexIter(None);
) -> (
Vec<HeapElement>,
Box<(dyn Iterator<Item = HeapElement> + 'a)>,
) {
let Some(s) = entry(mmap, &mut |_| true) else {
return (Vec::new(), Box::new(std::iter::empty()));
};
let levels = count_layers_of_a_vertex(mmap.m, s) - 1;
let u = fast_search(mmap, 1..=levels, s, vector, filter_fn);
local_search_vbase(mmap, range, u, vector)
let u = fast_search(mmap, 1..=levels, s, vector, &mut |_| true);
let mut iter = local_search_vbase(mmap, u, vector);
let mut queue = BinaryHeap::<HeapElement>::with_capacity(1 + range);
let mut stage1 = Vec::new();
for x in &mut iter {
if queue.len() == range && queue.peek().unwrap().distance < x.distance {
stage1.push(x);
break;
}
if queue.len() == range {
queue.pop();
}
queue.push(x);
stage1.push(x);
}
(stage1, Box::new(iter))
}

pub fn entry<S: G>(mmap: &HnswMmap<S>, filter: &mut impl Filter) -> Option<u32> {
Expand Down Expand Up @@ -471,7 +499,6 @@ pub fn local_search<S: G>(
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
assert!(k > 0);
let mut visited = mmap.visited.fetch();
let mut visited = visited.fetch();
let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
Expand Down Expand Up @@ -510,59 +537,34 @@ pub fn local_search<S: G>(
results
}

fn local_search_vbase<'a, S: G>(
pub fn local_search_vbase<'a, S: G>(
mmap: &'a HnswMmap<S>,
range: usize,
s: u32,
vector: &[S::Scalar],
) -> HnswIndexIter<'a, S> {
assert!(range > 0);
let mut visited_guard = mmap.visited.fetch();
let mut visited = visited_guard.fetch();
vector: &'a [S::Scalar],
) -> impl Iterator<Item = HeapElement> + 'a {
let mut visited = mmap.visited.fetch2();
let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = Heap::new(range);
let mut lost = Vec::<Reverse<HeapElement>>::new();
visited.mark(s);
let s_dis = mmap.quantization.distance(vector, s);
candidates.push(Reverse((s_dis, s)));
results.push(HeapElement {
distance: s_dis,
payload: mmap.raw.payload(s),
});
while let Some(Reverse((u_dis, u))) = candidates.pop() {
if !results.check(u_dis) {
candidates.push(Reverse((u_dis, u)));
break;
}
let edges = find_edges(mmap, u, 0);
for &HnswMmapEdge(_, v) in edges.iter() {
if !visited.check(v) {
continue;
}
visited.mark(v);
let v_dis = mmap.quantization.distance(vector, v);
if !results.check(v_dis) {
continue;
}
candidates.push(Reverse((v_dis, v)));
if let Some(val) = results.push(HeapElement {
distance: v_dis,
payload: mmap.raw.payload(v),
}) {
lost.push(Reverse(val));
std::iter::from_fn(move || {
let Reverse((u_dis, u)) = candidates.pop()?;
{
let edges = find_edges(mmap, u, 0);
for &HnswMmapEdge(_, v) in edges.iter() {
if !visited.check(v) {
continue;
}
visited.mark(v);
let v_dis = mmap.quantization.distance(vector, v);
candidates.push(Reverse((v_dis, v)));
}
}
}
lost.sort_unstable();
HnswIndexIter(Some(HnswIndexIterInner {
mmap,
range,
candidates,
results: results.into_reversed_heap(),
lost,
visited: visited_guard,
vector: vector.to_vec(),
}))
Some(HeapElement {
distance: u_dis,
payload: mmap.raw.payload(u),
})
})
}

fn count_layers_of_a_vertex(m: u32, i: u32) -> u8 {
Expand Down Expand Up @@ -621,6 +623,21 @@ impl VisitedPool {
.unwrap_or_else(|| VisitedBuffer::new(self.n as _));
VisitedGuard { buffer, pool: self }
}

fn fetch2(&self) -> VisitedGuardChecker {
let mut buffer = self
.locked_buffers
.lock()
.pop()
.unwrap_or_else(|| VisitedBuffer::new(self.n as _));
{
buffer.version = buffer.version.wrapping_add(1);
if buffer.version == 0 {
buffer.data.fill(0);
}
}
VisitedGuardChecker { buffer, pool: self }
}
}

struct VisitedGuard<'a> {
Expand All @@ -638,11 +655,6 @@ impl<'a> VisitedGuard<'a> {
buffer: &mut self.buffer,
}
}
fn fetch_current_version(&mut self) -> VisitedChecker {
VisitedChecker {
buffer: &mut self.buffer,
}
}
}

impl<'a> Drop for VisitedGuard<'a> {
Expand All @@ -669,77 +681,41 @@ impl<'a> VisitedChecker<'a> {
}
}

struct VisitedBuffer {
version: usize,
data: Vec<usize>,
struct VisitedGuardChecker<'a> {
buffer: VisitedBuffer,
pool: &'a VisitedPool,
}

impl VisitedBuffer {
fn new(capacity: usize) -> Self {
Self {
version: 0,
data: bytemuck::zeroed_vec(capacity),
}
impl<'a> VisitedGuardChecker<'a> {
fn check(&mut self, i: u32) -> bool {
self.buffer.data[i as usize] != self.buffer.version
}
fn mark(&mut self, i: u32) {
self.buffer.data[i as usize] = self.buffer.version;
}
}

pub struct HnswIndexIter<'mmap, S: G>(Option<HnswIndexIterInner<'mmap, S>>);

pub struct HnswIndexIterInner<'mmap, S: G> {
mmap: &'mmap HnswMmap<S>,
range: usize,
candidates: BinaryHeap<Reverse<(F32, u32)>>,
results: BinaryHeap<Reverse<HeapElement>>,
// The points lost in the first stage, we should keep it to the second stage.
lost: Vec<Reverse<HeapElement>>,
visited: VisitedGuard<'mmap>,
vector: Vec<S::Scalar>,
}

impl<S: G> Iterator for HnswIndexIter<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
self.0.as_mut()?.next()
impl<'a> Drop for VisitedGuardChecker<'a> {
fn drop(&mut self) {
let src = VisitedBuffer {
version: 0,
data: Vec::new(),
};
let buffer = std::mem::replace(&mut self.buffer, src);
self.pool.locked_buffers.lock().push(buffer);
}
}

impl<S: G> Iterator for HnswIndexIterInner<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
if self.results.len() > self.range {
return self.pop();
}

let mut visited = self.visited.fetch_current_version();
while let Some(Reverse((_, u))) = self.candidates.pop() {
let edges = find_edges(self.mmap, u, 0);
for &HnswMmapEdge(_, v) in edges.iter() {
if !visited.check(v) {
continue;
}
visited.mark(v);
let v_dis = self.mmap.quantization.distance(&self.vector, v);
self.candidates.push(Reverse((v_dis, v)));
self.results.push(Reverse(HeapElement {
distance: v_dis,
payload: self.mmap.raw.payload(v),
}));
}
if self.results.len() > self.range {
return self.pop();
}
}

self.pop()
}
struct VisitedBuffer {
version: usize,
data: Vec<usize>,
}

impl<S: G> HnswIndexIterInner<'_, S> {
fn pop(&mut self) -> Option<HeapElement> {
if self.results.peek() > self.lost.last() {
self.results.pop().map(|x| x.0)
} else {
self.lost.pop().map(|x| x.0)
impl VisitedBuffer {
fn new(capacity: usize) -> Self {
Self {
version: 0,
data: bytemuck::zeroed_vec(capacity),
}
}
}
Loading
Loading