diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 7eaaff89cf2..1862750cbfe 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -62,10 +62,7 @@ use super::{ encoding::compress_posting_list, iter::CompressedPostingListIterator, }; -use super::{ - encoding::compress_positions, - iter::{PostingListIterator, TokenIterator, TokenSource}, -}; +use super::{encoding::compress_positions, iter::PostingListIterator}; use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams}; use crate::frag_reuse::FragReuseIndex; use crate::pbold; @@ -938,13 +935,6 @@ impl TokenSet { self.len() == 0 } - pub(crate) fn iter(&self) -> TokenIterator<'_> { - TokenIterator::new(match &self.tokens { - TokenMap::HashMap(map) => TokenSource::HashMap(map.iter()), - TokenMap::Fst(map) => TokenSource::Fst(map.stream()), - }) - } - pub fn to_batch(self, format: TokenSetFormat) -> Result { match format { TokenSetFormat::Arrow => self.into_arrow_batch(), @@ -1150,6 +1140,24 @@ impl TokenSet { token_id } + pub(crate) fn get_or_add(&mut self, token: &str) -> u32 { + let next_id = self.next_id; + match self.tokens { + TokenMap::HashMap(ref mut map) => { + if let Some(&token_id) = map.get(token) { + return token_id; + } + + map.insert(token.to_owned(), next_id); + } + _ => unreachable!("tokens must be HashMap while indexing"), + } + + self.next_id += 1; + self.total_length += token.len(); + next_id + } + pub fn get(&self, token: &str) -> Option { match self.tokens { TokenMap::HashMap(ref map) => map.get(token).copied(), diff --git a/rust/lance-index/src/scalar/inverted/iter.rs b/rust/lance-index/src/scalar/inverted/iter.rs index b54fe543e9a..9c52b7a4873 100644 --- a/rust/lance-index/src/scalar/inverted/iter.rs +++ b/rust/lance-index/src/scalar/inverted/iter.rs @@ -1,11 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::collections::hash_map; - use arrow::array::AsArray; use arrow_array::{Array, LargeBinaryArray, ListArray}; -use fst::Streamer; use super::{ builder::BLOCK_SIZE, @@ -13,35 +10,6 @@ use super::{ PostingList, }; -pub enum TokenSource<'a> { - HashMap(hash_map::Iter<'a, String, u32>), - Fst(fst::map::Stream<'a>), -} -pub struct TokenIterator<'a> { - source: TokenSource<'a>, -} - -impl<'a> TokenIterator<'a> { - pub(crate) fn new(source: TokenSource<'a>) -> Self { - Self { source } - } -} - -impl Iterator for TokenIterator<'_> { - type Item = (String, u32); - - fn next(&mut self) -> Option { - match &mut self.source { - TokenSource::HashMap(iter) => iter - .next() - .map(|(token, token_id)| (token.clone(), *token_id)), - TokenSource::Fst(iter) => iter.next().map(|(token, token_id)| { - (String::from_utf8_lossy(token).into_owned(), token_id as u32) - }), - } - } -} - pub enum PostingListIterator<'a> { Plain(PlainPostingListIterator<'a>), Compressed(Box), diff --git a/rust/lance-index/src/scalar/inverted/merger.rs b/rust/lance-index/src/scalar/inverted/merger.rs index 6440a736ec8..66f1ee51de3 100644 --- a/rust/lance-index/src/scalar/inverted/merger.rs +++ b/rust/lance-index/src/scalar/inverted/merger.rs @@ -1,15 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::collections::HashMap; - +use fst::Streamer; use lance_core::Result; use crate::scalar::IndexStore; use super::{ builder::{doc_file_path, posting_file_path, token_file_path, InnerBuilder, PositionRecorder}, - InvertedPartition, PostingListBuilder, TokenSetFormat, + InvertedPartition, PostingListBuilder, TokenMap, TokenSetFormat, }; pub trait Merger { @@ -123,11 +122,28 @@ impl Merger for SizeBasedMerger<'_> { estimated_size = 0; } - let mut inv_token = HashMap::with_capacity(part.tokens.len()); // merge token set - for (token, token_id) in part.tokens.iter() { - self.builder.tokens.add(token.clone()); - inv_token.insert(token_id, token); + let mut token_id_map = vec![u32::MAX; part.tokens.len()]; + match &part.tokens.tokens { + TokenMap::HashMap(map) => { + for (token, token_id) in map.iter() { + let new_token_id = self.builder.tokens.get_or_add(token.as_str()); + let index = *token_id as usize; + debug_assert!(index < token_id_map.len()); + token_id_map[index] = new_token_id; + } + } + TokenMap::Fst(map) => { + let mut stream = map.stream(); + while let Some((token, token_id)) = stream.next() { + let token_id = token_id as u32; + let token = String::from_utf8_lossy(token); + let new_token_id = self.builder.tokens.get_or_add(token.as_ref()); + let index = token_id as usize; + debug_assert!(index < token_id_map.len()); + token_id_map[index] = new_token_id; + } + } } // merge doc set let doc_id_offset = self.builder.docs.len() as u32; @@ -149,7 +165,8 @@ impl Merger for SizeBasedMerger<'_> { let posting_list = part .inverted_list .posting_list_from_batch(&postings.slice(token_id as usize, 1), token_id)?; - let new_token_id = self.builder.tokens.get(&inv_token[&token_id]).unwrap(); + let new_token_id = token_id_map[token_id as usize]; + debug_assert_ne!(new_token_id, u32::MAX); let builder = &mut self.builder.posting_lists[new_token_id as usize]; let old_size = builder.size(); for (doc_id, freq, positions) in posting_list.iter() { @@ -175,3 +192,104 @@ impl Merger for SizeBasedMerger<'_> { Ok(self.partitions.clone()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::metrics::NoOpMetricsCollector; + use crate::scalar::lance_format::LanceIndexStore; + use lance_core::cache::LanceCache; + use lance_core::utils::tempfile::TempObjDir; + use lance_io::object_store::ObjectStore; + use std::sync::Arc; + + #[tokio::test] + async fn test_merge_reuses_token_ids_for_shared_tokens() -> Result<()> { + let src_dir = TempObjDir::default(); + let dest_dir = TempObjDir::default(); + let src_store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + src_dir.clone(), + Arc::new(LanceCache::no_cache()), + )); + let dest_store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + dest_dir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let token_set_format = TokenSetFormat::default(); + + let mut builder0 = InnerBuilder::new(0, false, token_set_format); + let apple_id = builder0.tokens.add("apple".to_owned()); + let banana_id = builder0.tokens.add("banana".to_owned()); + builder0 + .posting_lists + .resize_with(builder0.tokens.len(), || PostingListBuilder::new(false)); + let doc_id = builder0.docs.append(10, 2); + builder0.posting_lists[apple_id as usize].add(doc_id, PositionRecorder::Count(1)); + builder0.posting_lists[banana_id as usize].add(doc_id, PositionRecorder::Count(1)); + builder0.write(src_store.as_ref()).await?; + + let mut builder1 = InnerBuilder::new(1, false, token_set_format); + let banana_id = builder1.tokens.add("banana".to_owned()); + let carrot_id = builder1.tokens.add("carrot".to_owned()); + builder1 + .posting_lists + .resize_with(builder1.tokens.len(), || PostingListBuilder::new(false)); + let doc_id = builder1.docs.append(20, 2); + builder1.posting_lists[banana_id as usize].add(doc_id, PositionRecorder::Count(1)); + builder1.posting_lists[carrot_id as usize].add(doc_id, PositionRecorder::Count(1)); + builder1.write(src_store.as_ref()).await?; + + let partition0 = InvertedPartition::load( + src_store.clone(), + 0, + None, + &LanceCache::no_cache(), + token_set_format, + ) + .await?; + let partition1 = InvertedPartition::load( + src_store.clone(), + 1, + None, + &LanceCache::no_cache(), + token_set_format, + ) + .await?; + + let mut merger = SizeBasedMerger::new( + dest_store.as_ref(), + vec![partition0, partition1], + u64::MAX, + token_set_format, + ); + let merged_partitions = merger.merge().await?; + assert_eq!(merged_partitions, vec![2]); + + let merged = InvertedPartition::load( + dest_store.clone(), + merged_partitions[0], + None, + &LanceCache::no_cache(), + token_set_format, + ) + .await?; + + assert_eq!(merged.tokens.len(), 3); + assert_eq!(merged.docs.len(), 2); + assert_eq!(merged.docs.row_id(0), 10); + assert_eq!(merged.docs.row_id(1), 20); + + let banana_token_id = merged.tokens.get("banana").unwrap(); + let posting = merged + .inverted_list + .posting_list(banana_token_id, false, &NoOpMetricsCollector) + .await?; + let doc_ids: Vec = posting.iter().map(|(doc_id, _, _)| doc_id).collect(); + assert_eq!(doc_ids, vec![0, 1]); + + Ok(()) + } +}