Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ use super::{
encoding::compress_posting_list,
iter::CompressedPostingListIterator,
};
use super::{
encoding::compress_positions,
iter::{PostingListIterator, TokenIterator, TokenSource},
};
use super::{encoding::compress_positions, iter::PostingListIterator};
use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
use crate::frag_reuse::FragReuseIndex;
use crate::pbold;
Expand Down Expand Up @@ -938,13 +935,6 @@ impl TokenSet {
self.len() == 0
}

pub(crate) fn iter(&self) -> TokenIterator<'_> {
TokenIterator::new(match &self.tokens {
TokenMap::HashMap(map) => TokenSource::HashMap(map.iter()),
TokenMap::Fst(map) => TokenSource::Fst(map.stream()),
})
}

pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
match format {
TokenSetFormat::Arrow => self.into_arrow_batch(),
Expand Down Expand Up @@ -1150,6 +1140,24 @@ impl TokenSet {
token_id
}

pub(crate) fn get_or_add(&mut self, token: &str) -> u32 {
let next_id = self.next_id;
match self.tokens {
TokenMap::HashMap(ref mut map) => {
if let Some(&token_id) = map.get(token) {
return token_id;
}

map.insert(token.to_owned(), next_id);
}
_ => unreachable!("tokens must be HashMap while indexing"),
}

self.next_id += 1;
self.total_length += token.len();
next_id
}

pub fn get(&self, token: &str) -> Option<u32> {
match self.tokens {
TokenMap::HashMap(ref map) => map.get(token).copied(),
Expand Down
32 changes: 0 additions & 32 deletions rust/lance-index/src/scalar/inverted/iter.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,15 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::collections::hash_map;

use arrow::array::AsArray;
use arrow_array::{Array, LargeBinaryArray, ListArray};
use fst::Streamer;

use super::{
builder::BLOCK_SIZE,
encoding::{decompress_positions, decompress_posting_block, decompress_posting_remainder},
PostingList,
};

pub enum TokenSource<'a> {
HashMap(hash_map::Iter<'a, String, u32>),
Fst(fst::map::Stream<'a>),
}
pub struct TokenIterator<'a> {
source: TokenSource<'a>,
}

impl<'a> TokenIterator<'a> {
pub(crate) fn new(source: TokenSource<'a>) -> Self {
Self { source }
}
}

impl Iterator for TokenIterator<'_> {
type Item = (String, u32);

fn next(&mut self) -> Option<Self::Item> {
match &mut self.source {
TokenSource::HashMap(iter) => iter
.next()
.map(|(token, token_id)| (token.clone(), *token_id)),
TokenSource::Fst(iter) => iter.next().map(|(token, token_id)| {
(String::from_utf8_lossy(token).into_owned(), token_id as u32)
}),
}
}
}

pub enum PostingListIterator<'a> {
Plain(PlainPostingListIterator<'a>),
Compressed(Box<CompressedPostingListIterator>),
Expand Down
134 changes: 126 additions & 8 deletions rust/lance-index/src/scalar/inverted/merger.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::collections::HashMap;

use fst::Streamer;
use lance_core::Result;

use crate::scalar::IndexStore;

use super::{
builder::{doc_file_path, posting_file_path, token_file_path, InnerBuilder, PositionRecorder},
InvertedPartition, PostingListBuilder, TokenSetFormat,
InvertedPartition, PostingListBuilder, TokenMap, TokenSetFormat,
};

pub trait Merger {
Expand Down Expand Up @@ -123,11 +122,28 @@ impl Merger for SizeBasedMerger<'_> {
estimated_size = 0;
}

let mut inv_token = HashMap::with_capacity(part.tokens.len());
// merge token set
for (token, token_id) in part.tokens.iter() {
self.builder.tokens.add(token.clone());
inv_token.insert(token_id, token);
let mut token_id_map = vec![u32::MAX; part.tokens.len()];
match &part.tokens.tokens {
TokenMap::HashMap(map) => {
for (token, token_id) in map.iter() {
let new_token_id = self.builder.tokens.get_or_add(token.as_str());
let index = *token_id as usize;
debug_assert!(index < token_id_map.len());
token_id_map[index] = new_token_id;
}
}
TokenMap::Fst(map) => {
let mut stream = map.stream();
while let Some((token, token_id)) = stream.next() {
let token_id = token_id as u32;
let token = String::from_utf8_lossy(token);
let new_token_id = self.builder.tokens.get_or_add(token.as_ref());
let index = token_id as usize;
debug_assert!(index < token_id_map.len());
token_id_map[index] = new_token_id;
}
}
}
// merge doc set
let doc_id_offset = self.builder.docs.len() as u32;
Expand All @@ -149,7 +165,8 @@ impl Merger for SizeBasedMerger<'_> {
let posting_list = part
.inverted_list
.posting_list_from_batch(&postings.slice(token_id as usize, 1), token_id)?;
let new_token_id = self.builder.tokens.get(&inv_token[&token_id]).unwrap();
let new_token_id = token_id_map[token_id as usize];
debug_assert_ne!(new_token_id, u32::MAX);
let builder = &mut self.builder.posting_lists[new_token_id as usize];
let old_size = builder.size();
for (doc_id, freq, positions) in posting_list.iter() {
Expand All @@ -175,3 +192,104 @@ impl Merger for SizeBasedMerger<'_> {
Ok(self.partitions.clone())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::metrics::NoOpMetricsCollector;
use crate::scalar::lance_format::LanceIndexStore;
use lance_core::cache::LanceCache;
use lance_core::utils::tempfile::TempObjDir;
use lance_io::object_store::ObjectStore;
use std::sync::Arc;

#[tokio::test]
async fn test_merge_reuses_token_ids_for_shared_tokens() -> Result<()> {
let src_dir = TempObjDir::default();
let dest_dir = TempObjDir::default();
let src_store = Arc::new(LanceIndexStore::new(
ObjectStore::local().into(),
src_dir.clone(),
Arc::new(LanceCache::no_cache()),
));
let dest_store = Arc::new(LanceIndexStore::new(
ObjectStore::local().into(),
dest_dir.clone(),
Arc::new(LanceCache::no_cache()),
));

let token_set_format = TokenSetFormat::default();

let mut builder0 = InnerBuilder::new(0, false, token_set_format);
let apple_id = builder0.tokens.add("apple".to_owned());
let banana_id = builder0.tokens.add("banana".to_owned());
builder0
.posting_lists
.resize_with(builder0.tokens.len(), || PostingListBuilder::new(false));
let doc_id = builder0.docs.append(10, 2);
builder0.posting_lists[apple_id as usize].add(doc_id, PositionRecorder::Count(1));
builder0.posting_lists[banana_id as usize].add(doc_id, PositionRecorder::Count(1));
builder0.write(src_store.as_ref()).await?;

let mut builder1 = InnerBuilder::new(1, false, token_set_format);
let banana_id = builder1.tokens.add("banana".to_owned());
let carrot_id = builder1.tokens.add("carrot".to_owned());
builder1
.posting_lists
.resize_with(builder1.tokens.len(), || PostingListBuilder::new(false));
let doc_id = builder1.docs.append(20, 2);
builder1.posting_lists[banana_id as usize].add(doc_id, PositionRecorder::Count(1));
builder1.posting_lists[carrot_id as usize].add(doc_id, PositionRecorder::Count(1));
builder1.write(src_store.as_ref()).await?;

let partition0 = InvertedPartition::load(
src_store.clone(),
0,
None,
&LanceCache::no_cache(),
token_set_format,
)
.await?;
let partition1 = InvertedPartition::load(
src_store.clone(),
1,
None,
&LanceCache::no_cache(),
token_set_format,
)
.await?;

let mut merger = SizeBasedMerger::new(
dest_store.as_ref(),
vec![partition0, partition1],
u64::MAX,
token_set_format,
);
let merged_partitions = merger.merge().await?;
assert_eq!(merged_partitions, vec![2]);

let merged = InvertedPartition::load(
dest_store.clone(),
merged_partitions[0],
None,
&LanceCache::no_cache(),
token_set_format,
)
.await?;

assert_eq!(merged.tokens.len(), 3);
assert_eq!(merged.docs.len(), 2);
assert_eq!(merged.docs.row_id(0), 10);
assert_eq!(merged.docs.row_id(1), 20);

let banana_token_id = merged.tokens.get("banana").unwrap();
let posting = merged
.inverted_list
.posting_list(banana_token_id, false, &NoOpMetricsCollector)
.await?;
let doc_ids: Vec<u64> = posting.iter().map(|(doc_id, _, _)| doc_id).collect();
assert_eq!(doc_ids, vec![0, 1]);

Ok(())
}
}
Loading