Skip to content

Commit

Permalink
expose regex query (#241)
Browse files Browse the repository at this point in the history
Co-authored-by: alexau <alexau@hket.com>
  • Loading branch information
alex-au-922 and ct-alex-au authored Apr 25, 2024
1 parent c74990a commit 5c36663
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 87 deletions.
21 changes: 20 additions & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{make_term, Schema};
use crate::{get_field, make_term, to_pyerr, Schema};
use pyo3::{
exceptions,
prelude::*,
Expand Down Expand Up @@ -187,4 +187,23 @@ impl Query {
inner: Box::new(inner),
})
}

#[staticmethod]
#[pyo3(signature = (schema, field_name, regex_pattern))]
pub(crate) fn regex_query(
schema: &Schema,
field_name: &str,
regex_pattern: &str,
) -> PyResult<Query> {
let field = get_field(&schema.inner, field_name)?;

let inner_result =
tv::query::RegexQuery::from_pattern(regex_pattern, field);
match inner_result {
Ok(inner) => Ok(Query {
inner: Box::new(inner),
}),
Err(e) => Err(to_pyerr(e)),
}
}
}
173 changes: 87 additions & 86 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,108 +2,105 @@ import datetime
from enum import Enum
from typing import Any, Optional, Sequence


class Schema:
pass

class SchemaBuilder:

@staticmethod
def is_valid_field_name(name: str) -> bool:
pass

def add_text_field(
self,
name: str,
stored: bool = False,
tokenizer_name: str = "default",
index_option: str = "position",
self,
name: str,
stored: bool = False,
tokenizer_name: str = "default",
index_option: str = "position",
) -> SchemaBuilder:
pass

def add_integer_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
) -> SchemaBuilder:
pass

def add_float_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
) -> SchemaBuilder:
pass

def add_unsigned_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
) -> SchemaBuilder:
pass

def add_boolean_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
) -> SchemaBuilder:
pass

def add_date_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
) -> SchemaBuilder:
pass

def add_json_field(
self,
name: str,
stored: bool = False,
tokenizer_name: str = "default",
index_option: str = "position",
self,
name: str,
stored: bool = False,
tokenizer_name: str = "default",
index_option: str = "position",
) -> SchemaBuilder:
pass

def add_facet_field(
self,
name: str,
self,
name: str,
) -> SchemaBuilder:
pass

def add_bytes_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
index_option: str = "position",
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
index_option: str = "position",
) -> SchemaBuilder:
pass

def add_ip_addr_field(
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
self,
name: str,
stored: bool = False,
indexed: bool = False,
fast: bool = False,
) -> SchemaBuilder:
pass

def build(self) -> Schema:
pass


class Facet:
@staticmethod
def from_encoded(encoded_bytes: bytes) -> Facet:
Expand All @@ -130,9 +127,7 @@ class Facet:
def to_path_str(self) -> str:
pass


class Document:

def __new__(cls, **kwargs) -> Document:
pass

Expand Down Expand Up @@ -194,17 +189,29 @@ class Occur(Enum):

class Query:
@staticmethod
def term_query(schema: Schema, field_name: str, field_value: Any, index_option: str = "position") -> Query:
def term_query(
schema: Schema,
field_name: str,
field_value: Any,
index_option: str = "position",
) -> Query:
pass

@staticmethod
def all_query() -> Query:
pass

@staticmethod
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
def fuzzy_term_query(
schema: Schema,
field_name: str,
text: str,
distance: int = 1,
transposition_cost_one: bool = True,
prefix=False,
) -> Query:
pass

@staticmethod
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
pass
Expand All @@ -218,13 +225,15 @@ class Query:
pass


@staticmethod
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
pass

class Order(Enum):
Asc = 1
Desc = 2


class DocAddress:

def __new__(cls, segment_ord: int, doc: int) -> DocAddress:
pass

Expand All @@ -237,22 +246,19 @@ class DocAddress:
pass

class SearchResult:

@property
def hits(self) -> list[tuple[Any, DocAddress]]:
pass


class Searcher:

def search(
self,
query: Query,
limit: int = 10,
count: bool = True,
order_by_field: Optional[str] = None,
offset: int = 0,
order: Order = Order.Desc,
self,
query: Query,
limit: int = 10,
count: bool = True,
order_by_field: Optional[str] = None,
offset: int = 0,
order: Order = Order.Desc,
) -> SearchResult:
pass

Expand All @@ -267,9 +273,7 @@ class Searcher:
def doc(self, doc_address: DocAddress) -> Document:
pass


class IndexWriter:

def add_document(self, doc: Document) -> int:
pass

Expand Down Expand Up @@ -298,10 +302,10 @@ class IndexWriter:
def wait_merging_threads(self) -> None:
pass


class Index:

def __new__(cls, schema: Schema, path: Optional[str] = None, reuse: bool = True) -> Index:
def __new__(
cls, schema: Schema, path: Optional[str] = None, reuse: bool = True
) -> Index:
pass

@staticmethod
Expand All @@ -311,7 +315,9 @@ class Index:
def writer(self, heap_size: int = 128_000_000, num_threads: int = 0) -> IndexWriter:
pass

def config_reader(self, reload_policy: str = "commit", num_warmers: int = 0) -> None:
def config_reader(
self, reload_policy: str = "commit", num_warmers: int = 0
) -> None:
pass

def searcher(self) -> Searcher:
Expand All @@ -328,15 +334,17 @@ class Index:
def reload(self) -> None:
pass

def parse_query(self, query: str, default_field_names: Optional[list[str]] = None) -> Query:
def parse_query(
self, query: str, default_field_names: Optional[list[str]] = None
) -> Query:
pass

def parse_query_lenient(self, query: str, default_field_names: Optional[list[str]] = None) -> Query:
def parse_query_lenient(
self, query: str, default_field_names: Optional[list[str]] = None
) -> Query:
pass


class Range:

@property
def start(self) -> int:
pass
Expand All @@ -345,24 +353,17 @@ class Range:
def end(self) -> int:
pass


class Snippet:

def to_html(self) -> str:
pass

def highlighted(self) -> list[Range]:
pass


class SnippetGenerator:

@staticmethod
def create(
searcher: Searcher,
query: Query,
schema: Schema,
field_name: str
searcher: Searcher, query: Query, schema: Schema, field_name: str
) -> SnippetGenerator:
pass

Expand Down
33 changes: 33 additions & 0 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,36 @@ def test_boost_query(self, ram_index):
# no boost type error
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
Query.boost_query(query1)


def test_regex_query(self, ram_index):
index = ram_index

query = Query.regex_query(index.schema, "body", "fish")
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]

query = Query.regex_query(index.schema, "title", "(?:man|men)")
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]
_, doc_address = result.hits[1]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Of Mice and Men"]

# unknown field in the schema
with pytest.raises(
ValueError, match="Field `unknown_field` is not defined in the schema."
):
Query.regex_query(index.schema, "unknown_field", "fish")

# invalid regex pattern
with pytest.raises(
ValueError, match=r"An invalid argument was passed: 'fish\('"
):
Query.regex_query(index.schema, "body", "fish(")

0 comments on commit 5c36663

Please sign in to comment.