Skip to content

Commit

Permalink
feat: tantivy_0.20.1_upgrade (#82)
Browse files Browse the repository at this point in the history
* Added api changes from tantivy-0.20.1

* lint fix

* Increase test writer heap to 10_000_000

* Revert test back to original check

* Update src/searcher.rs

Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>

---------

Co-authored-by: Caleb Hattingh <caleb.hattingh@gmail.com>
Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 20, 2023
1 parent 1fe7244 commit a266f41
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 61 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.19.2"
version = "0.20.1"
readme = "README.md"
authors = ["Damir Jelić <poljar@termina.org.uk>"]
edition = "2018"
Expand All @@ -15,11 +15,11 @@ pyo3-build-config = "0.18.0"

[dependencies]
chrono = "0.4.23"
tantivy = "0.19.2"
tantivy = "0.20.1"
itertools = "0.10.5"
futures = "0.3.26"
serde_json = "1.0.91"

[dependencies.pyo3]
version = "0.18.0"
features = ["extension-module"]
features = ["extension-module"]
7 changes: 4 additions & 3 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ impl Index {
let schema = self.index.schema();
if let Some(default_field_names_vec) = default_field_names {
for default_field_name in &default_field_names_vec {
if let Some(field) = schema.get_field(default_field_name) {
if let Ok(field) = schema.get_field(default_field_name) {
let field_entry = schema.get_field_entry(field);
if !field_entry.is_indexed() {
return Err(exceptions::PyValueError::new_err(
Expand Down Expand Up @@ -385,10 +385,11 @@ impl Index {
];

for (name, lang) in &analyzers {
let an = TextAnalyzer::from(SimpleTokenizer)
let an = TextAnalyzer::builder(SimpleTokenizer::default())
.filter(RemoveLongFilter::limit(40))
.filter(LowerCaser)
.filter(Stemmer::new(*lang));
.filter(Stemmer::new(*lang))
.build();
index.tokenizers().register(name, an);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub(crate) fn get_field(
schema: &tv::schema::Schema,
field_name: &str,
) -> PyResult<tv::schema::Field> {
let field = schema.get_field(field_name).ok_or_else(|| {
let field = schema.get_field(field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
Expand Down
58 changes: 12 additions & 46 deletions src/schemabuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ impl SchemaBuilder {
///
/// Returns the associated field handle.
/// Raises a ValueError if there was an error with the field creation.
#[pyo3(signature = (name, stored = false, indexed = false, fast = None))]
#[pyo3(signature = (name, stored = false, indexed = false, fast = false))]
fn add_integer_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: Option<&str>,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;

Expand All @@ -132,13 +132,13 @@ impl SchemaBuilder {
Ok(self.clone())
}

#[pyo3(signature = (name, stored = false, indexed = false, fast = None))]
#[pyo3(signature = (name, stored = false, indexed = false, fast = false))]
fn add_float_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: Option<&str>,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;

Expand Down Expand Up @@ -174,13 +174,13 @@ impl SchemaBuilder {
///
/// Returns the associated field handle.
/// Raises a ValueError if there was an error with the field creation.
#[pyo3(signature = (name, stored = false, indexed = false, fast = None))]
#[pyo3(signature = (name, stored = false, indexed = false, fast = false))]
fn add_unsigned_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: Option<&str>,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;

Expand Down Expand Up @@ -216,13 +216,13 @@ impl SchemaBuilder {
///
/// Returns the associated field handle.
/// Raises a ValueError if there was an error with the field creation.
#[pyo3(signature = (name, stored = false, indexed = false, fast = None))]
#[pyo3(signature = (name, stored = false, indexed = false, fast = false))]
fn add_date_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: Option<&str>,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;

Expand All @@ -233,21 +233,8 @@ impl SchemaBuilder {
if indexed {
opts = opts.set_indexed();
}
let fast = match fast {
Some(f) => {
let f = f.to_lowercase();
match f.as_ref() {
"single" => Some(schema::Cardinality::SingleValue),
"multi" => Some(schema::Cardinality::MultiValues),
_ => return Err(exceptions::PyValueError::new_err(
"Invalid index option, valid choices are: 'multi' and 'single'"
)),
}
}
None => None,
};
if let Some(f) = fast {
opts = opts.set_fast(f);
if fast {
opts = opts.set_fast();
}

if let Some(builder) = builder.write().unwrap().as_mut() {
Expand Down Expand Up @@ -368,33 +355,12 @@ impl SchemaBuilder {
fn build_numeric_option(
stored: bool,
indexed: bool,
fast: Option<&str>,
fast: bool,
) -> PyResult<schema::NumericOptions> {
let opts = schema::NumericOptions::default();

let opts = if stored { opts.set_stored() } else { opts };
let opts = if indexed { opts.set_indexed() } else { opts };

let fast = match fast {
Some(f) => {
let f = f.to_lowercase();
match f.as_ref() {
"single" => Some(schema::Cardinality::SingleValue),
"multi" => Some(schema::Cardinality::MultiValues),
_ => return Err(exceptions::PyValueError::new_err(
"Invalid index option, valid choices are: 'multivalue' and 'singlevalue'"
)),
}
}
None => None,
};

let opts = if let Some(f) = fast {
opts.set_fast(f)
} else {
opts
};

let opts = if fast { opts.set_fast() } else { opts };
Ok(opts)
}

Expand Down
5 changes: 2 additions & 3 deletions src/searcher.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::new_ret_no_self)]

use crate::{document::Document, get_field, query::Query, to_pyerr};
use crate::{document::Document, query::Query, to_pyerr};
use pyo3::{exceptions::PyValueError, prelude::*};
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};
Expand Down Expand Up @@ -113,10 +113,9 @@ impl Searcher {

let (mut multifruit, hits) = {
if let Some(order_by) = order_by_field {
let field = get_field(&self.inner.index().schema(), order_by)?;
let collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by_u64_field(field);
.order_by_u64_field(order_by);
let top_docs_handle = multicollector.add_collector(collector);
let ret = self.inner.search(query.get(), &multicollector);

Expand Down
10 changes: 5 additions & 5 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_index(dir=None):
# assume all tests will use the same documents for now
# other methods may set up function-local indexes
index = Index(schema(), dir)
writer = index.writer()
writer = index.writer(10_000_000, 1)

# 2 ways of adding documents
# 1
Expand Down Expand Up @@ -77,7 +77,7 @@ def create_index(dir=None):

def create_index_with_numeric_fields(dir=None):
index = Index(schema_numeric_fields(), dir)
writer = index.writer()
writer = index.writer(10_000_000, 1)

doc = Document()
doc.add_integer("id", 1)
Expand Down Expand Up @@ -260,13 +260,13 @@ def test_and_query_numeric_fields(self, ram_index_numeric_fields):

def test_and_query_parser_default_fields(self, ram_index):
query = ram_index.parse_query("winter", default_field_names=["title"])
assert repr(query) == """Query(TermQuery(Term(type=Str, field=0, "winter")))"""
assert repr(query) == """Query(TermQuery(Term(field=0, type=Str, "winter")))"""

def test_and_query_parser_default_fields_undefined(self, ram_index):
query = ram_index.parse_query("winter")
assert (
repr(query)
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(type=Str, field=0, "winter"))), (Should, TermQuery(Term(type=Str, field=1, "winter")))] })"""
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)

def test_query_errors(self, ram_index):
Expand All @@ -278,7 +278,7 @@ def test_query_errors(self, ram_index):
def test_order_by_search(self):
schema = (
SchemaBuilder()
.add_unsigned_field("order", fast="single")
.add_unsigned_field("order", fast=True)
.add_text_field("title", stored=True)
.build()
)
Expand Down

0 comments on commit a266f41

Please sign in to comment.