Skip to content

Commit

Permalink
Add query matching terms in a set (#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
trinity-1686a authored Sep 28, 2022
1 parent f842da7 commit a86b0df
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod range_query;
mod regex_query;
mod reqopt_scorer;
mod scorer;
mod set_query;
mod term_query;
mod union;
mod weight;
Expand Down Expand Up @@ -58,6 +59,7 @@ pub use self::score_combiner::{
DisjunctionMaxCombiner, ScoreCombiner, SumCombiner, SumWithCoordsCombiner,
};
pub use self::scorer::Scorer;
pub use self::set_query::TermSetQuery;
pub use self::term_query::TermQuery;
pub use self::union::Union;
#[cfg(test)]
Expand Down
223 changes: 223 additions & 0 deletions src/query/set_query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
use std::collections::HashMap;

use tantivy_fst::raw::CompiledAddr;
use tantivy_fst::{Automaton, Map};

use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{AutomatonWeight, BooleanWeight, Occur, Query, Weight};
use crate::schema::Field;
use crate::{Searcher, Term};

/// A Term Set Query matches all of the documents containing any of the Term provided
#[derive(Debug, Clone)]
pub struct TermSetQuery {
terms_map: HashMap<Field, Vec<Term>>,
}

impl TermSetQuery {
/// Create a Term Set Query
pub fn new<T: IntoIterator<Item = Term>>(terms: T) -> Self {
let mut terms_map: HashMap<_, Vec<_>> = HashMap::new();
for term in terms {
terms_map.entry(term.field()).or_default().push(term);
}

for terms in terms_map.values_mut() {
terms.sort_unstable();
terms.dedup();
}

TermSetQuery { terms_map }
}

fn specialized_weight(
&self,
searcher: &Searcher,
) -> crate::Result<BooleanWeight<DoNothingCombiner>> {
let mut sub_queries: Vec<(_, Box<dyn Weight>)> = Vec::with_capacity(self.terms_map.len());

for (&field, sorted_terms) in self.terms_map.iter() {
let field_entry = searcher.schema().get_field_entry(field);
let field_type = field_entry.field_type();
if !field_type.is_indexed() {
let error_msg = format!("Field {:?} is not indexed.", field_entry.name());
return Err(crate::TantivyError::SchemaError(error_msg));
}
let field_type = field_type.value_type();

// In practice this won't fail because:
// - we are writing to memory, so no IoError
// - Terms are ordered
let map = Map::from_iter(sorted_terms.iter().map(|key| (key.value_bytes(), 0)))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;

sub_queries.push((
Occur::Should,
Box::new(AutomatonWeight::new(field, SetDfaWrapper(map))),
));
}

Ok(BooleanWeight::new(
sub_queries,
false,
Box::new(|| DoNothingCombiner),
))
}
}

impl Query for TermSetQuery {
fn weight(
&self,
searcher: &Searcher,
_scoring_enabled: bool,
) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight(searcher)?))
}
}

struct SetDfaWrapper(Map<Vec<u8>>);

impl Automaton for SetDfaWrapper {
type State = Option<CompiledAddr>;

fn start(&self) -> Option<CompiledAddr> {
Some(self.0.as_ref().root().addr())
}

fn is_match(&self, state_opt: &Option<CompiledAddr>) -> bool {
if let Some(state) = state_opt {
self.0.as_ref().node(*state).is_final()
} else {
false
}
}

fn accept(&self, state_opt: &Option<CompiledAddr>, byte: u8) -> Option<CompiledAddr> {
let state = state_opt.as_ref()?;
let node = self.0.as_ref().node(*state);
let transition = node.find_input(byte)?;
Some(node.transition_addr(transition))
}

fn can_match(&self, state: &Self::State) -> bool {
state.is_some()
}
}

#[cfg(test)]
mod tests {

use crate::collector::TopDocs;
use crate::query::TermSetQuery;
use crate::schema::{Schema, TEXT};
use crate::{assert_nearly_equals, Index, Term};

#[test]
pub fn test_term_set_query() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field1 = schema_builder.add_text_field("field1", TEXT);
let field2 = schema_builder.add_text_field("field1", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_for_tests()?;
index_writer.add_document(doc!(
field1 => "doc1",
field2 => "val1",
))?;
index_writer.add_document(doc!(
field1 => "doc2",
field2 => "val2",
))?;
index_writer.add_document(doc!(
field1 => "doc3",
field2 => "val3",
))?;
index_writer.add_document(doc!(
field1 => "val3",
field2 => "doc3",
))?;
index_writer.commit()?;
}
let reader = index.reader()?;
let searcher = reader.searcher();

{
// single element
let terms = vec![Term::from_field_text(field1, "doc1")];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 1, "Expected 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
}

{
// single element, absent
let terms = vec![Term::from_field_text(field1, "doc4")];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(1))?;
assert!(top_docs.is_empty(), "Expected 0 document");
}

{
// multiple elements
let terms = vec![
Term::from_field_text(field1, "doc1"),
Term::from_field_text(field1, "doc2"),
];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(2))?;
assert_eq!(top_docs.len(), 2, "Expected 2 documents");
for (score, _) in top_docs {
assert_nearly_equals!(1.0, score);
}
}

{
// multiple elements, mixed fields
let terms = vec![
Term::from_field_text(field1, "doc1"),
Term::from_field_text(field1, "doc1"),
Term::from_field_text(field2, "val2"),
];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;

assert_eq!(top_docs.len(), 2, "Expected 2 document");
for (score, _) in top_docs {
assert_nearly_equals!(1.0, score);
}
}

{
// no field crosstalk
let terms = vec![Term::from_field_text(field1, "doc3")];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
assert_eq!(top_docs.len(), 1, "Expected 1 document");

let terms = vec![Term::from_field_text(field2, "doc3")];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
assert_eq!(top_docs.len(), 1, "Expected 1 document");

let terms = vec![
Term::from_field_text(field1, "doc3"),
Term::from_field_text(field2, "doc3"),
];

let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
assert_eq!(top_docs.len(), 2, "Expected 2 document");
}

Ok(())
}
}

0 comments on commit a86b0df

Please sign in to comment.